davanstrien HF Staff commited on
Commit
c97aadf
·
1 Parent(s): 595f871

refactor: update model and embedding configurations, enhance logging for database setup

Browse files
Files changed (1) hide show
  1. main.py +110 -17
main.py CHANGED
@@ -24,8 +24,8 @@ load_dotenv(override=True)
24
  HF_TOKEN = os.getenv("HF_TOKEN")
25
  login(token=HF_TOKEN)
26
  # Configuration constants
27
- MODEL_NAME = "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13"
28
- EMBEDDING_MODEL = "nomic-ai/modernbert-embed-base"
29
  BATCH_SIZE = 2000
30
  CACHE_TTL = "24h"
31
  TRENDING_CACHE_TTL = "1h" # 15 minutes cache for trending data
@@ -38,9 +38,7 @@ else:
38
  DEVICE = "cpu"
39
 
40
 
41
- tokenizer = AutoTokenizer.from_pretrained(
42
- "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13"
43
- )
44
 
45
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # turn on HF_TRANSFER
46
  # Set up logging
@@ -90,7 +88,7 @@ app.add_middleware(
90
  def get_embedding_function():
91
  logger.info(f"Using device: {DEVICE}")
92
  return embedding_functions.SentenceTransformerEmbeddingFunction(
93
- model_name="nomic-ai/modernbert-embed-base", device=DEVICE
94
  )
95
 
96
 
@@ -135,24 +133,64 @@ def setup_database():
135
  logger.info(f"Most recent record in DB from: {latest_update}")
136
  logger.info(f"Oldest record in DB from: {min(last_modifieds)}")
137
 
 
 
 
 
138
  # Filter and process only newer records
139
  df = df.select(["datasetId", "summary", "likes", "downloads", "last_modified"])
140
 
141
- # Log some stats about the incoming data
142
- sample_dates = df.select("last_modified").limit(5).collect()
143
- logger.info(f"Sample of incoming dates: {sample_dates}")
144
-
145
  total_incoming = df.select(pl.len()).collect().item()
146
- logger.info(f"Total incoming records: {total_incoming}")
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  if latest_update:
149
  logger.info(f"Filtering records newer than {latest_update}")
 
 
 
 
 
 
 
 
 
 
 
150
  # Ensure last_modified is datetime before comparison
151
  df = df.with_columns(pl.col("last_modified").str.to_datetime())
152
  df = df.filter(pl.col("last_modified") > latest_update)
153
  filtered_count = df.select(pl.len()).collect().item()
154
  logger.info(f"Found {filtered_count} records to update after filtering")
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  df = df.collect()
157
  total_rows = len(df)
158
 
@@ -170,8 +208,26 @@ def setup_database():
170
  f"({batch_df['last_modified'].min()} to {batch_df['last_modified'].max()})"
171
  )
