Zguin commited on
Commit
e176f16
·
verified ·
1 Parent(s): 262fa8a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -79
app.py CHANGED
@@ -4,12 +4,8 @@ from PIL import Image
4
  from transformers import (
5
  BlipProcessor,
6
  BlipForConditionalGeneration,
7
- Blip2Processor,
8
- Blip2ForConditionalGeneration,
9
  M2M100Tokenizer,
10
- M2M100ForConditionalGeneration,
11
- AutoTokenizer,
12
- AutoModelForSeq2SeqLM
13
  )
14
  from typing import Union
15
  from gtts import gTTS
@@ -21,76 +17,52 @@ import gc
21
  torch.set_num_threads(2)
22
  _pipeline = None
23
 
24
- def init_pipeline(caption_model: str, translator_model: str):
25
  global _pipeline
26
  if _pipeline is None:
27
- _pipeline = ImageCaptionPipeline(caption_model, translator_model)
28
  return _pipeline
29
 
30
  class ImageCaptionPipeline:
31
- def __init__(self, caption_model: str, translator_model: str):
32
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
33
- self.caption_model = caption_model
34
- self.translator_model = translator_model
35
-
36
  start_time = time.time()
37
- if caption_model == "BLIP":
38
- self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large", use_fast=True)
39
- self.blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(self.device)
40
- else:
41
- self.blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
42
- self.blip_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b").to(self.device)
43
- print(f"Время загрузки {caption_model}: {time.time() - start_time:.2f} секунд")
44
 
45
  start_time = time.time()
46
- if translator_model == "M2M100":
47
- self.translator_tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
48
- self.translator_model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M").to(self.device)
49
- else:
50
- self.translator_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ru")
51
- self.translator_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ru").to(self.device)
52
- print(f"Время загрузки переводчика {translator_model}: {time.time() - start_time:.2f} секунд")
53
 
54
- def generate_english_caption(self, image: Union[str, Image.Image]) -> str:
55
  start_time = time.time()
56
  if isinstance(image, str):
57
  image = Image.open(image)
58
  image = image.convert("RGB")
59
- image = image.resize((384, 384))
60
  inputs = self.blip_processor(images=image, return_tensors="pt").to(self.device)
61
  with torch.no_grad():
62
  output_ids = self.blip_model.generate(**inputs, max_length=50, num_beams=2, early_stopping=True)
63
  english_caption = self.blip_processor.decode(output_ids[0], skip_special_tokens=True)
64
- english_caption = english_caption[0].upper() + english_caption[1:] + ('.' if not english_caption.endswith('.') else '')
65
  print(f"Время генерации английской подписи: {time.time() - start_time:.2f} секунд")
66
 
67
- gc.collect()
68
- return english_caption
69
-
70
- def translate_caption(self, english_caption: str) -> str:
71
  start_time = time.time()
72
- if self.translator_model == "M2M100":
73
- self.translator_tokenizer.src_lang = "en"
74
- translated_inputs = self.translator_tokenizer(english_caption, return_tensors="pt", padding=True).to(self.device)
75
- with torch.no_grad():
76
- translated_ids = self.translator_model.generate(
77
- **translated_inputs,
78
- forced_bos_token_id=self.translator_tokenizer.get_lang_id("ru"),
79
- max_length=50,
80
- num_beams=2,
81
- early_stopping=True
82
- )
83
- russian_caption = self.translator_tokenizer.decode(translated_ids[0], skip_special_tokens=True)
84
- else:
85
- translated_inputs = self.translator_tokenizer(english_caption, return_tensors="pt", padding=True).to(self.device)
86
- with torch.no_grad():
87
- translated_ids = self.translator_model.generate(**translated_inputs, max_length=50, num_beams=2, early_stopping=True)
88
- russian_caption = self.translator_tokenizer.decode(translated_ids[0], skip_special_tokens=True)
89
- russian_caption = russian_caption[0].upper() + russian_caption[1:] + ('.' if not russian_caption.endswith('.') else '')
90
  print(f"Время перевода на русский: {time.time() - start_time:.2f} секунд")
91
 
92
  gc.collect()
93
- return russian_caption
94
 
95
  def generate_audio(self, text: str, language: str) -> str:
