Spaces:
Runtime error
Runtime error
File size: 10,399 Bytes
a22e84b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
from qdrant_client import QdrantClient
from qdrant_client.http.exceptions import UnexpectedResponse
from sentence_transformers import SentenceTransformer
from .multimodal_dispatcher import ImageEmbedder, TextEmbedder, TRANSFORMERS_AVAILABLE
import time
import os
import socket
from pathlib import Path
class TopicAwareRetriever:
def __init__(self, qdrant_storage_path="/Users/yufeizhen/Desktop/project/qdrant_storage"):
# Use the same file-based storage path as video_ingester.py
self.qdrant_storage_path = qdrant_storage_path
# Ensure the storage directory exists
os.makedirs(os.path.dirname(self.qdrant_storage_path), exist_ok=True)
# Store client as None initially
self.client = None
# Try to connect
self._connect_to_qdrant()
# Use appropriate embedder based on availability
if TRANSFORMERS_AVAILABLE:
self.embedder = ImageEmbedder()
self.model = None # Not needed with image embedder
else:
# Fallback to text embedder with dimension padding
self.model = SentenceTransformer("all-MiniLM-L6-v2")
self.embedder = None
def _connect_to_qdrant(self):
"""Establish connection to Qdrant with fallbacks"""
# Create a direct connection to specified path instead of localhost:6333
try:
# Try to fix the connection reset issue with ZMQ timeout setting
os.environ["QDRANT_CLIENT_TIMEOUT"] = "30" # 30 second timeout
# Set up client with timeout settings
self.client = QdrantClient(
path=self.qdrant_storage_path,
timeout=30 # 30 second timeout for requests
)
print("Connected to Qdrant storage at: {}".format(self.qdrant_storage_path))
# Verify connection with a simple operation
collections = self.client.get_collections()
print("Available collections: {}".format(collections))
# Check if our collection exists
if self.client.collection_exists("video_chunks"):
count = self.client.count("video_chunks")
print("Found video_chunks collection with {} points".format(count.count))
else:
print("WARNING: video_chunks collection does not exist - have you ingested videos?")
return True
except Exception as e:
print("Error connecting to local Qdrant storage: {}".format(e))
# Fall back to the connection singleton if direct connection fails
try:
from llm_engineering.infrastructure.db.qdrant import connection
self.client = connection
print("Using fallback Qdrant connection singleton")
return True
except Exception as e2:
print("Fallback connection also failed: {}".format(e2))
# Last resort - try localhost connection
try:
self.client = QdrantClient(
host="localhost",
port=6333,
timeout=30 # Add timeout here as well
)
print("Trying localhost connection")
self.client.get_collections() # Test the connection
print("Connected to Qdrant via localhost")
return True
except Exception as e3:
print("All connection attempts failed: {}".format(e3))
self.client = None
return False
def _create_fresh_connection(self):
"""Create a new connection for each search to avoid connection resets"""
try:
# Close any existing connection
if hasattr(self, 'client') and self.client is not None:
# Try closing if possible (may not work with all client versions)
try:
if hasattr(self.client, 'close'):
self.client.close()
except:
pass
# Create a new one
print("Creating fresh connection to Qdrant...")
return QdrantClient(
path=self.qdrant_storage_path,
timeout=30 # 30 second timeout
)
except Exception as e:
print("Failed to create fresh connection: {}".format(e))
return None
def retrieve(self, query: str, k: int=3):
# First check if we have a client
if self.client is None:
print("No Qdrant connection available. Attempting to reconnect...")
if not self._connect_to_qdrant():
print("Failed to establish Qdrant connection")
return []
# Use CLIP's text encoder for queries if available, otherwise use SentenceTransformer
if TRANSFORMERS_AVAILABLE and self.embedder:
try:
print("Encoding query with CLIP: '{}'".format(query[:50] + "..." if len(query) > 50 else query))
query_embedding = self.embedder.encode_text(query)
print("Query embedded successfully")
except Exception as e:
print("Error during query embedding with CLIP: {}".format(e))
if self.model:
print("Falling back to sentence transformer model")
query_embedding = self._encode_with_sentence_transformer(query)
else:
print("No fallback available, returning empty results")
return []
else:
# Use sentence-transformers and pad to 512 dimensions for compatibility
query_embedding = self._encode_with_sentence_transformer(query)
# Add retry mechanism for Qdrant search
max_retries = 5
retry_count = 0
while retry_count < max_retries:
try:
print("Sending search request to Qdrant (attempt {}/{})".format(
retry_count + 1, max_retries))
# For each search, create a fresh connection to avoid connection reset issues
fresh_client = self._create_fresh_connection()
if fresh_client is None:
# Fall back to existing client if fresh connection fails
print("Using existing client...")
fresh_client = self.client
# Check if collection exists before searching
if not fresh_client.collection_exists("video_chunks"):
print("ERROR: video_chunks collection doesn't exist in Qdrant")
return []
# Use shorter timeout for search to avoid hanging
results = fresh_client.search(
collection_name="video_chunks",
query_vector=query_embedding,
limit=k,
with_payload=["start", "end", "video_id", "topics", "text"],
timeout=10 # 10 second timeout just for this search
)
# If successful, process and return results
print("Search successful, found {} results".format(len(results)))
return self._process_results(results)
except (UnexpectedResponse, ConnectionError, socket.error) as e:
retry_count += 1
print("Qdrant search error (attempt {}/{}): {}".format(
retry_count, max_retries, e))
if retry_count >= max_retries:
print("All retry attempts failed, returning empty results")
return []
# Wait before retrying, with exponential backoff
sleep_time = 2 ** retry_count # Exponential backoff: 2, 4, 8, 16, 32 seconds
print("Waiting {} seconds before retrying...".format(sleep_time))
time.sleep(sleep_time)
# Try to reconnect with a completely fresh client
print("Creating completely new connection...")
try:
self.client = QdrantClient(
path=self.qdrant_storage_path,
timeout=30
)
except Exception as reconnect_error:
print("Reconnection failed: {}".format(reconnect_error))
except Exception as other_error:
print("Unexpected error during search: {}".format(other_error))
return [] # Return empty results on any other error
def _encode_with_sentence_transformer(self, query):
"""Use sentence transformer with padding/truncation for compatibility"""
try:
print("Using sentence-transformer for query embedding")
embed = self.model.encode(query)
if len(embed) < 512:
print("Padding embedding from {} to 512 dimensions".format(len(embed)))
query_embedding = embed.tolist() + [0.0] * (512 - len(embed))
elif len(embed) > 512:
print("Truncating embedding from {} to 512 dimensions".format(len(embed)))
query_embedding = embed[:512].tolist()
else:
query_embedding = embed.tolist()
return query_embedding
except Exception as e:
print("Error encoding with sentence transformer: {}".format(e))
# Return a zero vector as last resort
return [0.0] * 512
def _process_results(self, results):
if not results:
return []
clips = []
for hit in results:
payload = hit.payload
clips.append({
"video_id": payload["video_id"],
"start": payload["start"],
"end": payload["end"],
"score": hit.score,
"text": payload.get("text", ""), # Add text content for debugging
"topics": payload.get("topics", [])
})
return clips
|