kamau1 commited on
Commit
5aa0409
Β·
1 Parent(s): 937c29e

fix: switch to HF default cache, remove MODELS_DIR and unused import, update CTranslate2 download to include all required files

Browse files
Files changed (1) hide show
  1. sema_translation_api.py +22 -17
sema_translation_api.py CHANGED
@@ -13,7 +13,7 @@ from typing import Optional
13
  from fastapi import FastAPI, HTTPException, Request
14
  from fastapi.middleware.cors import CORSMiddleware
15
  from pydantic import BaseModel, Field
16
- from huggingface_hub import hf_hub_download, snapshot_download
17
  import ctranslate2
18
  import sentencepiece as spm
19
  import fasttext
@@ -49,7 +49,6 @@ app.add_middleware(
49
 
50
  # --- Global Variables ---
51
  REPO_ID = "sematech/sema-utils"
52
- MODELS_DIR = "hf_models"
53
  beam_size = 1
54
  device = "cpu"
55
 
@@ -71,38 +70,44 @@ def get_nairobi_time():
71
  return full_date, curr_time
72
 
73
  def download_models():
74
- """Download models from HuggingFace Hub"""
75
  print("πŸ”„ Downloading models from sematech/sema-utils...")
76
 
77
- # Ensure models directory exists
78
- os.makedirs(MODELS_DIR, exist_ok=True)
79
-
80
  try:
81
- # Download individual files from root
82
  print("πŸ“₯ Downloading SentencePiece model...")
83
  spm_path = hf_hub_download(
84
  repo_id=REPO_ID,
85
- filename="spm.model",
86
- local_dir=MODELS_DIR
87
  )
88
 
89
  print("πŸ“₯ Downloading language detection model...")
90
  ft_path = hf_hub_download(
91
  repo_id=REPO_ID,
92
- filename="lid218e.bin",
93
- local_dir=MODELS_DIR
94
  )
95
 
96
- # Download translation model (3.3B) from subfolder
97
  print("πŸ“₯ Downloading translation model (3.3B)...")
98
- ct_model_path = snapshot_download(
 
 
 
 
 
 
 
 
 
 
 
 
99
  repo_id=REPO_ID,
100
- allow_patterns="translation_models/sematrans-3.3B/*",
101
- local_dir=MODELS_DIR
102
  )
103
 
104
- # Construct paths
105
- ct_model_full_path = os.path.join(MODELS_DIR, "translation_models", "sematrans-3.3B")
106
 
107
  return spm_path, ft_path, ct_model_full_path
108
 
 
13
  from fastapi import FastAPI, HTTPException, Request
14
  from fastapi.middleware.cors import CORSMiddleware
15
  from pydantic import BaseModel, Field
16
+ from huggingface_hub import hf_hub_download
17
  import ctranslate2
18
  import sentencepiece as spm
19
  import fasttext
 
49
 
50
  # --- Global Variables ---
51
  REPO_ID = "sematech/sema-utils"
 
52
  beam_size = 1
53
  device = "cpu"
54
 
 
70
  return full_date, curr_time
71
 
72
  def download_models():
73
+ """Download models from HuggingFace Hub using default cache"""
74
  print("πŸ”„ Downloading models from sematech/sema-utils...")
75
 
 
 
 
76
  try:
77
+ # Download individual files from root (using default HF cache)
78
  print("πŸ“₯ Downloading SentencePiece model...")
79
  spm_path = hf_hub_download(
80
  repo_id=REPO_ID,
81
+ filename="spm.model"
 
82
  )
83
 
84
  print("πŸ“₯ Downloading language detection model...")
85
  ft_path = hf_hub_download(
86
  repo_id=REPO_ID,
87
+ filename="lid218e.bin"
 
88
  )
89
 
90
+ # Download translation model files individually
91
  print("πŸ“₯ Downloading translation model (3.3B)...")
92
+
93
+ # Download all necessary CTranslate2 files
94
+ model_bin_path = hf_hub_download(
95
+ repo_id=REPO_ID,
96
+ filename="translation_models/sematrans-3.3B/model.bin"
97
+ )
98
+
99
+ hf_hub_download(
100
+ repo_id=REPO_ID,
101
+ filename="translation_models/sematrans-3.3B/config.json"
102
+ )
103
+
104
+ hf_hub_download(
105
  repo_id=REPO_ID,
106
+ filename="translation_models/sematrans-3.3B/shared_vocabulary.txt"
 
107
  )
108
 
109
+ # The model directory is the parent of the model.bin file
110
+ ct_model_full_path = os.path.dirname(model_bin_path)
111
 
112
  return spm_path, ft_path, ct_model_full_path
113