# main.py from fastapi import FastAPI, File, UploadFile from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration from transformers.image_utils import load_image import torch from io import BytesIO import os from dotenv import load_dotenv from PIL import Image from huggingface_hub import login # Load environment variables load_dotenv() # Set the cache directory to a writable path os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torch_inductor_cache" token = os.getenv("huggingface_ankit") # Login to the Hugging Face Hub login(token) app = FastAPI() model_id = "google/paligemma2-3b-mix-448" model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).to('cuda') processor = PaliGemmaProcessor.from_pretrained(model_id) def predict(image): prompt = " ocr" model_inputs = processor(text=prompt, images=image, return_tensors="pt").to('cuda') input_len = model_inputs["input_ids"].shape[-1] with torch.inference_mode(): generation = model.generate(**model_inputs, max_new_tokens=200) torch.cuda.empty_cache() decoded = processor.decode(generation[0], skip_special_tokens=True) #[len(prompt):].lstrip("\n") return decoded @app.post("/extract_text") async def extract_text(file: UploadFile = File(...)): image = Image.open(BytesIO(await file.read())).convert("RGB") # Ensure it's a valid PIL image text = predict(image) return {"extracted_text": text} @app.post("/batch_extract_text") async def batch_extract_text(files: list[UploadFile] = File(...)): if len(files) > 20: return {"error": "A maximum of 20 images can be processed at a time."} images = [Image.open(BytesIO(await file.read())).convert("RGB") for file in files] prompts = ["OCR"] * len(images) model_inputs = processor(text=prompts, images=images, return_tensors="pt").to(torch.bfloat16).to(model.device) input_len = model_inputs["input_ids"].shape[-1] with torch.inference_mode(): generations = model.generate(**model_inputs, max_new_tokens=200, do_sample=False) torch.cuda.empty_cache() extracted_texts = [processor.decode(generations[i], skip_special_tokens=True) for i in range(len(images))] return {"extracted_texts": extracted_texts} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)