172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  dataset_collection.upsert(
174
- ids=batch_df.select(["datasetId"]).to_series().to_list(),
175
  documents=batch_df.select(["summary"]).to_series().to_list(),
176
  metadatas=[
177
  {
@@ -188,18 +244,55 @@ def setup_database():
188
  )
189
  logger.info(f"Processed {i + batch_size:,} / {total_rows:,} records")
190
 
191
- logger.info(
192
- f"Database initialized with {dataset_collection.count():,} total rows"
193
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  # Load model data
196
  model_lazy_df = pl.scan_parquet(
197
  "hf://datasets/davanstrien/models_with_metadata_and_summaries/data/train-*.parquet"
198
  )
199
  model_row_count = model_lazy_df.select(pl.len()).collect().item()
200
- logger.info(f"Row count of new model data: {model_row_count}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
- if model_collection.count() < model_row_count:
 
 
203
  schema = model_lazy_df.collect_schema()
204
  select_columns = [
205
  "modelId",
 
24
  HF_TOKEN = os.getenv("HF_TOKEN")
25
  login(token=HF_TOKEN)
26
  # Configuration constants
27
+ MODEL_NAME = "davanstrien/Smol-Hub-tldr"
28
+ EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-0.6B"
29
  BATCH_SIZE = 2000
30
  CACHE_TTL = "24h"
31
  TRENDING_CACHE_TTL = "1h" # 15 minutes cache for trending data
 
38
  DEVICE = "cpu"
39
 
40
 
41
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
 
 
42
 
43
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # turn on HF_TRANSFER
44
  # Set up logging
 
88
  def get_embedding_function():
89
  logger.info(f"Using device: {DEVICE}")
90
  return embedding_functions.SentenceTransformerEmbeddingFunction(
91
+ model_name="Qwen/Qwen3-Embedding-0.6B", device=DEVICE
92
  )
93
 
94
 
 
133
  logger.info(f"Most recent record in DB from: {latest_update}")
134
  logger.info(f"Oldest record in DB from: {min(last_modifieds)}")
135
 
136
+ # Log sample of existing timestamps for debugging
137
+ sample_timestamps = sorted(last_modifieds, reverse=True)[:5]
138
+ logger.info(f"Sample of most recent DB timestamps: {sample_timestamps}")
139
+
140
  # Filter and process only newer records
141
  df = df.select(["datasetId", "summary", "likes", "downloads", "last_modified"])
142
 
143
+ # Log some stats about the incoming data BEFORE collecting
 
 
 
144
  total_incoming = df.select(pl.len()).collect().item()
145
+ logger.info(f"Total incoming records from source: {total_incoming}")
146
+
147
+ # Get sample of dates to understand the data
148
+ sample_df = (
149
+ df.select(["datasetId", "last_modified"])
150
+ .sort("last_modified", descending=True)
151
+ .limit(10)
152
+ .collect()
153
+ )
154
+ logger.info("Sample of most recent incoming records:")
155
+ for row in sample_df.iter_rows():
156
+ logger.info(f" {row[0]}: {row[1]}")
157
 
158
  if latest_update:
159
  logger.info(f"Filtering records newer than {latest_update}")
160
+ logger.info(f"Latest update type: {type(latest_update)}")
161
+
162
+ # Get date range before filtering
163
+ date_stats = df.select(
164
+ [
165
+ pl.col("last_modified").min().alias("min_date"),
166
+ pl.col("last_modified").max().alias("max_date"),
167
+ ]
168
+ ).collect()
169
+ logger.info(f"Incoming data date range: {date_stats.row(0)}")
170
+
171
  # Ensure last_modified is datetime before comparison
172
  df = df.with_columns(pl.col("last_modified").str.to_datetime())
173
  df = df.filter(pl.col("last_modified") > latest_update)
174
  filtered_count = df.select(pl.len()).collect().item()
175
  logger.info(f"Found {filtered_count} records to update after filtering")
176
 
177
+ if filtered_count == 0:
178
+ logger.warning(
179
+ "No new records found after filtering! This might indicate a problem."
180
+ )
181
+ # Log a few records that were just below the cutoff
182
+ just_before = (
183
+ df.select(["datasetId", "last_modified"])
184
+ .filter(pl.col("last_modified") <= latest_update)
185
+ .sort("last_modified", descending=True)
186
+ .limit(5)
187
+ .collect()
188
+ )
189
+ if len(just_before) > 0:
190
+ logger.info("Records just before cutoff:")
191
+ for row in just_before.iter_rows():
192
+ logger.info(f" {row[0]}: {row[1]}")
193
+
194
  df = df.collect()
195
  total_rows = len(df)
196
 
 
208
  f"({batch_df['last_modified'].min()} to {batch_df['last_modified'].max()})"
209
  )
210
 
211
+ ids_to_upsert = batch_df.select(["datasetId"]).to_series().to_list()
212
+
213
+ # Log first few IDs being upserted
214
+ logger.info(f"Upserting IDs (first 5): {ids_to_upsert[:5]}")
215
+
216
+ # Check if any of these already exist
217
+ existing_check = dataset_collection.get(
218
+ ids=ids_to_upsert[:5], include=["metadatas"]
219
+ )
220
+ if existing_check["ids"]:
221
+ logger.info(
222
+ f"Found {len(existing_check['ids'])} existing records in this batch sample"
223
+ )
224
+ for idx, id_ in enumerate(existing_check["ids"]):
225
+ logger.info(
226
+ f" Existing: {id_} - last_modified: {existing_check['metadatas'][idx].get('last_modified')}"
227
+ )
228
+
229
  dataset_collection.upsert(
230
+ ids=ids_to_upsert,
231
  documents=batch_df.select(["summary"]).to_series().to_list(),
232
  metadatas=[
233
  {
 
244
  )
245
  logger.info(f"Processed {i + batch_size:,} / {total_rows:,} records")
246
 
247
+ # Final validation
248
+ final_count = dataset_collection.count()
249
+ logger.info(f"Database initialized with {final_count:,} total rows")
250
+
251
+ # Verify the update worked by checking latest records
252
+ if final_count > 0:
253
+ final_metadata = dataset_collection.get(include=["metadatas"], limit=5)
254
+ final_timestamps = [
255
+ dateutil.parser.parse(m.get("last_modified"))
256
+ for m in final_metadata.get("metadatas")
257
+ ]
258
+ if final_timestamps:
259
+ latest_after_update = max(final_timestamps)
260
+ logger.info(f"Latest record after update: {latest_after_update}")
261
+ if latest_update and latest_after_update <= latest_update:
262
+ logger.error(
263
+ "WARNING: No new records were added! Latest timestamp hasn't changed."
264
+ )
265
+ elif latest_update:
266
+ logger.info(
267
+ f"Successfully added records from {latest_update} to {latest_after_update}"
268
+ )
269
 
270
  # Load model data
271
  model_lazy_df = pl.scan_parquet(
272
  "hf://datasets/davanstrien/models_with_metadata_and_summaries/data/train-*.parquet"
273
  )
274
  model_row_count = model_lazy_df.select(pl.len()).collect().item()
275
+ logger.info(f"Total model records in source: {model_row_count}")
276
+
277
+ # Get the most recent last_modified date from the model collection
278
+ model_latest_update = None
279
+ if model_collection.count() > 0:
280
+ model_metadata = model_collection.get(include=["metadatas"]).get(
281
+ "metadatas"
282
+ )
283
+ logger.info(
284
+ f"Found {len(model_metadata)} existing model records in collection"
285
+ )
286
+
287
+ model_last_modifieds = [
288
+ dateutil.parser.parse(m.get("last_modified")) for m in model_metadata
289
+ ]
290
+ model_latest_update = max(model_last_modifieds)
291
+ logger.info(f"Most recent model record in DB from: {model_latest_update}")
292
 
293
+ # Always process models to handle updates (not just new additions)
294
+ should_update_models = True
295
+ if model_latest_update:
296
  schema = model_lazy_df.collect_schema()
297
  select_columns = [
298
  "modelId",