96
  start_time = time.time()
@@ -101,57 +73,42 @@ class ImageCaptionPipeline:
101
  print(f"Время генерации озвучки: {time.time() - start_time:.2f} секунд")
102
  return audio_path
103
 
104
- def generate_english_caption(image: Image.Image, caption_model: str, translator_model: str) -> tuple:
105
  if image is not None:
106
- pipeline = init_pipeline(caption_model, translator_model)
107
- english_caption = pipeline.generate_english_caption(image)
108
- return english_caption, "", None
109
- return "Загрузите изображение.", "", None
110
 
111
- def generate_translation(english_caption: str, caption_model: str, translator_model: str) -> str:
112
- if not english_caption or english_caption == "Загрузите изображение.":
113
- return ""
114
- pipeline = init_pipeline(caption_model, translator_model)
115
- return pipeline.translate_caption(english_caption)
116
-
117
- def generate_audio(english_caption: str, russian_caption: str, audio_language: str, caption_model: str, translator_model: str) -> str:
118
  if not english_caption and not russian_caption:
119
  return None
120
- pipeline = init_pipeline(caption_model, translator_model)
121
  text = russian_caption if audio_language == "Русский" else english_caption
122
  return pipeline.generate_audio(text, audio_language)
123
 
124
- with gr.Blocks(css=".btn {width: 200px; background-color: #4B0082; color: white; border: none; padding: 10px 20px; text-align: center; font-size: 16px; margin: 10px auto; display: block;} .equal-height { height: 60px; }") as iface:
125
  with gr.Row():
126
- with gr.Column(scale=1, min_width=250, variant="panel"):
127
- image = gr.Image(type="pil", label="Изображение", height=250, width=250)
128
- caption_model = gr.Dropdown(choices=["BLIP", "BLIP-2"], label="Модель описания", value="BLIP")
129
  submit_button = gr.Button("Сгенерировать описание", elem_classes="btn")
130
  with gr.Column(scale=1, min_width=300):
131
  english_caption = gr.Textbox(label="Подпись English:", lines=2)
132
  russian_caption = gr.Textbox(label="Подпись Русский:", lines=2)
133
- translator_model = gr.Dropdown(choices=["M2M100", "Helsinki"], label="Модель перевода", value="M2M100")
134
- translate_button = gr.Button("Сгенерировать перевод", elem_classes="btn")
135
  audio_button = gr.Button("Сгенерировать озвучку", elem_classes="btn")
136
  with gr.Row():
137
  audio_language = gr.Dropdown(choices=["Русский", "English"], label="Язык озвучки", value="Русский", scale=1, min_width=150, elem_classes="equal-height")
138
  audio_output = gr.Audio(label="Озвучка", scale=1, min_width=150, elem_classes="equal-height")
139
 
140
  submit_button.click(
141
- fn=generate_english_caption,
142
- inputs=[image, caption_model, translator_model],
143
  outputs=[english_caption, russian_caption, audio_output]
144
  )
145
 
146
- translate_button.click(
147
- fn=generate_translation,
148
- inputs=[english_caption, caption_model, translator_model],
149
- outputs=[russian_caption]
150
- )
151
-
152
  audio_button.click(
153
  fn=generate_audio,
154
- inputs=[english_caption, russian_caption, audio_language, caption_model, translator_model],
155
  outputs=[audio_output]
156
  )
157
 
 
4
  from transformers import (
5
  BlipProcessor,
6
  BlipForConditionalGeneration,
 
 
7
  M2M100Tokenizer,
8
+ M2M100ForConditionalGeneration
 
 
9
  )
10
  from typing import Union
11
  from gtts import gTTS
 
17
  torch.set_num_threads(2)
18
  _pipeline = None
19
 
20
+ def init_pipeline():
21
  global _pipeline
22
  if _pipeline is None:
23
+ _pipeline = ImageCaptionPipeline()
24
  return _pipeline
25
 
26
  class ImageCaptionPipeline:
27
+ def __init__(self):
28
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
29
  start_time = time.time()
30
+ self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large", use_fast=True)
31
+ self.blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(self.device)
32
+ print(f"Время загрузки BLIP: {time.time() - start_time:.2f} секунд")
 
 
 
 
33
 
