from typing import Optional, List, Any from pydantic import Field, PrivateAttr from llama_index.core.llms import CustomLLM, CompletionResponse, CompletionResponseGen, LLMMetadata from llama_index.core.llms.callbacks import llm_completion_callback from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor from qwen_vl_utils import process_vision_info import torch from typing import Any, List, Optional from llama_index.core.embeddings import BaseEmbedding from sentence_transformers import SentenceTransformer from PIL import Image class QwenVL7BCustomLLM(CustomLLM): model_name: str = Field(default="Qwen/Qwen2.5-VL-7B-Instruct") context_window: int = Field(default=32768) num_output: int = Field(default=256) _model = PrivateAttr() _processor = PrivateAttr() def __init__(self, **kwargs): super().__init__(**kwargs) self._model = Qwen2_5_VLForConditionalGeneration.from_pretrained( self.model_name, torch_dtype=torch.bfloat16, device_map='balanced' ) self._processor = AutoProcessor.from_pretrained(self.model_name) @property def metadata(self) -> LLMMetadata: return LLMMetadata( context_window=self.context_window, num_output=self.num_output, model_name=self.model_name, ) @llm_completion_callback() def complete( self, prompt: str, image_paths: Optional[List[str]] = None, **kwargs: Any ) -> CompletionResponse: # Prepare multimodal input messages = [{"role": "user", "content": []}] if image_paths: for path in image_paths: messages[0]["content"].append({"type": "image", "image": path}) messages[0]["content"].append({"type": "text", "text": prompt}) # Tokenize and process text = self._processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) image_inputs, video_inputs = process_vision_info(messages) inputs = self._processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) inputs = inputs.to(self._model.device) # Generate output generated_ids = self._model.generate(**inputs, max_new_tokens=self.num_output) generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] output_text = self._processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] return CompletionResponse(text=output_text) @llm_completion_callback() def stream_complete( self, prompt: str, image_paths: Optional[List[str]] = None, **kwargs: Any ) -> CompletionResponseGen: response = self.complete(prompt, image_paths) for token in response.text: yield CompletionResponse(text=token, delta=token) class MultimodalCLIPEmbedding(BaseEmbedding): """ Custom embedding class using CLIP for multimodal capabilities. """ def __init__(self, model_name: str = "clip-ViT-B-32", **kwargs: Any) -> None: super().__init__(**kwargs) self._model = SentenceTransformer(model_name) @classmethod def class_name(cls) -> str: return "multimodal_clip" def _get_query_embedding(self, query: str, image_path: Optional[str] = None) -> List[float]: if image_path: image = Image.open(image_path) embedding = self._model.encode(image) return embedding.tolist() else: embedding = self._model.encode(query) return embedding.tolist() def _get_text_embedding(self, text: str, image_path: Optional[str] = None) -> List[float]: if image_path: image = Image.open(image_path) embedding = self._model.encode(image) return embedding.tolist() else: embedding = self._model.encode(text) return embedding.tolist() def _get_text_embeddings(self, texts: List[str], image_paths: Optional[List[str]] = None) -> List[List[float]]: embeddings = [] image_paths = image_paths or [None] * len(texts) for text, img_path in zip(texts, image_paths): if img_path: image = Image.open(img_path) emb = self._model.encode(image) else: emb = self._model.encode(text) embeddings.append(emb.tolist()) return embeddings async def _aget_query_embedding(self, query: str, image_path: Optional[str] = None) -> List[float]: return self._get_query_embedding(query, image_path) async def _aget_text_embedding(self, text: str, image_path: Optional[str] = None) -> List[float]: return self._get_text_embedding(text, image_path) # BAAI embedding class # To run on Terminal before running the app, you need to install the FlagEmbedding package. # This can be done by cloning the repository and installing it in editable mode. #!git clone https://github.com/FlagOpen/FlagEmbedding.git #cd FlagEmbedding/research/visual_bge #pip install -e . #go back to the app directory #cd ../../.. class BaaiMultimodalEmbedding(BaseEmbedding): """ Custom embedding class using BAAI's FlagEmbedding for multimodal capabilities. Implements the visual_bge Visualized_BGE model with bge-m3 backend. """ def __init__(self, model_name_bge: str = "BAAI/bge-m3", model_weight: str = "Visualized_m3.pth", device: str = "cuda:1", **kwargs: Any) -> None: super().__init__(**kwargs) # Set device self.device = torch.device(device if torch.cuda.is_available() else "cpu") print(f"BaaiMultimodalEmbedding initializing on device: {self.device}") # Import the visual_bge module from visual_bge.modeling import Visualized_BGE self._model = Visualized_BGE( model_name_bge=model_name_bge, model_weight=model_weight ) self._model.to(self.device) self._model.eval() print(f"Successfully loaded BAAI Visualized_BGE with {model_name_bge}") @classmethod def class_name(cls) -> str: return "baai_multimodal" def _get_query_embedding(self, query: str, image_path: Optional[str] = None) -> List[float]: """Get embedding for query with optional image""" with torch.no_grad(): if hasattr(self._model, 'encode') and hasattr(self._model, 'preprocess_val'): # Using visual_bge if image_path and query: # Combined text and image query embedding = self._model.encode(image=image_path, text=query) elif image_path: # Image only embedding = self._model.encode(image=image_path) else: # Text only embedding = self._model.encode(text=query) else: # Fallback to sentence-transformers if image_path: from PIL import Image image = Image.open(image_path) embedding = self._model.encode(image) else: embedding = self._model.encode(query) return embedding.cpu().numpy().tolist() if torch.is_tensor(embedding) else embedding.tolist() def _get_text_embedding(self, text: str, image_path: Optional[str] = None) -> List[float]: """Get embedding for text with optional image""" return self._get_query_embedding(text, image_path) def _get_text_embeddings(self, texts: List[str], image_paths: Optional[List[str]] = None) -> List[List[float]]: """Get embeddings for multiple texts with optional images""" embeddings = [] image_paths = image_paths or [None] * len(texts) for text, img_path in zip(texts, image_paths): emb = self._get_text_embedding(text, img_path) embeddings.append(emb) return embeddings async def _aget_query_embedding(self, query: str, image_path: Optional[str] = None) -> List[float]: return self._get_query_embedding(query, image_path) async def _aget_text_embedding(self, text: str, image_path: Optional[str] = None) -> List[float]: return self._get_text_embedding(text, image_path) class PixtralQuantizedLLM(CustomLLM): """ Pixtral 12B quantized model implementation for Kaggle compatibility. Uses float8 quantization for memory efficiency. """ model_name: str = Field(default="mistralai/Pixtral-12B-2409") context_window: int = Field(default=128000) num_output: int = Field(default=512) quantization: str = Field(default="fp8") _model = PrivateAttr() _processor = PrivateAttr() def __init__(self, **kwargs): super().__init__(**kwargs) # Check if we're in a Kaggle environment or have limited resources import psutil available_memory = psutil.virtual_memory().available / (1024**3) # GB if available_memory < 20: # Less than 20GB RAM print(f"Limited memory detected ({available_memory:.1f}GB), using quantized version") self._load_quantized_model() else: print("Sufficient memory available, attempting full model load") try: self._load_full_model() except Exception as e: print(f"Full model loading failed: {e}, falling back to quantized") self._load_quantized_model() def _load_quantized_model(self): """Load quantized Pixtral model for resource-constrained environments""" try: # Try to use a pre-quantized version from HuggingFace quantized_models = [ "RedHatAI/pixtral-12b-FP8-dynamic" ] model_loaded = False for model_id in quantized_models: try: print(f"Attempting to load quantized model: {model_id}") # Standard quantized model loading from transformers import AutoModelForCausalLM, AutoProcessor self._model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float8, device_map="auto", trust_remote_code=True ) self._processor = AutoProcessor.from_pretrained(model_id) print(f"Successfully loaded quantized Pixtral: {model_id}") model_loaded = True break except Exception as e: print(f"Failed to load {model_id}: {e}") continue if not model_loaded: print("All quantized models failed, using CPU-only fallback") self._load_cpu_fallback() except Exception as e: print(f"Quantized loading failed: {e}") self._load_cpu_fallback() def _load_full_model(self): """Load full Pixtral model""" from transformers import AutoModelForCausalLM, AutoProcessor self._model = AutoModelForCausalLM.from_pretrained( self.model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True ) self._processor = AutoProcessor.from_pretrained(self.model_name) def _load_cpu_fallback(self): """Fallback to CPU-only inference""" try: from transformers import AutoModelForCausalLM, AutoProcessor self._model = AutoModelForCausalLM.from_pretrained( "microsoft/DialoGPT-medium", # Smaller fallback model torch_dtype=torch.float32, device_map="cpu" ) self._processor = AutoProcessor.from_pretrained("microsoft/DialoGPT-medium") print("Using CPU fallback model (DialoGPT-medium)") except Exception as e: print(f"CPU fallback failed: {e}") # Use a minimal implementation self._model = None self._processor = None @property def metadata(self) -> LLMMetadata: return LLMMetadata( context_window=self.context_window, num_output=self.num_output, model_name=f"{self.model_name}-{self.quantization}", ) @llm_completion_callback() def complete( self, prompt: str, image_paths: Optional[List[str]] = None, **kwargs: Any ) -> CompletionResponse: if self._model is None: return CompletionResponse(text="Model not available in current environment") try: # Prepare multimodal input if images provided if image_paths and hasattr(self._processor, 'apply_chat_template'): # Handle multimodal input messages = [{"role": "user", "content": []}] if image_paths: for path in image_paths[:4]: # Limit to 4 images for memory messages[0]["content"].append({"type": "image", "image": path}) messages[0]["content"].append({"type": "text", "text": prompt}) # Process the input inputs = self._processor(messages, return_tensors="pt", padding=True) inputs = {k: v.to(self._model.device) for k, v in inputs.items()} # Generate with torch.no_grad(): outputs = self._model.generate( **inputs, max_new_tokens=min(self.num_output, 256), # Limit for memory do_sample=True, temperature=0.7, pad_token_id=self._processor.tokenizer.eos_token_id ) # Decode response response = self._processor.batch_decode(outputs, skip_special_tokens=True)[0] # Extract only the new generated part if len(messages[0]["content"]) > 0: response = response.split(prompt)[-1].strip() else: # Text-only fallback inputs = self._processor(prompt, return_tensors="pt", padding=True) inputs = {k: v.to(self._model.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self._model.generate( **inputs, max_new_tokens=min(self.num_output, 256), do_sample=True, temperature=0.7, pad_token_id=self._processor.tokenizer.eos_token_id ) response = self._processor.batch_decode(outputs, skip_special_tokens=True)[0] response = response.replace(prompt, "").strip() return CompletionResponse(text=response) except Exception as e: error_msg = f"Generation error: {str(e)}" print(error_msg) return CompletionResponse(text=error_msg) @llm_completion_callback() def stream_complete( self, prompt: str, image_paths: Optional[List[str]] = None, **kwargs: Any ) -> CompletionResponseGen: # For quantized models, streaming might not be efficient # Return the complete response as a single chunk response = self.complete(prompt, image_paths, **kwargs) yield response