Spaces:
Sleeping
Sleeping
import torch | |
from PIL import Image | |
import pytesseract | |
from transformers import BlipProcessor, BlipForConditionalGeneration | |
from io import BytesIO | |
import base64 | |
from typing import Union | |
import whisper # Must be `openai-whisper` installed | |
# === DEVICE SETUP === | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# === LOAD WHISPER MODEL FOR AUDIO TRANSCRIPTION === | |
try: | |
whisper_model = whisper.load_model("base") # Options: "tiny", "base", "small", "medium", "large" | |
except Exception as e: | |
raise RuntimeError(f"β Failed to load Whisper model: {str(e)}") | |
# === LOAD BLIP FOR IMAGE CAPTIONING === | |
try: | |
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
blip_model = BlipForConditionalGeneration.from_pretrained( | |
"Salesforce/blip-image-captioning-base" | |
).to(device) | |
except Exception as e: | |
raise RuntimeError(f"β Failed to load BLIP model: {str(e)}") | |
# === TEXT EXTRACTION (OCR) === | |
def extract_text_from_image_base64(image_base64: str) -> str: | |
"""Extract text from a base64-encoded image.""" | |
try: | |
image_data = base64.b64decode(image_base64) | |
image = Image.open(BytesIO(image_data)) | |
return pytesseract.image_to_string(image).strip() | |
except Exception as e: | |
return f"β OCR Error (base64): {str(e)}" | |
def extract_text_from_image_path(image_path: str) -> str: | |
"""Extract text from an image file path.""" | |
try: | |
image = Image.open(image_path) | |
return pytesseract.image_to_string(image).strip() | |
except Exception as e: | |
return f"β OCR Error (path): {str(e)}" | |
def extract_text_from_image_bytes(image_bytes: bytes) -> str: | |
"""Extract text from raw image bytes (e.g., file uploads).""" | |
try: | |
image = Image.open(BytesIO(image_bytes)) | |
return pytesseract.image_to_string(image).strip() | |
except Exception as e: | |
return f"β OCR Error (bytes): {str(e)}" | |
def extract_text_from_image(image_base64: str) -> str: | |
"""API alias for default OCR from base64 input.""" | |
return extract_text_from_image_base64(image_base64) | |
# === IMAGE CAPTIONING === | |
def caption_image(image: Image.Image) -> str: | |
"""Generate a caption from a PIL image object.""" | |
try: | |
inputs = blip_processor(image.convert("RGB"), return_tensors="pt").to(device) | |
outputs = blip_model.generate(**inputs) | |
return blip_processor.decode(outputs[0], skip_special_tokens=True) | |
except Exception as e: | |
return f"β Captioning Error: {str(e)}" | |
def caption_image_path(image_path: str) -> str: | |
"""Generate a caption from image file path.""" | |
try: | |
image = Image.open(image_path) | |
return caption_image(image) | |
except Exception as e: | |
return f"β Captioning Error (path): {str(e)}" | |
def caption_image_bytes(image_bytes: bytes) -> str: | |
"""Generate a caption from image bytes.""" | |
try: | |
image = Image.open(BytesIO(image_bytes)) | |
return caption_image(image) | |
except Exception as e: | |
return f"β Captioning Error (bytes): {str(e)}" | |
def describe_image(input_data: Union[str, bytes]) -> str: | |
""" | |
Unified captioning API β accepts either path or bytes. | |
""" | |
try: | |
if isinstance(input_data, bytes): | |
return caption_image_bytes(input_data) | |
elif isinstance(input_data, str): | |
return caption_image_path(input_data) | |
else: | |
return "β Unsupported input type for describe_image" | |
except Exception as e: | |
return f"β Description Error: {str(e)}" | |
# === AUDIO TRANSCRIPTION === | |
def transcribe_audio_bytes(audio_bytes: bytes) -> str: | |
"""Transcribe raw audio bytes using Whisper.""" | |
try: | |
# Save to temporary file | |
temp_path = "/tmp/temp_audio.wav" | |
with open(temp_path, "wb") as f: | |
f.write(audio_bytes) | |
result = whisper_model.transcribe(temp_path) | |
return result.get("text", "").strip() | |
except Exception as e: | |
return f"β Transcription Error: {str(e)}" | |
def transcribe_audio_path(audio_path: str) -> str: | |
"""Transcribe audio file using Whisper.""" | |
try: | |
result = whisper_model.transcribe(audio_path) | |
return result.get("text", "").strip() | |
except Exception as e: | |
return f"β Transcription Error (path): {str(e)}" | |