import os import torch import gradio as gr import numpy as np from typing import List, Union, Optional from PIL import Image import requests from io import BytesIO import base64 # Import your handler from handler import ModelHandler # Create model handler instance model_handler = ModelHandler() model_handler.initialize(None) # We'll handle device placement manually def cosine_similarity(embedding1, embedding2): """Calculate cosine similarity between two embeddings""" embedding1 = np.array(embedding1) embedding2 = np.array(embedding2) return np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2)) def process_image(image_input): """Process image input (URL, uploaded file, or base64)""" if isinstance(image_input, str): if image_input.startswith("http"): # URL response = requests.get(image_input) image = Image.open(BytesIO(response.content)).convert("RGB") elif image_input.startswith("data:image"): # Base64 image_data = image_input.split(",")[1] image = Image.open(BytesIO(base64.b64decode(image_data))).convert("RGB") else: # Local path image = Image.open(image_input).convert("RGB") else: # Uploaded file from Gradio image = Image.open(image_input).convert("RGB") return image def generate_embeddings(inputs, task="retrieval", input_type="text"): """Generate embeddings for text or image inputs""" try: # Handle different input types if input_type == "text": features = model_handler.model.tokenize(inputs) else: # image processed_images = [process_image(img) for img in inputs] features = model_handler.model.tokenize(processed_images) # Process features through model with torch.no_grad(): outputs = model_handler.model.forward(features, task=task) embeddings = outputs.get("sentence_embedding", None) if embeddings is not None: return embeddings.cpu().numpy().tolist() else: return None except Exception as e: return {"error": str(e)} def text_to_embedding(text, task="retrieval"): """Convert text to embedding""" if not text.strip(): return None return generate_embeddings([text], task=task, input_type="text")[0] def image_to_embedding(image, task="retrieval"): """Convert image to embedding""" if image is None: return None return generate_embeddings([image], task=task, input_type="image")[0] def compare_embeddings(embedding1, embedding2): """Compare two embeddings and return similarity""" if embedding1 is None or embedding2 is None: return "Please generate both embeddings first" similarity = cosine_similarity(embedding1, embedding2) return f"Cosine Similarity: {similarity:.4f}" # Create Gradio interface with gr.Blocks(title="Embedding Model Demo") as demo: gr.Markdown("# Embedding Model Demo") gr.Markdown("Generate and compare embeddings for text and images") with gr.Tab("Text Embeddings"): with gr.Row(): with gr.Column(): text_input1 = gr.Textbox(label="Text Input 1", lines=5) task_dropdown1 = gr.Dropdown( ["retrieval", "text-matching", "code"], label="Task", value="retrieval" ) text_embed_btn1 = gr.Button("Generate Embedding 1") with gr.Column(): text_input2 = gr.Textbox(label="Text Input 2", lines=5) task_dropdown2 = gr.Dropdown( ["retrieval", "text-matching", "code"], label="Task", value="retrieval" ) text_embed_btn2 = gr.Button("Generate Embedding 2") embedding_output1 = gr.JSON(label="Embedding 1", visible=False) embedding_output2 = gr.JSON(label="Embedding 2", visible=False) compare_btn = gr.Button("Compare Embeddings") similarity_output = gr.Textbox(label="Similarity Result") with gr.Tab("Image Embeddings"): with gr.Row(): with gr.Column(): image_input1 = gr.Image(label="Image Input 1", type="pil") image_task_dropdown1 = gr.Dropdown( ["retrieval"], label="Task", value="retrieval" ) image_embed_btn1 = gr.Button("Generate Embedding 1") with gr.Column(): image_input2 = gr.Image(label="Image Input 2", type="pil") image_task_dropdown2 = gr.Dropdown( ["retrieval"], label="Task", value="retrieval" ) image_embed_btn2 = gr.Button("Generate Embedding 2") image_embedding_output1 = gr.JSON(label="Embedding 1", visible=False) image_embedding_output2 = gr.JSON(label="Embedding 2", visible=False) image_compare_btn = gr.Button("Compare Embeddings") image_similarity_output = gr.Textbox(label="Similarity Result") with gr.Tab("Cross-Modal Comparison"): with gr.Row(): with gr.Column(): cross_text_input = gr.Textbox(label="Text Input", lines=5) cross_text_task = gr.Dropdown( ["retrieval"], label="Task", value="retrieval" ) cross_text_btn = gr.Button("Generate Text Embedding") with gr.Column(): cross_image_input = gr.Image(label="Image Input", type="pil") cross_image_task = gr.Dropdown( ["retrieval"], label="Task", value="retrieval" ) cross_image_btn = gr.Button("Generate Image Embedding") cross_text_embedding = gr.JSON(label="Text Embedding", visible=False) cross_image_embedding = gr.JSON(label="Image Embedding", visible=False) cross_compare_btn = gr.Button("Compare Text and Image") cross_similarity_output = gr.Textbox(label="Similarity Result") # Text tab events text_embed_btn1.click( fn=text_to_embedding, inputs=[text_input1, task_dropdown1], outputs=embedding_output1 ) text_embed_btn2.click( fn=text_to_embedding, inputs=[text_input2, task_dropdown2], outputs=embedding_output2 ) compare_btn.click( fn=compare_embeddings, inputs=[embedding_output1, embedding_output2], outputs=similarity_output ) # Image tab events image_embed_btn1.click( fn=image_to_embedding, inputs=[image_input1, image_task_dropdown1], outputs=image_embedding_output1 ) image_embed_btn2.click( fn=image_to_embedding, inputs=[image_input2, image_task_dropdown2], outputs=image_embedding_output2 ) image_compare_btn.click( fn=compare_embeddings, inputs=[image_embedding_output1, image_embedding_output2], outputs=image_similarity_output ) # Cross-modal tab events cross_text_btn.click( fn=text_to_embedding, inputs=[cross_text_input, cross_text_task], outputs=cross_text_embedding ) cross_image_btn.click( fn=image_to_embedding, inputs=[cross_image_input, cross_image_task], outputs=cross_image_embedding ) cross_compare_btn.click( fn=compare_embeddings, inputs=[cross_text_embedding, cross_image_embedding], outputs=cross_similarity_output ) # Launch the Gradio app if __name__ == "__main__": demo.launch()