Spaces:
Running
Running
import gradio as gr | |
import torch | |
from PIL import Image | |
from transformers import ( | |
BlipProcessor, | |
BlipForConditionalGeneration, | |
AutoTokenizer, | |
AutoModelForSeq2SeqLM | |
) | |
from typing import Union | |
from gtts import gTTS | |
import os | |
import uuid | |
import time | |
import gc | |
torch.set_num_threads(2) | |
_pipeline = None | |
def init_pipeline(): | |
global _pipeline | |
if _pipeline is None: | |
_pipeline = ImageCaptionPipeline() | |
return _pipeline | |
class ImageCaptionPipeline: | |
def __init__(self): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
start_time = time.time() | |
self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large", use_fast=True) | |
self.blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(self.device) | |
print(f"Время загрузки BLIP: {time.time() - start_time:.2f} секунд") | |
start_time = time.time() | |
self.translator_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ru") | |
self.translator_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ru").to(self.device) | |
print(f"Время загрузки переводчика: {time.time() - start_time:.2f} секунд") | |
def generate_captions(self, image: Union[str, Image.Image]) -> tuple: | |
start_time = time.time() | |
if isinstance(image, str): | |
image = Image.open(image) | |
image = image.convert("RGB") | |
image = image.resize((512, 512), Image.Resampling.LANCZOS) | |
inputs = self.blip_processor(images=image, return_tensors="pt").to(self.device) | |
with torch.no_grad(): | |
output_ids = self.blip_model.generate(**inputs, max_length=50, num_beams=2, early_stopping=True) | |
english_caption = self.blip_processor.decode(output_ids[0], skip_special_tokens=True) | |
print(f"Время генерации английской подписи: {time.time() - start_time:.2f} секунд") | |
start_time = time.time() | |
translated_inputs = self.translator_tokenizer(english_caption, return_tensors="pt", padding=True).to(self.device) | |
with torch.no_grad(): | |
translated_ids = self.translator_model.generate( | |
**translated_inputs, | |
max_length=50, | |
num_beams=2, | |
early_stopping=True | |
) | |
russian_caption = self.translator_tokenizer.decode(translated_ids[0], skip_special_tokens=True) | |
print(f"Время перевода на русский: {time.time() - start_time:.2f} секунд") | |
gc.collect() | |
return english_caption, russian_caption | |
def generate_audio(self, text: str, language: str) -> str: | |
start_time = time.time() | |
lang_code = "ru" if language == "Русский" else "en" | |
tts = gTTS(text=text, lang=lang_code) | |
audio_path = f"caption_audio_{uuid.uuid4()}.mp3" | |
tts.save(audio_path) | |
print(f"Время генерации озвучки: {time.time() - start_time:.2f} секунд") | |
return audio_path | |
def generate_captions(image: Image.Image) -> tuple: | |
if image is not None: | |
pipeline = init_pipeline() | |
english_caption, russian_caption = pipeline.generate_captions(image) | |
return english_caption, russian_caption, None | |
return "Загрузите изображение.", "Загрузите изображение.", None | |
def generate_audio(english_caption: str, russian_caption: str, audio_language: str) -> str: | |
if not english_caption and not russian_caption: | |
return None | |
pipeline = init_pipeline() | |
text = russian_caption if audio_language == "Русский" else english_caption | |
return pipeline.generate_audio(text, audio_language) | |
with gr.Blocks(css=""" | |
.btn { | |
width: 200px; | |
background-color: #4B0082; | |
color: white; | |
font-size: 16px; | |
} | |
.equal-height { | |
height: 100px !important; | |
} | |
""") as iface: | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=400, variant="panel"): | |
with gr.Row(): | |
image = gr.Image(type="pil", label="Изображение", height=400, width=400) | |
with gr.Row(): | |
submit_button = gr.Button("Сгенерировать описание", elem_classes="btn") | |
with gr.Column(scale=1, variant="panel"): | |
with gr.Row(): | |
english_caption = gr.Textbox(label="Английский язык:", lines=1, interactive=False) | |
russian_caption = gr.Textbox(label="Русский язык:", lines=1, interactive=False) | |
with gr.Row(): | |
audio_language = gr.Dropdown( | |
choices=["Русский", "English"], | |
label="Язык озвучки", | |
value="Русский", | |
elem_classes="equal-height" | |
) | |
audio_output = gr.Audio( | |
label="Озвучка", | |
elem_classes="equal-height" | |
) | |
with gr.Row(): | |
audio_button = gr.Button("Сгенерировать озвучку", elem_classes="btn") | |
submit_button.click( | |
fn=generate_captions, | |
inputs=[image], | |
outputs=[english_caption, russian_caption] | |
) | |
audio_button.click( | |
fn=generate_audio, | |
inputs=[english_caption, russian_caption, audio_language], | |
outputs=[audio_output] | |
) | |
if __name__ == "__main__": | |
iface.launch() | |
# Пум-пуммм.. |