Kursovaia2025 / app.py
Zguin's picture
Update app.py
b6dd7a4 verified
import gradio as gr
import torch
from PIL import Image
from transformers import (
BlipProcessor,
BlipForConditionalGeneration,
AutoTokenizer,
AutoModelForSeq2SeqLM
)
from typing import Union
from gtts import gTTS
import os
import uuid
import time
import gc
torch.set_num_threads(2)
_pipeline = None
def init_pipeline():
global _pipeline
if _pipeline is None:
_pipeline = ImageCaptionPipeline()
return _pipeline
class ImageCaptionPipeline:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
start_time = time.time()
self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large", use_fast=True)
self.blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(self.device)
print(f"Время загрузки BLIP: {time.time() - start_time:.2f} секунд")
start_time = time.time()
self.translator_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ru")
self.translator_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ru").to(self.device)
print(f"Время загрузки переводчика: {time.time() - start_time:.2f} секунд")
def generate_captions(self, image: Union[str, Image.Image]) -> tuple:
start_time = time.time()
if isinstance(image, str):
image = Image.open(image)
image = image.convert("RGB")
image = image.resize((512, 512), Image.Resampling.LANCZOS)
inputs = self.blip_processor(images=image, return_tensors="pt").to(self.device)
with torch.no_grad():
output_ids = self.blip_model.generate(**inputs, max_length=50, num_beams=2, early_stopping=True)
english_caption = self.blip_processor.decode(output_ids[0], skip_special_tokens=True)
print(f"Время генерации английской подписи: {time.time() - start_time:.2f} секунд")
start_time = time.time()
translated_inputs = self.translator_tokenizer(english_caption, return_tensors="pt", padding=True).to(self.device)
with torch.no_grad():
translated_ids = self.translator_model.generate(
**translated_inputs,
max_length=50,
num_beams=2,
early_stopping=True
)
russian_caption = self.translator_tokenizer.decode(translated_ids[0], skip_special_tokens=True)
print(f"Время перевода на русский: {time.time() - start_time:.2f} секунд")
gc.collect()
return english_caption, russian_caption
def generate_audio(self, text: str, language: str) -> str:
start_time = time.time()
lang_code = "ru" if language == "Русский" else "en"
tts = gTTS(text=text, lang=lang_code)
audio_path = f"caption_audio_{uuid.uuid4()}.mp3"
tts.save(audio_path)
print(f"Время генерации озвучки: {time.time() - start_time:.2f} секунд")
return audio_path
def generate_captions(image: Image.Image) -> tuple:
if image is not None:
pipeline = init_pipeline()
english_caption, russian_caption = pipeline.generate_captions(image)
return english_caption, russian_caption, None
return "Загрузите изображение.", "Загрузите изображение.", None
def generate_audio(english_caption: str, russian_caption: str, audio_language: str) -> str:
if not english_caption and not russian_caption:
return None
pipeline = init_pipeline()
text = russian_caption if audio_language == "Русский" else english_caption
return pipeline.generate_audio(text, audio_language)
with gr.Blocks(css="""
.btn {
width: 200px;
background-color: #4B0082;
color: white;
font-size: 16px;
}
.equal-height {
height: 100px !important;
}
""") as iface:
with gr.Row():
with gr.Column(scale=1, min_width=400, variant="panel"):
with gr.Row():
image = gr.Image(type="pil", label="Изображение", height=400, width=400)
with gr.Row():
submit_button = gr.Button("Сгенерировать описание", elem_classes="btn")
with gr.Column(scale=1, variant="panel"):
with gr.Row():
english_caption = gr.Textbox(label="Английский язык:", lines=1, interactive=False)
russian_caption = gr.Textbox(label="Русский язык:", lines=1, interactive=False)
with gr.Row():
audio_language = gr.Dropdown(
choices=["Русский", "English"],
label="Язык озвучки",
value="Русский",
elem_classes="equal-height"
)
audio_output = gr.Audio(
label="Озвучка",
elem_classes="equal-height"
)
with gr.Row():
audio_button = gr.Button("Сгенерировать озвучку", elem_classes="btn")
submit_button.click(
fn=generate_captions,
inputs=[image],
outputs=[english_caption, russian_caption]
)
audio_button.click(
fn=generate_audio,
inputs=[english_caption, russian_caption, audio_language],
outputs=[audio_output]
)
if __name__ == "__main__":
iface.launch()
# Пум-пуммм..