purpleriann's picture
Upload folder using huggingface_hub
a22e84b verified
from qdrant_client import QdrantClient
from qdrant_client.http.exceptions import UnexpectedResponse
from sentence_transformers import SentenceTransformer
from .multimodal_dispatcher import ImageEmbedder, TextEmbedder, TRANSFORMERS_AVAILABLE
import time
import os
import socket
from pathlib import Path
class TopicAwareRetriever:
def __init__(self, qdrant_storage_path="/Users/yufeizhen/Desktop/project/qdrant_storage"):
# Use the same file-based storage path as video_ingester.py
self.qdrant_storage_path = qdrant_storage_path
# Ensure the storage directory exists
os.makedirs(os.path.dirname(self.qdrant_storage_path), exist_ok=True)
# Store client as None initially
self.client = None
# Try to connect
self._connect_to_qdrant()
# Use appropriate embedder based on availability
if TRANSFORMERS_AVAILABLE:
self.embedder = ImageEmbedder()
self.model = None # Not needed with image embedder
else:
# Fallback to text embedder with dimension padding
self.model = SentenceTransformer("all-MiniLM-L6-v2")
self.embedder = None
def _connect_to_qdrant(self):
"""Establish connection to Qdrant with fallbacks"""
# Create a direct connection to specified path instead of localhost:6333
try:
# Try to fix the connection reset issue with ZMQ timeout setting
os.environ["QDRANT_CLIENT_TIMEOUT"] = "30" # 30 second timeout
# Set up client with timeout settings
self.client = QdrantClient(
path=self.qdrant_storage_path,
timeout=30 # 30 second timeout for requests
)
print("Connected to Qdrant storage at: {}".format(self.qdrant_storage_path))
# Verify connection with a simple operation
collections = self.client.get_collections()
print("Available collections: {}".format(collections))
# Check if our collection exists
if self.client.collection_exists("video_chunks"):
count = self.client.count("video_chunks")
print("Found video_chunks collection with {} points".format(count.count))
else:
print("WARNING: video_chunks collection does not exist - have you ingested videos?")
return True
except Exception as e:
print("Error connecting to local Qdrant storage: {}".format(e))
# Fall back to the connection singleton if direct connection fails
try:
from llm_engineering.infrastructure.db.qdrant import connection
self.client = connection
print("Using fallback Qdrant connection singleton")
return True
except Exception as e2:
print("Fallback connection also failed: {}".format(e2))
# Last resort - try localhost connection
try:
self.client = QdrantClient(
host="localhost",
port=6333,
timeout=30 # Add timeout here as well
)
print("Trying localhost connection")
self.client.get_collections() # Test the connection
print("Connected to Qdrant via localhost")
return True
except Exception as e3:
print("All connection attempts failed: {}".format(e3))
self.client = None
return False
def _create_fresh_connection(self):
"""Create a new connection for each search to avoid connection resets"""
try:
# Close any existing connection
if hasattr(self, 'client') and self.client is not None:
# Try closing if possible (may not work with all client versions)
try:
if hasattr(self.client, 'close'):
self.client.close()
except:
pass
# Create a new one
print("Creating fresh connection to Qdrant...")
return QdrantClient(
path=self.qdrant_storage_path,
timeout=30 # 30 second timeout
)
except Exception as e:
print("Failed to create fresh connection: {}".format(e))
return None
def retrieve(self, query: str, k: int=3):
# First check if we have a client
if self.client is None:
print("No Qdrant connection available. Attempting to reconnect...")
if not self._connect_to_qdrant():
print("Failed to establish Qdrant connection")
return []
# Use CLIP's text encoder for queries if available, otherwise use SentenceTransformer
if TRANSFORMERS_AVAILABLE and self.embedder:
try:
print("Encoding query with CLIP: '{}'".format(query[:50] + "..." if len(query) > 50 else query))
query_embedding = self.embedder.encode_text(query)
print("Query embedded successfully")
except Exception as e:
print("Error during query embedding with CLIP: {}".format(e))
if self.model:
print("Falling back to sentence transformer model")
query_embedding = self._encode_with_sentence_transformer(query)
else:
print("No fallback available, returning empty results")
return []
else:
# Use sentence-transformers and pad to 512 dimensions for compatibility
query_embedding = self._encode_with_sentence_transformer(query)
# Add retry mechanism for Qdrant search
max_retries = 5
retry_count = 0
while retry_count < max_retries:
try:
print("Sending search request to Qdrant (attempt {}/{})".format(
retry_count + 1, max_retries))
# For each search, create a fresh connection to avoid connection reset issues
fresh_client = self._create_fresh_connection()
if fresh_client is None:
# Fall back to existing client if fresh connection fails
print("Using existing client...")
fresh_client = self.client
# Check if collection exists before searching
if not fresh_client.collection_exists("video_chunks"):
print("ERROR: video_chunks collection doesn't exist in Qdrant")
return []
# Use shorter timeout for search to avoid hanging
results = fresh_client.search(
collection_name="video_chunks",
query_vector=query_embedding,
limit=k,
with_payload=["start", "end", "video_id", "topics", "text"],
timeout=10 # 10 second timeout just for this search
)
# If successful, process and return results
print("Search successful, found {} results".format(len(results)))
return self._process_results(results)
except (UnexpectedResponse, ConnectionError, socket.error) as e:
retry_count += 1
print("Qdrant search error (attempt {}/{}): {}".format(
retry_count, max_retries, e))
if retry_count >= max_retries:
print("All retry attempts failed, returning empty results")
return []
# Wait before retrying, with exponential backoff
sleep_time = 2 ** retry_count # Exponential backoff: 2, 4, 8, 16, 32 seconds
print("Waiting {} seconds before retrying...".format(sleep_time))
time.sleep(sleep_time)
# Try to reconnect with a completely fresh client
print("Creating completely new connection...")
try:
self.client = QdrantClient(
path=self.qdrant_storage_path,
timeout=30
)
except Exception as reconnect_error:
print("Reconnection failed: {}".format(reconnect_error))
except Exception as other_error:
print("Unexpected error during search: {}".format(other_error))
return [] # Return empty results on any other error
def _encode_with_sentence_transformer(self, query):
"""Use sentence transformer with padding/truncation for compatibility"""
try:
print("Using sentence-transformer for query embedding")
embed = self.model.encode(query)
if len(embed) < 512:
print("Padding embedding from {} to 512 dimensions".format(len(embed)))
query_embedding = embed.tolist() + [0.0] * (512 - len(embed))
elif len(embed) > 512:
print("Truncating embedding from {} to 512 dimensions".format(len(embed)))
query_embedding = embed[:512].tolist()
else:
query_embedding = embed.tolist()
return query_embedding
except Exception as e:
print("Error encoding with sentence transformer: {}".format(e))
# Return a zero vector as last resort
return [0.0] * 512
def _process_results(self, results):
if not results:
return []
clips = []
for hit in results:
payload = hit.payload
clips.append({
"video_id": payload["video_id"],
"start": payload["start"],
"end": payload["end"],
"score": hit.score,
"text": payload.get("text", ""), # Add text content for debugging
"topics": payload.get("topics", [])
})
return clips