File size: 5,454 Bytes
deaf141
 
 
 
 
 
98b7982
 
deaf141
 
 
 
d98c8c7
 
 
 
 
 
 
e176f16
d98c8c7
 
e176f16
d98c8c7
deaf141
 
e176f16
deaf141
d98c8c7
e176f16
 
 
d98c8c7
 
98b7982
 
e176f16
deaf141
e176f16
d98c8c7
deaf141
 
 
98b7982
deaf141
 
d98c8c7
deaf141
d98c8c7
 
 
e176f16
 
 
 
 
 
 
 
 
d98c8c7
 
 
e176f16
deaf141
d98c8c7
 
deaf141
d98c8c7
 
deaf141
d98c8c7
 
 
e176f16
d98c8c7
e176f16
 
 
 
262fa8a
e176f16
d98c8c7
 
e176f16
ac10dbc
d98c8c7
deaf141
98b7982
d98c8c7
e176f16
 
ac10dbc
ff96f1e
98b7982
 
4454c3a
 
 
 
ac10dbc
d98c8c7
deaf141
e176f16
 
d98c8c7
 
 
 
 
e176f16
d98c8c7
deaf141
 
 
eb2f678
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
import torch
from PIL import Image
from transformers import (
    BlipProcessor,
    BlipForConditionalGeneration,
    AutoTokenizer,
    AutoModelForSeq2SeqLM
)
from typing import Union
from gtts import gTTS
import os
import uuid
import time
import gc

torch.set_num_threads(2)
_pipeline = None

def init_pipeline():
    global _pipeline
    if _pipeline is None:
        _pipeline = ImageCaptionPipeline()
    return _pipeline

class ImageCaptionPipeline:
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        start_time = time.time()
        self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large", use_fast=True)
        self.blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(self.device)
        print(f"Время загрузки BLIP: {time.time() - start_time:.2f} секунд")
        
        start_time = time.time()
        self.translator_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ru")
        self.translator_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ru").to(self.device)
        print(f"Время загрузки переводчика: {time.time() - start_time:.2f} секунд")

    def generate_captions(self, image: Union[str, Image.Image]) -> tuple:
        start_time = time.time()
        if isinstance(image, str):
            image = Image.open(image)
        image = image.convert("RGB")
        image = image.resize((512, 512), Image.Resampling.LANCZOS)
        inputs = self.blip_processor(images=image, return_tensors="pt").to(self.device)
        with torch.no_grad():
            output_ids = self.blip_model.generate(**inputs, max_length=50, num_beams=2, early_stopping=True)
            english_caption = self.blip_processor.decode(output_ids[0], skip_special_tokens=True)
        print(f"Время генерации английской подписи: {time.time() - start_time:.2f} секунд")
        
        start_time = time.time()
        translated_inputs = self.translator_tokenizer(english_caption, return_tensors="pt", padding=True).to(self.device)
        with torch.no_grad():
            translated_ids = self.translator_model.generate(
                **translated_inputs,
                max_length=50,
                num_beams=2,
                early_stopping=True
            )
            russian_caption = self.translator_tokenizer.decode(translated_ids[0], skip_special_tokens=True)
        print(f"Время перевода на русский: {time.time() - start_time:.2f} секунд")
        
        gc.collect()
        return english_caption, russian_caption

    def generate_audio(self, text: str, language: str) -> str:
        start_time = time.time()
        lang_code = "ru" if language == "Русский" else "en"
        tts = gTTS(text=text, lang=lang_code)
        audio_path = f"caption_audio_{uuid.uuid4()}.mp3"
        tts.save(audio_path)
        print(f"Время генерации озвучки: {time.time() - start_time:.2f} секунд")
        return audio_path

def generate_captions(image: Image.Image) -> tuple:
    if image is not None:
        pipeline = init_pipeline()
        english_caption, russian_caption = pipeline.generate_captions(image)
        return english_caption, russian_caption, None
    return "Загрузите изображение.", "Загрузите изображение.", None

def generate_audio(english_caption: str, russian_caption: str, audio_language: str) -> str:
    if not english_caption and not russian_caption:
        return None
    pipeline = init_pipeline()
    text = russian_caption if audio_language == "Русский" else english_caption
    return pipeline.generate_audio(text, audio_language)

with gr.Blocks(css=".btn {width: 200px; background-color: #4B0082; color: white; border: none; padding: 10px 20px; text-align: center; font-size: 16px; margin: 0 auto; display: block;} .equal-height {height: 60px !important;") as iface:
    with gr.Row():
        with gr.Column(scale=1, min_width=400, variant="panel"):
            image = gr.Image(type="pil", label="Изображение", height=400, width=400)
            submit_button = gr.Button("Сгенерировать описание", elem_classes="btn")
        with gr.Column(scale=1, min_width=300):
            english_caption = gr.Textbox(label="Описание на English:", lines=2)
            russian_caption = gr.Textbox(label="Описание на Русском:", lines=2)
        with gr.Row():
            audio_language = gr.Dropdown(choices=["Русский", "English"], label="Язык озвучки", value="Русский", scale=1, min_width=200, elem_classes="equal-height")
            audio_output = gr.Audio(label="Озвучка", scale=1, min_width=200, elem_classes="equal-height")
        with gr.Column(scale=1):
            audio_button = gr.Button("Сгенерировать озвучку", elem_classes="btn")
    
    submit_button.click(
        fn=generate_captions,
        inputs=[image],
        outputs=[english_caption, russian_caption, audio_output]
    )
    
    audio_button.click(
        fn=generate_audio,
        inputs=[english_caption, russian_caption, audio_language],
        outputs=[audio_output]
    )

if __name__ == "__main__":
    iface.launch()