Krishna Indukuri commited on
Commit
22fcf31
·
verified ·
1 Parent(s): e031746

Upload 31 files

Browse files
Files changed (8) hide show
  1. README_CUSTOM.md +87 -0
  2. api.py +94 -0
  3. app.py +226 -0
  4. custom_st.py +194 -151
  5. inference.py +132 -0
  6. model_card.md +85 -0
  7. pipeline.py +98 -0
  8. requirements.txt +14 -0
README_CUSTOM.md ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Custom Embedding Model
2
+
3
+ This repository contains a custom embedding model based on Jina Embeddings V4, optimized for generating embeddings for text, images, and visual documents.
4
+
5
+ ## Features
6
+
7
+ - Multimodal embeddings for text and images
8
+ - Multilingual support (30+ languages)
9
+ - Task-specific adapters (retrieval, text-matching, code)
10
+ - Flexible embedding dimensions
11
+
12
+ ## Setup
13
+
14
+ 1. Install the required dependencies:
15
+
16
+ ```bash
17
+ pip install -r requirements.txt
18
+ ```
19
+
20
+ 2. You can use the model in different ways:
21
+
22
+ ### Using the Handler
23
+
24
+ ```python
25
+ from handler import ModelHandler
26
+
27
+ # Initialize the model
28
+ model_handler = ModelHandler()
29
+ model_handler.initialize(None)
30
+
31
+ # Process text inputs
32
+ text_inputs = ["Your text here", "Another example"]
33
+ features = model_handler.preprocess({"body": {"inputs": text_inputs}})
34
+ result = model_handler.inference(features)
35
+ print(result) # {"embeddings": [...]}
36
+ ```
37
+
38
+ ### Using the API
39
+
40
+ Run the API server:
41
+
42
+ ```bash
43
+ python api.py
44
+ ```
45
+
46
+ Then make API requests:
47
+
48
+ ```python
49
+ import requests
50
+ import json
51
+
52
+ response = requests.post(
53
+ "http://localhost:8000/embeddings",
54
+ json={
55
+ "inputs": [{"text": "Your text here"}, {"text": "Another example"}],
56
+ "task": "retrieval"
57
+ }
58
+ )
59
+ print(response.json()) # {"embeddings": [...]}
60
+ ```
61
+
62
+ ### Using the Pipeline
63
+
64
+ ```python
65
+ from pipeline import load_pipeline
66
+
67
+ # Load the pipeline
68
+ pipeline = load_pipeline("path/to/model")
69
+
70
+ # Generate embeddings
71
+ embeddings = pipeline("Your text here", task="retrieval")
72
+ print(embeddings.shape) # (1, 2048)
73
+ ```
74
+
75
+ ## Demo UI
76
+
77
+ You can also run a Gradio demo UI:
78
+
79
+ ```bash
80
+ python app.py
81
+ ```
82
+
83
+ This will start a web interface for testing embeddings and comparing similarities between text and images.
84
+
85
+ ## License
86
+
87
+ This model is available under the same terms as the original model it's based on. Please refer to the license information for details.
api.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request, Response, HTTPException
2
+ from fastapi.responses import JSONResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ import uvicorn
5
+ import torch
6
+ import json
7
+ import base64
8
+ from io import BytesIO
9
+ from PIL import Image
10
+ import requests
11
+ from typing import List, Dict, Any, Union, Optional
12
+ from pydantic import BaseModel, Field
13
+ import numpy as np
14
+ import os
15
+
16
+ # Import handler
17
+ from handler import ModelHandler
18
+
19
+ app = FastAPI(title="Embedding Model API")
20
+
21
+ # Add CORS middleware
22
+ app.add_middleware(
23
+ CORSMiddleware,
24
+ allow_origins=["*"],
25
+ allow_credentials=True,
26
+ allow_methods=["*"],
27
+ allow_headers=["*"],
28
+ )
29
+
30
+ # Initialize model handler
31
+ model_handler = ModelHandler()
32
+ model_handler.initialize(None) # We'll handle context manually
33
+
34
+ # Define request/response models
35
+ class TextInput(BaseModel):
36
+ text: str = Field(..., description="The text to generate embeddings for")
37
+
38
+ class ImageInput(BaseModel):
39
+ image: str = Field(..., description="URL or base64-encoded image to generate embeddings for")
40
+
41
+ class EmbeddingRequest(BaseModel):
42
+ inputs: List[Union[TextInput, ImageInput]] = Field(..., description="List of text or image inputs")
43
+ task: str = Field("retrieval", description="Task type: retrieval, text-matching, or code")
44
+
45
+ class EmbeddingResponse(BaseModel):
46
+ embeddings: List[List[float]] = Field(..., description="List of embeddings")
47
+
48
+ @app.get("/")
49
+ async def root():
50
+ return {"message": "Embedding Model API is running"}
51
+
52
+ @app.post("/embeddings", response_model=EmbeddingResponse)
53
+ async def create_embeddings(request: EmbeddingRequest):
54
+ try:
55
+ inputs = []
56
+
57
+ # Process inputs
58
+ for item in request.inputs:
59
+ if hasattr(item, "text"):
60
+ inputs.append(item.text)
61
+ elif hasattr(item, "image"):
62
+ image_data = item.image
63
+ if image_data.startswith("http"):
64
+ # URL
65
+ response = requests.get(image_data)
66
+ image = Image.open(BytesIO(response.content)).convert("RGB")
67
+ elif image_data.startswith("data:image"):
68
+ # Base64
69
+ image_b64 = image_data.split(",")[1]
70
+ image = Image.open(BytesIO(base64.b64decode(image_b64))).convert("RGB")
71
+ else:
72
+ raise HTTPException(status_code=400, detail="Invalid image format")
73
+ inputs.append(image)
74
+
75
+ # Get embeddings
76
+ features = model_handler.model.tokenize(inputs)
77
+ outputs = model_handler.model.forward(features, task=request.task)
78
+ embeddings = outputs.get("sentence_embedding", None)
79
+
80
+ if embeddings is None:
81
+ raise HTTPException(status_code=500, detail="Failed to generate embeddings")
82
+
83
+ # Convert to list for JSON serialization
84
+ embeddings_list = embeddings.cpu().numpy().tolist()
85
+
86
+ return {"embeddings": embeddings_list}
87
+
88
+ except Exception as e:
89
+ raise HTTPException(status_code=500, detail=str(e))
90
+
91
+ if __name__ == "__main__":
92
+ # Run the API server
93
+ port = int(os.environ.get("PORT", 8000))
94
+ uvicorn.run(app, host="0.0.0.0", port=port)
app.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ import numpy as np
5
+ from typing import List, Union, Optional
6
+ from PIL import Image
7
+ import requests
8
+ from io import BytesIO
9
+ import base64
10
+
11
+ # Import your handler
12
+ from handler import ModelHandler
13
+
14
+ # Create model handler instance
15
+ model_handler = ModelHandler()
16
+ model_handler.initialize(None) # We'll handle device placement manually
17
+
18
+ def cosine_similarity(embedding1, embedding2):
19
+ """Calculate cosine similarity between two embeddings"""
20
+ embedding1 = np.array(embedding1)
21
+ embedding2 = np.array(embedding2)
22
+ return np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
23
+
24
+ def process_image(image_input):
25
+ """Process image input (URL, uploaded file, or base64)"""
26
+ if isinstance(image_input, str):
27
+ if image_input.startswith("http"):
28
+ # URL
29
+ response = requests.get(image_input)
30
+ image = Image.open(BytesIO(response.content)).convert("RGB")
31
+ elif image_input.startswith("data:image"):
32
+ # Base64
33
+ image_data = image_input.split(",")[1]
34
+ image = Image.open(BytesIO(base64.b64decode(image_data))).convert("RGB")
35
+ else:
36
+ # Local path
37
+ image = Image.open(image_input).convert("RGB")
38
+ else:
39
+ # Uploaded file from Gradio
40
+ image = Image.open(image_input).convert("RGB")
41
+ return image
42
+
43
+ def generate_embeddings(inputs, task="retrieval", input_type="text"):
44
+ """Generate embeddings for text or image inputs"""
45
+ try:
46
+ # Handle different input types
47
+ if input_type == "text":
48
+ features = model_handler.model.tokenize(inputs)
49
+ else: # image
50
+ processed_images = [process_image(img) for img in inputs]
51
+ features = model_handler.model.tokenize(processed_images)
52
+
53
+ # Process features through model
54
+ with torch.no_grad():
55
+ outputs = model_handler.model.forward(features, task=task)
56
+ embeddings = outputs.get("sentence_embedding", None)
57
+
58
+ if embeddings is not None:
59
+ return embeddings.cpu().numpy().tolist()
60
+ else:
61
+ return None
62
+ except Exception as e:
63
+ return {"error": str(e)}
64
+
65
+ def text_to_embedding(text, task="retrieval"):
66
+ """Convert text to embedding"""
67
+ if not text.strip():
68
+ return None
69
+ return generate_embeddings([text], task=task, input_type="text")[0]
70
+
71
+ def image_to_embedding(image, task="retrieval"):
72
+ """Convert image to embedding"""
73
+ if image is None:
74
+ return None
75
+ return generate_embeddings([image], task=task, input_type="image")[0]
76
+
77
+ def compare_embeddings(embedding1, embedding2):
78
+ """Compare two embeddings and return similarity"""
79
+ if embedding1 is None or embedding2 is None:
80
+ return "Please generate both embeddings first"
81
+ similarity = cosine_similarity(embedding1, embedding2)
82
+ return f"Cosine Similarity: {similarity:.4f}"
83
+
84
+ # Create Gradio interface
85
+ with gr.Blocks(title="Embedding Model Demo") as demo:
86
+ gr.Markdown("# Embedding Model Demo")
87
+ gr.Markdown("Generate and compare embeddings for text and images")
88
+
89
+ with gr.Tab("Text Embeddings"):
90
+ with gr.Row():
91
+ with gr.Column():
92
+ text_input1 = gr.Textbox(label="Text Input 1", lines=5)
93
+ task_dropdown1 = gr.Dropdown(
94
+ ["retrieval", "text-matching", "code"],
95
+ label="Task",
96
+ value="retrieval"
97
+ )
98
+ text_embed_btn1 = gr.Button("Generate Embedding 1")
99
+
100
+ with gr.Column():
101
+ text_input2 = gr.Textbox(label="Text Input 2", lines=5)
102
+ task_dropdown2 = gr.Dropdown(
103
+ ["retrieval", "text-matching", "code"],
104
+ label="Task",
105
+ value="retrieval"
106
+ )
107
+ text_embed_btn2 = gr.Button("Generate Embedding 2")
108
+
109
+ embedding_output1 = gr.JSON(label="Embedding 1", visible=False)
110
+ embedding_output2 = gr.JSON(label="Embedding 2", visible=False)
111
+
112
+ compare_btn = gr.Button("Compare Embeddings")
113
+ similarity_output = gr.Textbox(label="Similarity Result")
114
+
115
+ with gr.Tab("Image Embeddings"):
116
+ with gr.Row():
117
+ with gr.Column():
118
+ image_input1 = gr.Image(label="Image Input 1", type="pil")
119
+ image_task_dropdown1 = gr.Dropdown(
120
+ ["retrieval"],
121
+ label="Task",
122
+ value="retrieval"
123
+ )
124
+ image_embed_btn1 = gr.Button("Generate Embedding 1")
125
+
126
+ with gr.Column():
127
+ image_input2 = gr.Image(label="Image Input 2", type="pil")
128
+ image_task_dropdown2 = gr.Dropdown(
129
+ ["retrieval"],
130
+ label="Task",
131
+ value="retrieval"
132
+ )
133
+ image_embed_btn2 = gr.Button("Generate Embedding 2")
134
+
135
+ image_embedding_output1 = gr.JSON(label="Embedding 1", visible=False)
136
+ image_embedding_output2 = gr.JSON(label="Embedding 2", visible=False)
137
+
138
+ image_compare_btn = gr.Button("Compare Embeddings")
139
+ image_similarity_output = gr.Textbox(label="Similarity Result")
140
+
141
+ with gr.Tab("Cross-Modal Comparison"):
142
+ with gr.Row():
143
+ with gr.Column():
144
+ cross_text_input = gr.Textbox(label="Text Input", lines=5)
145
+ cross_text_task = gr.Dropdown(
146
+ ["retrieval"],
147
+ label="Task",
148
+ value="retrieval"
149
+ )
150
+ cross_text_btn = gr.Button("Generate Text Embedding")
151
+
152
+ with gr.Column():
153
+ cross_image_input = gr.Image(label="Image Input", type="pil")
154
+ cross_image_task = gr.Dropdown(
155
+ ["retrieval"],
156
+ label="Task",
157
+ value="retrieval"
158
+ )
159
+ cross_image_btn = gr.Button("Generate Image Embedding")
160
+
161
+ cross_text_embedding = gr.JSON(label="Text Embedding", visible=False)
162
+ cross_image_embedding = gr.JSON(label="Image Embedding", visible=False)
163
+
164
+ cross_compare_btn = gr.Button("Compare Text and Image")
165
+ cross_similarity_output = gr.Textbox(label="Similarity Result")
166
+
167
+ # Text tab events
168
+ text_embed_btn1.click(
169
+ fn=text_to_embedding,
170
+ inputs=[text_input1, task_dropdown1],
171
+ outputs=embedding_output1
172
+ )
173
+
174
+ text_embed_btn2.click(
175
+ fn=text_to_embedding,
176
+ inputs=[text_input2, task_dropdown2],
177
+ outputs=embedding_output2
178
+ )
179
+
180
+ compare_btn.click(
181
+ fn=compare_embeddings,
182
+ inputs=[embedding_output1, embedding_output2],
183
+ outputs=similarity_output
184
+ )
185
+
186
+ # Image tab events
187
+ image_embed_btn1.click(
188
+ fn=image_to_embedding,
189
+ inputs=[image_input1, image_task_dropdown1],
190
+ outputs=image_embedding_output1
191
+ )
192
+
193
+ image_embed_btn2.click(
194
+ fn=image_to_embedding,
195
+ inputs=[image_input2, image_task_dropdown2],
196
+ outputs=image_embedding_output2
197
+ )
198
+
199
+ image_compare_btn.click(
200
+ fn=compare_embeddings,
201
+ inputs=[image_embedding_output1, image_embedding_output2],
202
+ outputs=image_similarity_output
203
+ )
204
+
205
+ # Cross-modal tab events
206
+ cross_text_btn.click(
207
+ fn=text_to_embedding,
208
+ inputs=[cross_text_input, cross_text_task],
209
+ outputs=cross_text_embedding
210
+ )
211
+
212
+ cross_image_btn.click(
213
+ fn=image_to_embedding,
214
+ inputs=[cross_image_input, cross_image_task],
215
+ outputs=cross_image_embedding
216
+ )
217
+
218
+ cross_compare_btn.click(
219
+ fn=compare_embeddings,
220
+ inputs=[cross_text_embedding, cross_image_embedding],
221
+ outputs=cross_similarity_output
222
+ )
223
+
224
+ # Launch the Gradio app
225
+ if __name__ == "__main__":
226
+ demo.launch()
custom_st.py CHANGED
@@ -1,186 +1,229 @@
1
  import json
 
