yhzx233 commited on
Commit
ccd4320
·
1 Parent(s): 3236b17
Files changed (2) hide show
  1. app.py +2 -0
  2. generation_utils.py +1 -1
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  import torch
3
  import torchaudio
@@ -149,6 +150,7 @@ def initialize_model():
149
  # Initialize model when starting the application
150
  initialize_model()
151
 
 
152
  def process_single_audio_generation(
153
  text_input: str,
154
  audio_mode: str,
 
1
+ import spaces
2
  import gradio as gr
3
  import torch
4
  import torchaudio
 
150
  # Initialize model when starting the application
151
  initialize_model()
152
 
153
+ @spaces.GPU
154
  def process_single_audio_generation(
155
  text_input: str,
156
  audio_mode: str,
generation_utils.py CHANGED
@@ -15,7 +15,7 @@ SILENCE_DURATION = 5.0 # Fixed silence duration: 5 seconds
15
  def load_model(model_path, spt_config_path, spt_checkpoint_path):
16
  tokenizer = AutoTokenizer.from_pretrained(model_path)
17
 
18
- model = AsteroidTTSInstruct.from_pretrained(model_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
19
 
20
  spt = XY_Tokenizer.load_from_checkpoint(config_path=spt_config_path, ckpt_path=spt_checkpoint_path)
21
 
 
15
  def load_model(model_path, spt_config_path, spt_checkpoint_path):
16
  tokenizer = AutoTokenizer.from_pretrained(model_path)
17
 
18
+ model = AsteroidTTSInstruct.from_pretrained(model_path, torch_dtype=torch.bfloat16, attn_implementation="sdpa")
19
 
20
  spt = XY_Tokenizer.load_from_checkpoint(config_path=spt_config_path, ckpt_path=spt_checkpoint_path)
21