Spaces:
Runtime error
Runtime error
from qdrant_client import QdrantClient | |
from qdrant_client.http.exceptions import UnexpectedResponse | |
from sentence_transformers import SentenceTransformer | |
from .multimodal_dispatcher import ImageEmbedder, TextEmbedder, TRANSFORMERS_AVAILABLE | |
import time | |
import os | |
import socket | |
from pathlib import Path | |
class TopicAwareRetriever: | |
def __init__(self, qdrant_storage_path="/Users/yufeizhen/Desktop/project/qdrant_storage"): | |
# Use the same file-based storage path as video_ingester.py | |
self.qdrant_storage_path = qdrant_storage_path | |
# Ensure the storage directory exists | |
os.makedirs(os.path.dirname(self.qdrant_storage_path), exist_ok=True) | |
# Store client as None initially | |
self.client = None | |
# Try to connect | |
self._connect_to_qdrant() | |
# Use appropriate embedder based on availability | |
if TRANSFORMERS_AVAILABLE: | |
self.embedder = ImageEmbedder() | |
self.model = None # Not needed with image embedder | |
else: | |
# Fallback to text embedder with dimension padding | |
self.model = SentenceTransformer("all-MiniLM-L6-v2") | |
self.embedder = None | |
def _connect_to_qdrant(self): | |
"""Establish connection to Qdrant with fallbacks""" | |
# Create a direct connection to specified path instead of localhost:6333 | |
try: | |
# Try to fix the connection reset issue with ZMQ timeout setting | |
os.environ["QDRANT_CLIENT_TIMEOUT"] = "30" # 30 second timeout | |
# Set up client with timeout settings | |
self.client = QdrantClient( | |
path=self.qdrant_storage_path, | |
timeout=30 # 30 second timeout for requests | |
) | |
print("Connected to Qdrant storage at: {}".format(self.qdrant_storage_path)) | |
# Verify connection with a simple operation | |
collections = self.client.get_collections() | |
print("Available collections: {}".format(collections)) | |
# Check if our collection exists | |
if self.client.collection_exists("video_chunks"): | |
count = self.client.count("video_chunks") | |
print("Found video_chunks collection with {} points".format(count.count)) | |
else: | |
print("WARNING: video_chunks collection does not exist - have you ingested videos?") | |
return True | |
except Exception as e: | |
print("Error connecting to local Qdrant storage: {}".format(e)) | |
# Fall back to the connection singleton if direct connection fails | |
try: | |
from llm_engineering.infrastructure.db.qdrant import connection | |
self.client = connection | |
print("Using fallback Qdrant connection singleton") | |
return True | |
except Exception as e2: | |
print("Fallback connection also failed: {}".format(e2)) | |
# Last resort - try localhost connection | |
try: | |
self.client = QdrantClient( | |
host="localhost", | |
port=6333, | |
timeout=30 # Add timeout here as well | |
) | |
print("Trying localhost connection") | |
self.client.get_collections() # Test the connection | |
print("Connected to Qdrant via localhost") | |
return True | |
except Exception as e3: | |
print("All connection attempts failed: {}".format(e3)) | |
self.client = None | |
return False | |
def _create_fresh_connection(self): | |
"""Create a new connection for each search to avoid connection resets""" | |
try: | |
# Close any existing connection | |
if hasattr(self, 'client') and self.client is not None: | |
# Try closing if possible (may not work with all client versions) | |
try: | |
if hasattr(self.client, 'close'): | |
self.client.close() | |
except: | |
pass | |
# Create a new one | |
print("Creating fresh connection to Qdrant...") | |
return QdrantClient( | |
path=self.qdrant_storage_path, | |
timeout=30 # 30 second timeout | |
) | |
except Exception as e: | |
print("Failed to create fresh connection: {}".format(e)) | |
return None | |
def retrieve(self, query: str, k: int=3): | |
# First check if we have a client | |
if self.client is None: | |
print("No Qdrant connection available. Attempting to reconnect...") | |
if not self._connect_to_qdrant(): | |
print("Failed to establish Qdrant connection") | |
return [] | |
# Use CLIP's text encoder for queries if available, otherwise use SentenceTransformer | |
if TRANSFORMERS_AVAILABLE and self.embedder: | |
try: | |
print("Encoding query with CLIP: '{}'".format(query[:50] + "..." if len(query) > 50 else query)) | |
query_embedding = self.embedder.encode_text(query) | |
print("Query embedded successfully") | |
except Exception as e: | |
print("Error during query embedding with CLIP: {}".format(e)) | |
if self.model: | |
print("Falling back to sentence transformer model") | |
query_embedding = self._encode_with_sentence_transformer(query) | |
else: | |
print("No fallback available, returning empty results") | |
return [] | |
else: | |
# Use sentence-transformers and pad to 512 dimensions for compatibility | |
query_embedding = self._encode_with_sentence_transformer(query) | |
# Add retry mechanism for Qdrant search | |
max_retries = 5 | |
retry_count = 0 | |
while retry_count < max_retries: | |
try: | |
print("Sending search request to Qdrant (attempt {}/{})".format( | |
retry_count + 1, max_retries)) | |
# For each search, create a fresh connection to avoid connection reset issues | |
fresh_client = self._create_fresh_connection() | |
if fresh_client is None: | |
# Fall back to existing client if fresh connection fails | |
print("Using existing client...") | |
fresh_client = self.client | |
# Check if collection exists before searching | |
if not fresh_client.collection_exists("video_chunks"): | |
print("ERROR: video_chunks collection doesn't exist in Qdrant") | |
return [] | |
# Use shorter timeout for search to avoid hanging | |
results = fresh_client.search( | |
collection_name="video_chunks", | |
query_vector=query_embedding, | |
limit=k, | |
with_payload=["start", "end", "video_id", "topics", "text"], | |
timeout=10 # 10 second timeout just for this search | |
) | |
# If successful, process and return results | |
print("Search successful, found {} results".format(len(results))) | |
return self._process_results(results) | |
except (UnexpectedResponse, ConnectionError, socket.error) as e: | |
retry_count += 1 | |
print("Qdrant search error (attempt {}/{}): {}".format( | |
retry_count, max_retries, e)) | |
if retry_count >= max_retries: | |
print("All retry attempts failed, returning empty results") | |
return [] | |
# Wait before retrying, with exponential backoff | |
sleep_time = 2 ** retry_count # Exponential backoff: 2, 4, 8, 16, 32 seconds | |
print("Waiting {} seconds before retrying...".format(sleep_time)) | |
time.sleep(sleep_time) | |
# Try to reconnect with a completely fresh client | |
print("Creating completely new connection...") | |
try: | |
self.client = QdrantClient( | |
path=self.qdrant_storage_path, | |
timeout=30 | |
) | |
except Exception as reconnect_error: | |
print("Reconnection failed: {}".format(reconnect_error)) | |
except Exception as other_error: | |
print("Unexpected error during search: {}".format(other_error)) | |
return [] # Return empty results on any other error | |
def _encode_with_sentence_transformer(self, query): | |
"""Use sentence transformer with padding/truncation for compatibility""" | |
try: | |
print("Using sentence-transformer for query embedding") | |
embed = self.model.encode(query) | |
if len(embed) < 512: | |
print("Padding embedding from {} to 512 dimensions".format(len(embed))) | |
query_embedding = embed.tolist() + [0.0] * (512 - len(embed)) | |
elif len(embed) > 512: | |
print("Truncating embedding from {} to 512 dimensions".format(len(embed))) | |
query_embedding = embed[:512].tolist() | |
else: | |
query_embedding = embed.tolist() | |
return query_embedding | |
except Exception as e: | |
print("Error encoding with sentence transformer: {}".format(e)) | |
# Return a zero vector as last resort | |
return [0.0] * 512 | |
def _process_results(self, results): | |
if not results: | |
return [] | |
clips = [] | |
for hit in results: | |
payload = hit.payload | |
clips.append({ | |
"video_id": payload["video_id"], | |
"start": payload["start"], | |
"end": payload["end"], | |
"score": hit.score, | |
"text": payload.get("text", ""), # Add text content for debugging | |
"topics": payload.get("topics", []) | |
}) | |
return clips | |