2
  import os
3
  from io import BytesIO
4
- from pathlib import Path
5
- from typing import Any, Dict, List, Literal, Optional, Union
6
 
7
- import requests
8
  import torch
9
- from PIL import Image
10
  from torch import nn
11
- from transformers import AutoConfig, AutoModel, AutoProcessor
 
 
12
 
13
 
14
  class Transformer(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  save_in_root: bool = True
17
 
18
  def __init__(
19
  self,
20
- model_name_or_path: str = "jinaai/jina-embeddings-v4",
21
- max_seq_length: Optional[int] = None,
22
- config_args: Optional[Dict[str, Any]] = None,
23
- model_args: Optional[Dict[str, Any]] = None,
24
- tokenizer_args: Optional[Dict[str, Any]] = None,
25
- cache_dir: Optional[str] = None,
26
- backend: Literal["torch", "onnx", "openvino"] = "torch",
 
27
  **kwargs,
28
  ) -> None:
29
- super(Transformer, self).__init__()
30
- if backend != "torch":
31
- raise ValueError(
32
- f"Backend '{backend}' is not supported, please use 'torch' instead"
 
 
 
 
 
 
 
 
 
 
33
  )
34
- config_kwargs = config_args or {}
35
- model_kwargs = model_args or {}
36
- tokenizer_kwargs = tokenizer_args or {}
37
-
38
- self.config = AutoConfig.from_pretrained(
39
- model_name_or_path, cache_dir=cache_dir, trust_remote_code=True, **config_kwargs
40
- )
41
- self.default_task = model_args.pop("default_task", None)
42
- if self.default_task and self.default_task not in self.config.task_names:
43
  raise ValueError(
44
- f"Invalid task: {self.default_task}. Must be one of {self.config.task_names}."
45
  )
 
 
 
46
 
47
- self.model = AutoModel.from_pretrained(
48
- model_name_or_path, config=self.config, cache_dir=cache_dir, trust_remote_code=True, **model_kwargs
49
- )
50
- self.processor = AutoProcessor.from_pretrained(
51
- model_name_or_path,
 
 
 
52
  cache_dir=cache_dir,
53
- use_fast=True,
54
- trust_remote_code=True,
55
- **tokenizer_kwargs,
56
  )
57
- self.max_seq_length = max_seq_length or 8192
58
 
59
- def tokenize(
60
- self, texts: List[Union[str, Image.Image]], padding: Union[str, bool] = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  ) -> Dict[str, torch.Tensor]:
62
- encoding = {}
63
- text_indices = []
64
- image_indices = []
65
- for i, text in enumerate(texts):
66
- if isinstance(text, str):
67
- # Remove Query: or Passage: prefixes when checking for URLs or file paths
68
- clean_text = text
69
- if text.startswith("Query: "):
70
- clean_text = text[len("Query: ") :]
71
- elif text.startswith("Passage: "):
72
- clean_text = text[len("Passage: ") :]
73
-
74
- if clean_text.startswith("http"):
75
- response = requests.get(clean_text)
76
- texts[i] = Image.open(BytesIO(response.content)).convert("RGB")
77
- image_indices.append(i)
78
- else:
79
- try:
80
- if Path(clean_text).is_file():
81
- texts[i] = Image.open(clean_text).convert("RGB")
82
- image_indices.append(i)
83
- else:
84
- text_indices.append(i)
85
- except Exception as e:
86
- text_indices.append(i)
87
- elif isinstance(text, Image.Image):
88
- image_indices.append(i)
89
- else:
90
- raise ValueError(f"Invalid input type: {type(text)}")
91
- if text_indices:
92
- _texts = [texts[i] for i in text_indices]
93
- text_features = self.processor.process_texts(
94
- _texts, max_length=self.max_seq_length
95
  )
96
- for key, value in text_features.items():
97
- encoding[f"text_{key}"] = value
98
- encoding["text_indices"] = text_indices
99
 
100
- if image_indices:
101
- _images = [texts[i] for i in image_indices]
102
- img_features = self.processor.process_images(_images)
103
- for key, value in img_features.items():
104
- encoding[f"image_{key}"] = value
105
- encoding["image_indices"] = image_indices
 
 
106
 
107
- return encoding
 
108
 
109
- def forward(
110
  self,
111
- features: Dict[str, torch.Tensor],
112
- task: Optional[str] = None,
113
- truncate_dim: Optional[int] = None,
114
  ) -> Dict[str, torch.Tensor]:
115
- self.model.eval()
116
-
117
- if task is None:
118
- if self.default_task is None:
119
- raise ValueError(
120
- "Task must be specified before encoding data. You can set it either during "
121
- "loading the model (e.g., model_kwargs={'default_task': 'retrieval'}) or "
122
- "pass it as an argument to the encode method (e.g., model.encode(texts, task='retrieval'))."
123
- )
124
- task = self.default_task
 
 
125
  else:
126
- if task not in self.config.task_names:
127
- raise ValueError(
128
- f"Invalid task: {task}. Must be one of {self.config.task_names}."
129
- )
130
-
131
- device = self.model.device.type
132
- all_embeddings = []
133
-
134
- with torch.no_grad():
135
- if any(k.startswith("text_") for k in features.keys()):
136
- text_batch = {
137
- k[len("text_") :]: v.to(device)
138
- for k, v in features.items()
139
- if k.startswith("text_") and k != "text_indices"
140
- }
141
- text_indices = features.get("text_indices", [])
142
- with torch.autocast(device_type=device, dtype=torch.bfloat16):
143
- text_embeddings = self.model(
144
- **text_batch, task_label=task
145
- ).single_vec_emb
146
- if truncate_dim:
147
- text_embeddings = text_embeddings[:, :truncate_dim]
148
- text_embeddings = torch.nn.functional.normalize(
149
- text_embeddings, p=2, dim=-1
150
- )
151
- for i, embedding in enumerate(text_embeddings):
152
- all_embeddings.append((text_indices[i], embedding))
153
-
154
- if any(k.startswith("image_") for k in features.keys()):
155
- image_batch = {
156
- k[len("image_") :]: v.to(device)
157
- for k, v in features.items()
158
- if k.startswith("image_") and k != "image_indices"
159
- }
160
- image_indices = features.get("image_indices", [])
161
-
162
- with torch.autocast(device_type=device, dtype=torch.bfloat16):
163
- img_embeddings = self.model(
164
- **image_batch, task_label=task
165
- ).single_vec_emb
166
- if truncate_dim:
167
- img_embeddings = img_embeddings[:, :truncate_dim]
168
- img_embeddings = torch.nn.functional.normalize(
169
- img_embeddings, p=2, dim=-1
170
- )
171
-
172
- for i, embedding in enumerate(img_embeddings):
173
- all_embeddings.append((image_indices[i], embedding))
174
-
175
- if not all_embeddings:
176
- raise RuntimeError("No embeddings were generated")
177
-
178
- all_embeddings.sort(key=lambda x: x[0]) # sort by original index
179
- combined_embeddings = torch.stack([emb for _, emb in all_embeddings])
180
- features["sentence_embedding"] = combined_embeddings
181
 
182
- return features
183
 
184
  @classmethod
185
  def load(cls, input_path: str) -> "Transformer":
186
- return cls(model_name_or_path=input_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
+ import logging
3
  import os
4
  from io import BytesIO
5
+ from typing import Any, Dict, List, Optional, Tuple, Union
 
6
 
 
7
  import torch
 
8
  from torch import nn
9
+ from transformers import AutoConfig, AutoModel, AutoTokenizer
10
+
11
+ logger = logging.getLogger(__name__)
12
 
13
 
14
  class Transformer(nn.Module):
15
+ """Huggingface AutoModel to generate token embeddings.
16
+ Loads the correct class, e.g. BERT / RoBERTa etc.
17
+
18
+ Args:
19
+ model_name_or_path: Huggingface models name
20
+ (https://huggingface.co/models)
21
+ max_seq_length: Truncate any inputs longer than max_seq_length
22
+ model_args: Keyword arguments passed to the Huggingface
23
+ Transformers model
24
+ tokenizer_args: Keyword arguments passed to the Huggingface
25
+ Transformers tokenizer
26
+ config_args: Keyword arguments passed to the Huggingface
27
+ Transformers config
28
+ cache_dir: Cache dir for Huggingface Transformers to store/load
29
+ models
30
+ do_lower_case: If true, lowercases the input (independent if the
31
+ model is cased or not)
32
+ tokenizer_name_or_path: Name or path of the tokenizer. When
33
+ None, then model_name_or_path is used
34
+ """
35
 
36
  save_in_root: bool = True
37
 
38
  def __init__(
39
  self,
40
+ model_name_or_path: str,
41
+ max_seq_length: int = None,
42
+ model_args: Dict[str, Any] = None,
43
+ tokenizer_args: Dict[str, Any] = None,
44
+ config_args: Dict[str, Any] = None,
45
+ cache_dir: str = None,
46
+ do_lower_case: bool = False,
47
+ tokenizer_name_or_path: str = None,
48
  **kwargs,
49
  ) -> None:
50
+ super().__init__()
51
+ self.config_keys = ["max_seq_length", "do_lower_case"]
52
+ self.do_lower_case = do_lower_case
53
+ if model_args is None:
54
+ model_args = {}
55
+ if tokenizer_args is None:
56
+ tokenizer_args = {}
57
+ if config_args is None:
58
+ config_args = {}
59
+
60
+ if kwargs.get("backend", "torch") != "torch":
61
+ logger.warning(
62
+ f'"jinaai/jina-embeddings-v3" is currently not compatible with the {kwargs["backend"]} backend. '
63
+ 'Continuing with the "torch" backend.'
64
  )
65
+
66
+ self.config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
67
+
68
+ self._lora_adaptations = self.config.lora_adaptations
69
+ if (
70
+ not isinstance(self._lora_adaptations, list)
71
+ or len(self._lora_adaptations) < 1
72
+ ):
 
73
  raise ValueError(
74
+ f"`lora_adaptations` must be a list and contain at least one element"
75
  )
76
+ self._adaptation_map = {
77
+ name: idx for idx, name in enumerate(self._lora_adaptations)
78
+ }
79
 
80
+ self.default_task = model_args.pop('default_task', None)
81
+
82
+ self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=self.config, cache_dir=cache_dir, **model_args)
83
+
84
+ if max_seq_length is not None and "model_max_length" not in tokenizer_args:
85
+ tokenizer_args["model_max_length"] = max_seq_length
86
+ self.tokenizer = AutoTokenizer.from_pretrained(
87
+ tokenizer_name_or_path if tokenizer_name_or_path is not None else model_name_or_path,
88
  cache_dir=cache_dir,
89
+ **tokenizer_args,
 
 
90
  )
 
91
 
92
+ # No max_seq_length set. Try to infer from model
93
+ if max_seq_length is None:
94
+ if (
95
+ hasattr(self.auto_model, "config")
96
+ and hasattr(self.auto_model.config, "max_position_embeddings")
97
+ and hasattr(self.tokenizer, "model_max_length")
98
+ ):
99
+ max_seq_length = min(self.auto_model.config.max_position_embeddings, self.tokenizer.model_max_length)
100
+
101
+ self.max_seq_length = max_seq_length
102
+
103
+ if tokenizer_name_or_path is not None:
104
+ self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__
105
+
106
+
107
+ @property
108
+ def default_task(self):
109
+ return self._default_task
110
+
111
+ @default_task.setter
112
+ def default_task(self, task: Union[None, str]):
113
+ self._validate_task(task)
114
+ self._default_task = task
115
+
116
+
117
+ def _validate_task(self, task: str):
118
+ if task and task not in self._lora_adaptations:
119
+ raise ValueError(
120
+ f"Unsupported task '{task}'. "
121
+ f"Supported tasks are: {', '.join(self.config.lora_adaptations)}. "
122
+ f"Alternatively, don't pass the `task` argument to disable LoRA."
123
+ )
124
+
125
+ def forward(
126
+ self, features: Dict[str, torch.Tensor], task: Optional[str] = None
127
  ) -> Dict[str, torch.Tensor]:
128
+ """Returns token_embeddings, cls_token"""
129
+ self._validate_task(task)
130
+ task = task or self.default_task
131
+ adapter_mask = None
132
+ if task:
133
+ task_id = self._adaptation_map[task]
134
+ num_examples = features['input_ids'].size(0)
135
+ adapter_mask = torch.full(
136
+ (num_examples,), task_id, dtype=torch.int32, device=features['input_ids'].device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  )
 
 
 
138
 
139
+ lora_arguments = (
140
+ {"adapter_mask": adapter_mask} if adapter_mask is not None else {}
141
+ )
142
+ features.pop('prompt_length', None)
143
+ output_states = self.auto_model.forward(**features, **lora_arguments, return_dict=False)
144
+ output_tokens = output_states[0]
145
+ features.update({"token_embeddings": output_tokens, "attention_mask": features["attention_mask"]})
146
+ return features
147
 
148
+ def get_word_embedding_dimension(self) -> int:
149
+ return self.auto_model.config.hidden_size
150
 
151
+ def tokenize(
152
  self,
153
+ texts: Union[List[str], List[dict], List[Tuple[str, str]]],
154
+ padding: Union[str, bool] = True
 
155
  ) -> Dict[str, torch.Tensor]:
156
+ """Tokenizes a text and maps tokens to token-ids"""
157
+ output = {}
158
+ if isinstance(texts[0], str):
159
+ to_tokenize = [texts]
160
+ elif isinstance(texts[0], dict):
161
+ to_tokenize = []
162
+ output["text_keys"] = []
163
+ for lookup in texts:
164
+ text_key, text = next(iter(lookup.items()))
165
+ to_tokenize.append(text)
166
+ output["text_keys"].append(text_key)
167
+ to_tokenize = [to_tokenize]
168
  else:
169
+ batch1, batch2 = [], []
170
+ for text_tuple in texts:
171
+ batch1.append(text_tuple[0])
172
+ batch2.append(text_tuple[1])
173
+ to_tokenize = [batch1, batch2]
174
+
175
+ # strip
176
+ to_tokenize = [[str(s).strip() for s in col] for col in to_tokenize]
177
+
178
+ # Lowercase
179
+ if self.do_lower_case:
180
+ to_tokenize = [[s.lower() for s in col] for col in to_tokenize]
181
+
182
+ output.update(
183
+ self.tokenizer(
184
+ *to_tokenize,
185
+ padding=padding,
186
+ truncation="longest_first",
187
+ return_tensors="pt",
188
+ max_length=self.max_seq_length,
189
+ )
190
+ )
191
+ return output
192
+
193
+ def get_config_dict(self) -> Dict[str, Any]:
194
+ return {key: self.__dict__[key] for key in self.config_keys}
195
+
196
+ def save(self, output_path: str, safe_serialization: bool = True) -> None:
197
+ self.auto_model.save_pretrained(output_path, safe_serialization=safe_serialization)
198
+ self.tokenizer.save_pretrained(output_path)
199
+
200
+ with open(os.path.join(output_path, "sentence_bert_config.json"), "w") as fOut:
201
+ json.dump(self.get_config_dict(), fOut, indent=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
 
203
 
204
  @classmethod
205
  def load(cls, input_path: str) -> "Transformer":
206
+ # Old classes used other config names than 'sentence_bert_config.json'
207
+ for config_name in [
208
+ "sentence_bert_config.json",
209
+ "sentence_roberta_config.json",
210
+ "sentence_distilbert_config.json",
211
+ "sentence_camembert_config.json",
212
+ "sentence_albert_config.json",
213
+ "sentence_xlm-roberta_config.json",
214
+ "sentence_xlnet_config.json",
215
+ ]:
216
+ sbert_config_path = os.path.join(input_path, config_name)
217
+ if os.path.exists(sbert_config_path):
218
+ break
219
+
220
+ with open(sbert_config_path) as fIn:
221
+ config = json.load(fIn)
222
+ # Don't allow configs to set trust_remote_code
223
+ if "model_args" in config and "trust_remote_code" in config["model_args"]:
224
+ config["model_args"].pop("trust_remote_code")
225
+ if "tokenizer_args" in config and "trust_remote_code" in config["tokenizer_args"]:
226
+ config["tokenizer_args"].pop("trust_remote_code")
227
+ if "config_args" in config and "trust_remote_code" in config["config_args"]:
228
+ config["config_args"].pop("trust_remote_code")
229
+ return cls(model_name_or_path=input_path, **config)
inference.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import json
4
+ from typing import Dict, List, Union, Optional, Any
5
+ from PIL import Image
6
+ from transformers import AutoConfig, AutoTokenizer
7
+ from custom_st import Transformer
8
+
9
+ class InferenceEmbeddings:
10
+ def __init__(self, model_path: str):
11
+ """
12
+ Initialize the embedding model for inference
13
+
14
+ Args:
15
+ model_path: Path to the model directory
16
+ """
17
+ self.model_path = model_path
18
+ self.model = Transformer(
19
+ model_name_or_path=model_path,
20
+ model_args={"default_task": "retrieval", "trust_remote_code": True},
21
+ trust_remote_code=True
22
+ )
23
+ self.model.eval()
24
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ self.model.to(self.device)
26
+
27
+ def encode_text(self,
28
+ texts: List[str],
29
+ task: str = "retrieval",
30
+ prompt_name: Optional[str] = None,
31
+ truncate_dim: Optional[int] = None,
32
+ return_multivector: bool = False,
33
+ max_length: Optional[int] = None,
34
+ batch_size: int = 32) -> torch.Tensor:
35
+ """
36
+ Encode text inputs to embeddings
37
+
38
+ Args:
39
+ texts: List of text inputs to encode
40
+ task: Task for which to generate embeddings (retrieval, text-matching, code)
41
+ prompt_name: Optional prompt type (query, passage)
42
+ truncate_dim: Optional dimension to truncate embeddings to
43
+ return_multivector: Whether to return multi-vector embeddings
44
+ max_length: Maximum token length
45
+ batch_size: Batch size for encoding
46
+
47
+ Returns:
48
+ Tensor of embeddings
49
+ """
50
+ if prompt_name:
51
+ # Add prompt prefix based on prompt_name
52
+ if prompt_name == "query":
53
+ texts = [f"Query: {text}" for text in texts]
54
+ elif prompt_name == "passage":
55
+ texts = [f"Passage: {text}" for text in texts]
56
+
57
+ embeddings = []
58
+ for i in range(0, len(texts), batch_size):
59
+ batch_texts = texts[i:i+batch_size]
60
+ features = self.model.tokenize(batch_texts)
61
+
62
+ # Move tensors to device
63
+ for key, value in features.items():
64
+ if isinstance(value, torch.Tensor):
65
+ features[key] = value.to(self.device)
66
+
67
+ with torch.no_grad():
68
+ outputs = self.model.forward(features, task=task, truncate_dim=truncate_dim)
69
+ batch_embeddings = outputs.get("sentence_embedding", None)
70
+
71
+ if batch_embeddings is not None:
72
+ embeddings.append(batch_embeddings.cpu())
73
+
74
+ if embeddings:
75
+ return torch.cat(embeddings, dim=0)
76
+ else:
77
+ raise RuntimeError("Failed to generate embeddings")
78
+
79
+ def encode_image(self,
80
+ images: List[Union[str, Image.Image]],
81
+ task: str = "retrieval",
82
+ truncate_dim: Optional[int] = None,
83
+ return_multivector: bool = False,
84
+ max_pixels: Optional[int] = None,
85
+ batch_size: int = 8) -> torch.Tensor:
86
+ """
87
+ Encode image inputs to embeddings
88
+
89
+ Args:
90
+ images: List of image inputs (file paths, URLs, or PIL Images)
91
+ task: Task for which to generate embeddings
92
+ truncate_dim: Optional dimension to truncate embeddings to
93
+ return_multivector: Whether to return multi-vector embeddings
94
+ max_pixels: Maximum number of pixels for image resizing
95
+ batch_size: Batch size for encoding
96
+
97
+ Returns:
98
+ Tensor of embeddings
99
+ """
100
+ embeddings = []
101
+ for i in range(0, len(images), batch_size):
102
+ batch_images = images[i:i+batch_size]
103
+ features = self.model.tokenize(batch_images)
104
+
105
+ # Move tensors to device
106
+ for key, value in features.items():
107
+ if isinstance(value, torch.Tensor):
108
+ features[key] = value.to(self.device)
109
+
110
+ with torch.no_grad():
111
+ outputs = self.model.forward(features, task=task, truncate_dim=truncate_dim)
112
+ batch_embeddings = outputs.get("sentence_embedding", None)
113
+
114
+ if batch_embeddings is not None:
115
+ embeddings.append(batch_embeddings.cpu())
116
+
117
+ if embeddings:
118
+ return torch.cat(embeddings, dim=0)
119
+ else:
120
+ raise RuntimeError("Failed to generate embeddings")
121
+
122
+ def load_model(model_path: str):
123
+ """
124
+ Load the embedding model for inference
125
+
126
+ Args:
127
+ model_path: Path to the model directory
128
+
129
+ Returns:
130
+ Loaded model instance
131
+ """
132
+ return InferenceEmbeddings(model_path)
model_card.md ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: multilingual
3
+ license: other
4
+ datasets:
5
+ - jinaai/jina-vdr
6
+ pipeline_tag: feature-extraction
7
+ tags:
8
+ - embeddings
9
+ - multilingual-embeddings
10
+ - multimodal-embeddings
11
+ - text-to-image
12
+ - sentence-transformers
13
+ - sentence-similarity
14
+ - visual-document-retrieval
15
+ ---
16
+
17
+ # Custom Embedding Model
18
+
19
+ This is a custom embedding model based on the Jina Embeddings V4 architecture, specially adapted for embedding tasks involving text, images, and visual documents.
20
+
21
+ ## Model Description
22
+
23
+ The model supports:
24
+
25
+ - **Multimodal Embeddings**: Generate unified embeddings for text and images
26
+ - **Multilingual Support**: Works across 30+ languages
27
+ - **Task-specific Modes**: Optimized for retrieval, text-matching, and code tasks
28
+ - **Flexible Dimensions**: Dense embeddings that can be truncated with minimal performance loss
29
+
30
+ ## Usage
31
+
32
+ ### Text Embeddings
33
+
34
+ ```python
35
+ from custom_st import Transformer
36
+
37
+ # Initialize the model
38
+ model = Transformer(
39
+ model_name_or_path="path/to/model",
40
+ model_args={"default_task": "retrieval", "trust_remote_code": True},
41
+ trust_remote_code=True
42
+ )
43
+
44
+ # Encode text
45
+ texts = ["Your text here", "Another text example"]
46
+ features = model.tokenize(texts)
47
+ outputs = model.forward(features, task="retrieval")
48
+ embeddings = outputs["sentence_embedding"]
49
+ ```
50
+
51
+ ### Image Embeddings
52
+
53
+ ```python
54
+ from PIL import Image
55
+ from custom_st import Transformer
56
+
57
+ # Initialize the model
58
+ model = Transformer(
59
+ model_name_or_path="path/to/model",
60
+ model_args={"default_task": "retrieval", "trust_remote_code": True},
61
+ trust_remote_code=True
62
+ )
63
+
64
+ # Load images
65
+ images = [Image.open("image1.jpg"), Image.open("image2.jpg")]
66
+ # Or use URLs
67
+ image_urls = ["http://example.com/image1.jpg", "http://example.com/image2.jpg"]
68
+
69
+ # Encode images
70
+ features = model.tokenize(images) # or model.tokenize(image_urls)
71
+ outputs = model.forward(features, task="retrieval")
72
+ embeddings = outputs["sentence_embedding"]
73
+ ```
74
+
75
+ ## Requirements
76
+
77
+ - Python 3.8+
78
+ - PyTorch 2.0+
79
+ - Transformers 4.30+
80
+ - PEFT 0.4+
81
+ - Pillow 9.0+
82
+
83
+ ## License
84
+
85
+ This model is available under the same terms as the original model it's based on. Please refer to the license information in the repository for details.
pipeline.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union, Dict, Any, Optional
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ from transformers import Pipeline
6
+ from custom_st import Transformer
7
+
8
+ class EmbeddingPipeline(Pipeline):
9
+ """
10
+ Pipeline for generating embeddings using custom transformer model
11
+ """
12
+
13
+ def __init__(self, model, **kwargs):
14
+ super().__init__(model=model, **kwargs)
15
+
16
+ # Default task if not specified
17
+ self.default_task = "retrieval"
18
+
19
+ def _sanitize_parameters(self, task=None, truncate_dim=None, **kwargs):
20
+ preprocess_params = {}
21
+ forward_params = {}
22
+ postprocess_params = {}
23
+
24
+ if task is not None:
25
+ forward_params["task"] = task
26
+
27
+ if truncate_dim is not None:
28
+ forward_params["truncate_dim"] = truncate_dim
29
+
30
+ return preprocess_params, forward_params, postprocess_params
31
+
32
+ def preprocess(self, inputs, **preprocess_params):
33
+ """
34
+ Preprocess the inputs before passing to model
35
+ """
36
+ # Handle single input vs list of inputs
37
+ if not isinstance(inputs, list):
38
+ inputs = [inputs]
39
+
40
+ # Tokenize/prepare the inputs
41
+ features = self.model.tokenize(inputs)
42
+ return features
43
+
44
+ def _forward(self, features, task=None, truncate_dim=None):
45
+ """
46
+ Forward pass through the model
47
+ """
48
+ # Set default task if not provided
49
+ if task is None:
50
+ task = self.default_task
51
+
52
+ # Forward pass
53
+ outputs = self.model.forward(features, task=task, truncate_dim=truncate_dim)
54
+ return outputs
55
+
56
+ def postprocess(self, model_outputs, **postprocess_params):
57
+ """
58
+ Convert model outputs to final embeddings
59
+ """
60
+ # Extract embeddings
61
+ embeddings = model_outputs.get("sentence_embedding", None)
62
+
63
+ if embeddings is None:
64
+ raise ValueError("No embeddings were generated")
65
+
66
+ # Convert to numpy
67
+ embeddings = embeddings.cpu().numpy()
68
+
69
+ return embeddings
70
+
71
+ def load_pipeline(model_path: str, device: str = None):
72
+ """
73
+ Load the embedding pipeline from a model path
74
+
75
+ Args:
76
+ model_path: Path to the model directory
77
+ device: Device to use for inference (cpu, cuda, etc.)
78
+
79
+ Returns:
80
+ EmbeddingPipeline instance
81
+ """
82
+ # Determine device
83
+ if device is None:
84
+ device = "cuda" if torch.cuda.is_available() else "cpu"
85
+
86
+ # Load model
87
+ model = Transformer(
88
+ model_name_or_path=model_path,
89
+ model_args={"default_task": "retrieval", "trust_remote_code": True},
90
+ trust_remote_code=True
91
+ )
92
+ model.to(device)
93
+ model.eval()
94
+
95
+ # Create pipeline
96
+ pipeline = EmbeddingPipeline(model=model, device=device)
97
+
98
+ return pipeline
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.30.2
3
+ peft>=0.4.0
4
+ pillow>=9.0.0
5
+ numpy>=1.22.0
6
+ sentencepiece>=0.1.97
7
+ protobuf>=3.20.0
8
+ accelerate>=0.20.0
9
+ gradio>=3.50.0
10
+ requests>=2.28.0
11
+ torchvision>=0.15.0
12
+ fastapi>=0.95.0
13
+ uvicorn>=0.22.0
14
+ pydantic>=1.10.0