|
""" |
|
Sema Translation API - New Implementation |
|
Created for testing consolidated sema-utils repository |
|
Uses HuggingFace Hub for model downloading |
|
""" |
|
|
|
import os |
|
import time |
|
from datetime import datetime |
|
import pytz |
|
from typing import Optional |
|
|
|
from fastapi import FastAPI, HTTPException, Request |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from pydantic import BaseModel, Field |
|
from huggingface_hub import hf_hub_download |
|
import ctranslate2 |
|
import sentencepiece as spm |
|
import fasttext |
|
|
|
|
|
class TranslationRequest(BaseModel): |
|
text: str = Field(..., example="Habari ya asubuhi", description="Text to translate") |
|
target_language: str = Field(..., example="eng_Latn", description="FLORES-200 target language code") |
|
source_language: Optional[str] = Field(None, example="swh_Latn", description="Optional FLORES-200 source language code") |
|
|
|
class TranslationResponse(BaseModel): |
|
translated_text: str |
|
source_language: str |
|
target_language: str |
|
inference_time: float |
|
timestamp: str |
|
|
|
|
|
app = FastAPI( |
|
title="Sema Translation API", |
|
description="Translation API using consolidated sema-utils models from HuggingFace", |
|
version="2.0.0" |
|
) |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=False, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
REPO_ID = "sematech/sema-utils" |
|
beam_size = 1 |
|
device = "cpu" |
|
|
|
|
|
lang_model = None |
|
sp_model = None |
|
translator = None |
|
|
|
def get_nairobi_time(): |
|
"""Get current time in Nairobi timezone""" |
|
nairobi_timezone = pytz.timezone('Africa/Nairobi') |
|
current_time_nairobi = datetime.now(nairobi_timezone) |
|
|
|
curr_day = current_time_nairobi.strftime('%A') |
|
curr_date = current_time_nairobi.strftime('%Y-%m-%d') |
|
curr_time = current_time_nairobi.strftime('%H:%M:%S') |
|
|
|
full_date = f"{curr_day} | {curr_date} | {curr_time}" |
|
return full_date, curr_time |
|
|
|
def get_model_paths(): |
|
"""Get model paths from HuggingFace cache (models pre-downloaded in Docker)""" |
|
print("π Loading models from cache...") |
|
|
|
try: |
|
|
|
offline_mode = os.environ.get("HF_HUB_OFFLINE", "0") == "1" |
|
|
|
if offline_mode: |
|
print("π¦ Running in offline mode - using cached models") |
|
|
|
|
|
|
|
|
|
spm_path = hf_hub_download( |
|
repo_id=REPO_ID, |
|
filename="spm.model", |
|
local_files_only=True |
|
) |
|
|
|
ft_path = hf_hub_download( |
|
repo_id=REPO_ID, |
|
filename="lid218e.bin", |
|
local_files_only=True |
|
) |
|
|
|
|
|
model_bin_path = hf_hub_download( |
|
repo_id=REPO_ID, |
|
filename="translation_models/sematrans-3.3B/model.bin", |
|
local_files_only=True |
|
) |
|
|
|
|
|
ct_model_full_path = os.path.dirname(model_bin_path) |
|
|
|
else: |
|
print("π Running in online mode - downloading models") |
|
|
|
spm_path = hf_hub_download( |
|
repo_id=REPO_ID, |
|
filename="spm.model" |
|
) |
|
|
|
ft_path = hf_hub_download( |
|
repo_id=REPO_ID, |
|
filename="lid218e.bin" |
|
) |
|
|
|
|
|
model_bin_path = hf_hub_download( |
|
repo_id=REPO_ID, |
|
filename="translation_models/sematrans-3.3B/model.bin" |
|
) |
|
|
|
hf_hub_download( |
|
repo_id=REPO_ID, |
|
filename="translation_models/sematrans-3.3B/config.json" |
|
) |
|
|
|
hf_hub_download( |
|
repo_id=REPO_ID, |
|
filename="translation_models/sematrans-3.3B/shared_vocabulary.txt" |
|
) |
|
|
|
ct_model_full_path = os.path.dirname(model_bin_path) |
|
|
|
print(f"π Model paths:") |
|
print(f" SentencePiece: {spm_path}") |
|
print(f" Language detection: {ft_path}") |
|
print(f" Translation model: {ct_model_full_path}") |
|
|
|
return spm_path, ft_path, ct_model_full_path |
|
|
|
except Exception as e: |
|
print(f"β Error loading models: {e}") |
|
raise e |
|
|
|
def load_models(): |
|
"""Load all models into memory""" |
|
global lang_model, sp_model, translator |
|
|
|
print("π Loading models into memory...") |
|
|
|
|
|
spm_path, ft_path, ct_model_path = get_model_paths() |
|
|
|
|
|
fasttext.FastText.eprint = lambda x: None |
|
|
|
try: |
|
|
|
print("1οΈβ£ Loading language detection model...") |
|
lang_model = fasttext.load_model(ft_path) |
|
|
|
|
|
print("2οΈβ£ Loading SentencePiece model...") |
|
sp_model = spm.SentencePieceProcessor() |
|
sp_model.load(spm_path) |
|
|
|
|
|
print("3οΈβ£ Loading translation model...") |
|
translator = ctranslate2.Translator(ct_model_path, device) |
|
|
|
print("β
All models loaded successfully!") |
|
|
|
except Exception as e: |
|
print(f"β Error loading models: {e}") |
|
raise e |
|
|
|
def translate_with_detection(text: str, target_lang: str): |
|
"""Translate text with automatic source language detection""" |
|
start_time = time.time() |
|
|
|
|
|
source_sents = [text.strip()] |
|
target_prefix = [[target_lang]] |
|
|
|
|
|
predictions = lang_model.predict(text.replace('\n', ' '), k=1) |
|
source_lang = predictions[0][0].replace('__label__', '') |
|
|
|
|
|
source_sents_subworded = sp_model.encode(source_sents, out_type=str) |
|
source_sents_subworded = [[source_lang] + sent + ["</s>"] for sent in source_sents_subworded] |
|
|
|
|
|
translations = translator.translate_batch( |
|
source_sents_subworded, |
|
batch_type="tokens", |
|
max_batch_size=2048, |
|
beam_size=beam_size, |
|
target_prefix=target_prefix, |
|
) |
|
|
|
|
|
translations = [translation[0]['tokens'] for translation in translations] |
|
translations_desubword = sp_model.decode(translations) |
|
translated_text = translations_desubword[0][len(target_lang):] |
|
|
|
inference_time = time.time() - start_time |
|
|
|
return source_lang, translated_text, inference_time |
|
|
|
def translate_with_source(text: str, source_lang: str, target_lang: str): |
|
"""Translate text with provided source language""" |
|
start_time = time.time() |
|
|
|
|
|
source_sents = [text.strip()] |
|
target_prefix = [[target_lang]] |
|
|
|
|
|
source_sents_subworded = sp_model.encode(source_sents, out_type=str) |
|
source_sents_subworded = [[source_lang] + sent + ["</s>"] for sent in source_sents_subworded] |
|
|
|
|
|
translations = translator.translate_batch( |
|
source_sents_subworded, |
|
batch_type="tokens", |
|
max_batch_size=2048, |
|
beam_size=beam_size, |
|
target_prefix=target_prefix |
|
) |
|
|
|
|
|
translations = [translation[0]['tokens'] for translation in translations] |
|
translations_desubword = sp_model.decode(translations) |
|
translated_text = translations_desubword[0][len(target_lang):] |
|
|
|
inference_time = time.time() - start_time |
|
|
|
return translated_text, inference_time |
|
|
|
|
|
|
|
@app.get("/") |
|
async def root(): |
|
"""Health check endpoint""" |
|
return { |
|
"status": "ok", |
|
"message": "Sema Translation API is running", |
|
"version": "2.0.0", |
|
"models_loaded": all([lang_model, sp_model, translator]) |
|
} |
|
|
|
@app.post("/translate", response_model=TranslationResponse) |
|
async def translate_endpoint(request: TranslationRequest): |
|
""" |
|
Main translation endpoint. |
|
Automatically detects source language if not provided. |
|
""" |
|
if not request.text.strip(): |
|
raise HTTPException(status_code=400, detail="Input text cannot be empty") |
|
|
|
full_date, current_time = get_nairobi_time() |
|
print(f"\nπ Request: {full_date}") |
|
print(f"Target: {request.target_language}, Text: {request.text[:50]}...") |
|
|
|
try: |
|
if request.source_language: |
|
|
|
translated_text, inference_time = translate_with_source( |
|
request.text, |
|
request.source_language, |
|
request.target_language |
|
) |
|
source_lang = request.source_language |
|
else: |
|
|
|
source_lang, translated_text, inference_time = translate_with_detection( |
|
request.text, |
|
request.target_language |
|
) |
|
|
|
_, response_time = get_nairobi_time() |
|
print(f"β
Response: {response_time}") |
|
print(f"Source: {source_lang}, Translation: {translated_text[:50]}...\n") |
|
|
|
return TranslationResponse( |
|
translated_text=translated_text, |
|
source_language=source_lang, |
|
target_language=request.target_language, |
|
inference_time=inference_time, |
|
timestamp=full_date |
|
) |
|
|
|
except Exception as e: |
|
print(f"β Translation error: {e}") |
|
raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}") |
|
|
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
"""Load models when the application starts""" |
|
print("\nπ΅ Starting Sema Translation API...") |
|
print("πΌ Loading the Orchestra... π¦") |
|
load_models() |
|
print("π API started successfully!\n") |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|