Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -4,12 +4,8 @@ from PIL import Image
|
|
4 |
from transformers import (
|
5 |
BlipProcessor,
|
6 |
BlipForConditionalGeneration,
|
7 |
-
Blip2Processor,
|
8 |
-
Blip2ForConditionalGeneration,
|
9 |
M2M100Tokenizer,
|
10 |
-
M2M100ForConditionalGeneration
|
11 |
-
AutoTokenizer,
|
12 |
-
AutoModelForSeq2SeqLM
|
13 |
)
|
14 |
from typing import Union
|
15 |
from gtts import gTTS
|
@@ -21,76 +17,52 @@ import gc
|
|
21 |
torch.set_num_threads(2)
|
22 |
_pipeline = None
|
23 |
|
24 |
-
def init_pipeline(
|
25 |
global _pipeline
|
26 |
if _pipeline is None:
|
27 |
-
_pipeline = ImageCaptionPipeline(
|
28 |
return _pipeline
|
29 |
|
30 |
class ImageCaptionPipeline:
|
31 |
-
def __init__(self
|
32 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
33 |
-
self.caption_model = caption_model
|
34 |
-
self.translator_model = translator_model
|
35 |
-
|
36 |
start_time = time.time()
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
else:
|
41 |
-
self.blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
42 |
-
self.blip_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b").to(self.device)
|
43 |
-
print(f"Время загрузки {caption_model}: {time.time() - start_time:.2f} секунд")
|
44 |
|
45 |
start_time = time.time()
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
else:
|
50 |
-
self.translator_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ru")
|
51 |
-
self.translator_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ru").to(self.device)
|
52 |
-
print(f"Время загрузки переводчика {translator_model}: {time.time() - start_time:.2f} секунд")
|
53 |
|
54 |
-
def
|
55 |
start_time = time.time()
|
56 |
if isinstance(image, str):
|
57 |
image = Image.open(image)
|
58 |
image = image.convert("RGB")
|
59 |
-
image = image.resize((384, 384))
|
60 |
inputs = self.blip_processor(images=image, return_tensors="pt").to(self.device)
|
61 |
with torch.no_grad():
|
62 |
output_ids = self.blip_model.generate(**inputs, max_length=50, num_beams=2, early_stopping=True)
|
63 |
english_caption = self.blip_processor.decode(output_ids[0], skip_special_tokens=True)
|
64 |
-
english_caption = english_caption[0].upper() + english_caption[1:] + ('.' if not english_caption.endswith('.') else '')
|
65 |
print(f"Время генерации английской подписи: {time.time() - start_time:.2f} секунд")
|
66 |
|
67 |
-
gc.collect()
|
68 |
-
return english_caption
|
69 |
-
|
70 |
-
def translate_caption(self, english_caption: str) -> str:
|
71 |
start_time = time.time()
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
russian_caption = self.translator_tokenizer.decode(translated_ids[0], skip_special_tokens=True)
|
84 |
-
else:
|
85 |
-
translated_inputs = self.translator_tokenizer(english_caption, return_tensors="pt", padding=True).to(self.device)
|
86 |
-
with torch.no_grad():
|
87 |
-
translated_ids = self.translator_model.generate(**translated_inputs, max_length=50, num_beams=2, early_stopping=True)
|
88 |
-
russian_caption = self.translator_tokenizer.decode(translated_ids[0], skip_special_tokens=True)
|
89 |
-
russian_caption = russian_caption[0].upper() + russian_caption[1:] + ('.' if not russian_caption.endswith('.') else '')
|
90 |
print(f"Время перевода на русский: {time.time() - start_time:.2f} секунд")
|
91 |
|
92 |
gc.collect()
|
93 |
-
return russian_caption
|
94 |
|
95 |
def generate_audio(self, text: str, language: str) -> str:
|
96 |
start_time = time.time()
|
@@ -101,57 +73,42 @@ class ImageCaptionPipeline:
|
|
101 |
print(f"Время генерации озвучки: {time.time() - start_time:.2f} секунд")
|
102 |
return audio_path
|
103 |
|
104 |
-
def
|
105 |
if image is not None:
|
106 |
-
pipeline = init_pipeline(
|
107 |
-
english_caption = pipeline.
|
108 |
-
return english_caption,
|
109 |
-
return "Загрузите изображение.", "", None
|
110 |
|
111 |
-
def
|
112 |
-
if not english_caption or english_caption == "Загрузите изображение.":
|
113 |
-
return ""
|
114 |
-
pipeline = init_pipeline(caption_model, translator_model)
|
115 |
-
return pipeline.translate_caption(english_caption)
|
116 |
-
|
117 |
-
def generate_audio(english_caption: str, russian_caption: str, audio_language: str, caption_model: str, translator_model: str) -> str:
|
118 |
if not english_caption and not russian_caption:
|
119 |
return None
|
120 |
-
pipeline = init_pipeline(
|
121 |
text = russian_caption if audio_language == "Русский" else english_caption
|
122 |
return pipeline.generate_audio(text, audio_language)
|
123 |
|
124 |
-
with gr.Blocks(css=".btn {width: 200px; background-color: #
|
125 |
with gr.Row():
|
126 |
-
with gr.Column(scale=1, min_width=
|
127 |
-
image = gr.Image(type="pil", label="Изображение", height=
|
128 |
-
caption_model = gr.Dropdown(choices=["BLIP", "BLIP-2"], label="Модель описания", value="BLIP")
|
129 |
submit_button = gr.Button("Сгенерировать описание", elem_classes="btn")
|
130 |
with gr.Column(scale=1, min_width=300):
|
131 |
english_caption = gr.Textbox(label="Подпись English:", lines=2)
|
132 |
russian_caption = gr.Textbox(label="Подпись Русский:", lines=2)
|
133 |
-
translator_model = gr.Dropdown(choices=["M2M100", "Helsinki"], label="Модель перевода", value="M2M100")
|
134 |
-
translate_button = gr.Button("Сгенерировать перевод", elem_classes="btn")
|
135 |
audio_button = gr.Button("Сгенерировать озвучку", elem_classes="btn")
|
136 |
with gr.Row():
|
137 |
audio_language = gr.Dropdown(choices=["Русский", "English"], label="Язык озвучки", value="Русский", scale=1, min_width=150, elem_classes="equal-height")
|
138 |
audio_output = gr.Audio(label="Озвучка", scale=1, min_width=150, elem_classes="equal-height")
|
139 |
|
140 |
submit_button.click(
|
141 |
-
fn=
|
142 |
-
inputs=[image
|
143 |
outputs=[english_caption, russian_caption, audio_output]
|
144 |
)
|
145 |
|
146 |
-
translate_button.click(
|
147 |
-
fn=generate_translation,
|
148 |
-
inputs=[english_caption, caption_model, translator_model],
|
149 |
-
outputs=[russian_caption]
|
150 |
-
)
|
151 |
-
|
152 |
audio_button.click(
|
153 |
fn=generate_audio,
|
154 |
-
inputs=[english_caption, russian_caption, audio_language
|
155 |
outputs=[audio_output]
|
156 |
)
|
157 |
|
|
|
4 |
from transformers import (
|
5 |
BlipProcessor,
|
6 |
BlipForConditionalGeneration,
|
|
|
|
|
7 |
M2M100Tokenizer,
|
8 |
+
M2M100ForConditionalGeneration
|
|
|
|
|
9 |
)
|
10 |
from typing import Union
|
11 |
from gtts import gTTS
|
|
|
17 |
torch.set_num_threads(2)
|
18 |
_pipeline = None
|
19 |
|
20 |
+
def init_pipeline():
|
21 |
global _pipeline
|
22 |
if _pipeline is None:
|
23 |
+
_pipeline = ImageCaptionPipeline()
|
24 |
return _pipeline
|
25 |
|
26 |
class ImageCaptionPipeline:
|
27 |
+
def __init__(self):
|
28 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
29 |
start_time = time.time()
|
30 |
+
self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large", use_fast=True)
|
31 |
+
self.blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(self.device)
|
32 |
+
print(f"Время загрузки BLIP: {time.time() - start_time:.2f} секунд")
|
|
|
|
|
|
|
|
|
33 |
|
34 |
start_time = time.time()
|
35 |
+
self.translator_tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
|
36 |
+
self.translator_model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M").to(self.device)
|
37 |
+
print(f"Время загрузки переводчика: {time.time() - start_time:.2f} секунд")
|
|
|
|
|
|
|
|
|
38 |
|
39 |
+
def generate_captions(self, image: Union[str, Image.Image]) -> tuple:
|
40 |
start_time = time.time()
|
41 |
if isinstance(image, str):
|
42 |
image = Image.open(image)
|
43 |
image = image.convert("RGB")
|
|
|
44 |
inputs = self.blip_processor(images=image, return_tensors="pt").to(self.device)
|
45 |
with torch.no_grad():
|
46 |
output_ids = self.blip_model.generate(**inputs, max_length=50, num_beams=2, early_stopping=True)
|
47 |
english_caption = self.blip_processor.decode(output_ids[0], skip_special_tokens=True)
|
|
|
48 |
print(f"Время генерации английской подписи: {time.time() - start_time:.2f} секунд")
|
49 |
|
|
|
|
|
|
|
|
|
50 |
start_time = time.time()
|
51 |
+
self.translator_tokenizer.src_lang = "en"
|
52 |
+
translated_inputs = self.translator_tokenizer(english_caption, return_tensors="pt", padding=True).to(self.device)
|
53 |
+
with torch.no_grad():
|
54 |
+
translated_ids = self.translator_model.generate(
|
55 |
+
**translated_inputs,
|
56 |
+
forced_bos_token_id=self.translator_tokenizer.get_lang_id("ru"),
|
57 |
+
max_length=50,
|
58 |
+
num_beams=2,
|
59 |
+
early_stopping=True
|
60 |
+
)
|
61 |
+
russian_caption = self.translator_tokenizer.decode(translated_ids[0], skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
print(f"Время перевода на русский: {time.time() - start_time:.2f} секунд")
|
63 |
|
64 |
gc.collect()
|
65 |
+
return english_caption, russian_caption
|
66 |
|
67 |
def generate_audio(self, text: str, language: str) -> str:
|
68 |
start_time = time.time()
|
|
|
73 |
print(f"Время генерации озвучки: {time.time() - start_time:.2f} секунд")
|
74 |
return audio_path
|
75 |
|
76 |
+
def generate_captions(image: Image.Image) -> tuple:
|
77 |
if image is not None:
|
78 |
+
pipeline = init_pipeline()
|
79 |
+
english_caption, russian_caption = pipeline.generate_captions(image)
|
80 |
+
return english_caption, russian_caption, None
|
81 |
+
return "Загрузите изображение.", "Загрузите изображение.", None
|
82 |
|
83 |
+
def generate_audio(english_caption: str, russian_caption: str, audio_language: str) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
if not english_caption and not russian_caption:
|
85 |
return None
|
86 |
+
pipeline = init_pipeline()
|
87 |
text = russian_caption if audio_language == "Русский" else english_caption
|
88 |
return pipeline.generate_audio(text, audio_language)
|
89 |
|
90 |
+
with gr.Blocks(css=".btn {width: 200px; background-color: #4682B4; color: white; border: none; padding: 10px 20px; text-align: center; font-size: 16px;}") as iface:
|
91 |
with gr.Row():
|
92 |
+
with gr.Column(scale=1, min_width=400, variant="panel"):
|
93 |
+
image = gr.Image(type="pil", label="Изображение", height=400, width=400)
|
|
|
94 |
submit_button = gr.Button("Сгенерировать описание", elem_classes="btn")
|
95 |
with gr.Column(scale=1, min_width=300):
|
96 |
english_caption = gr.Textbox(label="Подпись English:", lines=2)
|
97 |
russian_caption = gr.Textbox(label="Подпись Русский:", lines=2)
|
|
|
|
|
98 |
audio_button = gr.Button("Сгенерировать озвучку", elem_classes="btn")
|
99 |
with gr.Row():
|
100 |
audio_language = gr.Dropdown(choices=["Русский", "English"], label="Язык озвучки", value="Русский", scale=1, min_width=150, elem_classes="equal-height")
|
101 |
audio_output = gr.Audio(label="Озвучка", scale=1, min_width=150, elem_classes="equal-height")
|
102 |
|
103 |
submit_button.click(
|
104 |
+
fn=generate_captions,
|
105 |
+
inputs=[image],
|
106 |
outputs=[english_caption, russian_caption, audio_output]
|
107 |
)
|
108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
audio_button.click(
|
110 |
fn=generate_audio,
|
111 |
+
inputs=[english_caption, russian_caption, audio_language],
|
112 |
outputs=[audio_output]
|
113 |
)
|
114 |
|