Ava Pun commited on
Commit
cc9cd8f
·
1 Parent(s): 5d131e0
Files changed (1) hide show
  1. app.py +6 -11
app.py CHANGED
@@ -43,7 +43,7 @@ def main():
43
  setup()
44
 
45
  model_cfg = BrickGPTConfig(max_regenerations=5, device='cuda')
46
- generator = BrickGenerator(BrickGPT(model_cfg))
47
 
48
  # Define inputs and outputs
49
  in_prompt = gr.Textbox(label='Prompt', placeholder='Enter a prompt to generate a brick model.', max_length=500)
@@ -100,9 +100,9 @@ def main():
100
 
101
 
102
  class BrickGenerator:
103
- def __init__(self, model: BrickGPT):
104
- self.model = model
105
- self.ctx = mp.get_context('spawn')
106
 
107
  @spaces.GPU
108
  def generate_bricks(
@@ -115,6 +115,8 @@ class BrickGenerator:
115
  max_brick_rejections: int | None,
116
  max_regenerations: int | None,
117
  ) -> (str, str):
 
 
118
  # Set model parameters
119
  if temperature is not None: self.model.temperature = temperature
120
  if max_bricks is not None: self.model.max_bricks = max_bricks
@@ -173,13 +175,6 @@ class BrickGenerator:
173
 
174
  return img_filename, output['bricks'].to_txt()
175
 
176
- def generate_bricks_subprocess(self, *args) -> (str, str):
177
- """
178
- Run generation as a subprocess so that multiple requests can be handled concurrently.
179
- """
180
- with self.ctx.Pool(1) as pool:
181
- return pool.starmap(self.generate_bricks, [args])[0]
182
-
183
 
184
  def get_help_string(field_name: str) -> str:
185
  """
 
43
  setup()
44
 
45
  model_cfg = BrickGPTConfig(max_regenerations=5, device='cuda')
46
+ generator = BrickGenerator(model_cfg)
47
 
48
  # Define inputs and outputs
49
  in_prompt = gr.Textbox(label='Prompt', placeholder='Enter a prompt to generate a brick model.', max_length=500)
 
100
 
101
 
102
  class BrickGenerator:
103
+ def __init__(self, model_cfg: BrickGPTConfig):
104
+ self.model_cfg = model_cfg
105
+ self.model = None
106
 
107
  @spaces.GPU
108
  def generate_bricks(
 
115
  max_brick_rejections: int | None,
116
  max_regenerations: int | None,
117
  ) -> (str, str):
118
+ self.model = BrickGPT(self.model_cfg)
119
+
120
  # Set model parameters
121
  if temperature is not None: self.model.temperature = temperature
122
  if max_bricks is not None: self.model.max_bricks = max_bricks
 
175
 
176
  return img_filename, output['bricks'].to_txt()
177
 
 
 
 
 
 
 
 
178
 
179
  def get_help_string(field_name: str) -> str:
180
  """