Clone04 commited on
Commit
b731001
·
verified ·
1 Parent(s): e06a3af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -3
app.py CHANGED
@@ -1,16 +1,28 @@
1
  import gradio as gr
2
  from optimum.intel import OVDiffusionPipeline
 
3
  from threading import Lock
 
4
 
5
- # Load the pipeline globally (includes the correct tokenizer)
 
 
 
6
  model_id = "OpenVINO/FLUX.1-schnell-int4-ov"
7
  pipeline = OVDiffusionPipeline.from_pretrained(model_id, device="CPU")
 
 
 
 
 
 
8
  lock = Lock()
9
 
10
  # Define the image generation function
11
  def generate_image(prompt):
12
  with lock:
13
- image = pipeline(prompt, num_inference_steps=4, guidance_scale=3.5).images[0]
 
14
  return image
15
 
16
  # Create the Gradio interface
@@ -26,4 +38,4 @@ interface = gr.Interface(
26
 
27
  # Launch the interface
28
  if __name__ == "__main__":
29
- interface.launch()
 
1
  import gradio as gr
2
  from optimum.intel import OVDiffusionPipeline
3
+ from transformers import AutoTokenizer
4
  from threading import Lock
5
+ import warnings
6
 
7
+ # Suppress deprecation warnings (optional)
8
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
9
+
10
+ # Load the pipeline globally
11
  model_id = "OpenVINO/FLUX.1-schnell-int4-ov"
12
  pipeline = OVDiffusionPipeline.from_pretrained(model_id, device="CPU")
13
+
14
+ # Explicitly load and configure the tokenizer
15
+ # FLUX.1-schnell uses a T5-based encoder, typically t5-v1_1-xl or similar
16
+ tokenizer = AutoTokenizer.from_pretrained("google/t5-v1_1-xl", use_fast=True, add_prefix_space=True)
17
+ pipeline.text_encoder_2.tokenizer = tokenizer # Assign to the T5 encoder
18
+
19
  lock = Lock()
20
 
21
  # Define the image generation function
22
  def generate_image(prompt):
23
  with lock:
24
+ # Reduce num_inference_steps for faster inference to avoid timeouts
25
+ image = pipeline(prompt, num_inference_steps=2, guidance_scale=3.5).images[0]
26
  return image
27
 
28
  # Create the Gradio interface
 
38
 
39
  # Launch the interface
40
  if __name__ == "__main__":
41
+ interface.launch(server_name="0.0.0.0", server_port=7860) # Explicitly set for Spaces