34
  start_time = time.time()
35
+ self.translator_tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
36
+ self.translator_model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M").to(self.device)
37
+ print(f"Время загрузки переводчика: {time.time() - start_time:.2f} секунд")
 
 
 
 
38
 
39
+ def generate_captions(self, image: Union[str, Image.Image]) -> tuple:
40
  start_time = time.time()
41
  if isinstance(image, str):
42
  image = Image.open(image)
43
  image = image.convert("RGB")
 
44
  inputs = self.blip_processor(images=image, return_tensors="pt").to(self.device)
45
  with torch.no_grad():
46
  output_ids = self.blip_model.generate(**inputs, max_length=50, num_beams=2, early_stopping=True)
47
  english_caption = self.blip_processor.decode(output_ids[0], skip_special_tokens=True)
 
48
  print(f"Время генерации английской подписи: {time.time() - start_time:.2f} секунд")
49
 
 
 
 
 
50
  start_time = time.time()
51
+ self.translator_tokenizer.src_lang = "en"
52
+ translated_inputs = self.translator_tokenizer(english_caption, return_tensors="pt", padding=True).to(self.device)
53
+ with torch.no_grad():
54
+ translated_ids = self.translator_model.generate(
55
+ **translated_inputs,
56
+ forced_bos_token_id=self.translator_tokenizer.get_lang_id("ru"),
57
+ max_length=50,
58
+ num_beams=2,
59
+ early_stopping=True
60
+ )
61
+ russian_caption = self.translator_tokenizer.decode(translated_ids[0], skip_special_tokens=True)
 
 
 
 
 
 
 
62
  print(f"Время перевода на русский: {time.time() - start_time:.2f} секунд")
63
 
64
  gc.collect()
65
+ return english_caption, russian_caption
66
 
67
  def generate_audio(self, text: str, language: str) -> str:
68
  start_time = time.time()
 
73
  print(f"Время генерации озвучки: {time.time() - start_time:.2f} секунд")
74
  return audio_path
75
 
76
+ def generate_captions(image: Image.Image) -> tuple:
77
  if image is not None:
78
+ pipeline = init_pipeline()
79
+ english_caption, russian_caption = pipeline.generate_captions(image)
80
+ return english_caption, russian_caption, None
81
+ return "Загрузите изображение.", "Загрузите изображение.", None
82
 
83
+ def generate_audio(english_caption: str, russian_caption: str, audio_language: str) -> str:
 
 
 
 
 
 
84
  if not english_caption and not russian_caption:
85
  return None
86
+ pipeline = init_pipeline()
87
  text = russian_caption if audio_language == "Русский" else english_caption
88
  return pipeline.generate_audio(text, audio_language)
89
 
90
+ with gr.Blocks(css=".btn {width: 200px; background-color: #4682B4; color: white; border: none; padding: 10px 20px; text-align: center; font-size: 16px;}") as iface:
91
  with gr.Row():
92
+ with gr.Column(scale=1, min_width=400, variant="panel"):
93
+ image = gr.Image(type="pil", label="Изображение", height=400, width=400)
 
94
  submit_button = gr.Button("Сгенерировать описание", elem_classes="btn")
95
  with gr.Column(scale=1, min_width=300):
96
  english_caption = gr.Textbox(label="Подпись English:", lines=2)
97
  russian_caption = gr.Textbox(label="Подпись Русский:", lines=2)
 
 
98
  audio_button = gr.Button("Сгенерировать озвучку", elem_classes="btn")
99
  with gr.Row():
100
  audio_language = gr.Dropdown(choices=["Русский", "English"], label="Язык озвучки", value="Русский", scale=1, min_width=150, elem_classes="equal-height")
101
  audio_output = gr.Audio(label="Озвучка", scale=1, min_width=150, elem_classes="equal-height")
102
 
103
  submit_button.click(
104
+ fn=generate_captions,
105
+ inputs=[image],
106
  outputs=[english_caption, russian_caption, audio_output]
107
  )
108
 
 
 
 
 
 
 
109
  audio_button.click(
110
  fn=generate_audio,
111
+ inputs=[english_caption, russian_caption, audio_language],
112
  outputs=[audio_output]
113
  )
114