ayan4m1 commited on
Commit
555f172
·
1 Parent(s): 6b3b146

feat: add multiple model support

Browse files
Files changed (1) hide show
  1. app.py +10 -2
app.py CHANGED
@@ -1,18 +1,25 @@
1
  import gradio as gr
2
  from transformers import pipeline
3
 
4
- pipe = pipeline("text-generation", model="pszemraj/distilgpt2-magicprompt-SD")
 
 
 
 
5
 
 
 
6
 
7
  def respond(
8
  message,
9
  _: list[tuple[str, str]],
 
10
  max_new_tokens: int,
11
  temperature: float,
12
  top_p: float,
13
  top_k: int
14
  ):
15
- yield pipe(
16
  message,
17
  max_new_tokens=max_new_tokens,
18
  do_sample=True,
@@ -28,6 +35,7 @@ For information on how to customize the ChatInterface, peruse the gradio docs: h
28
  demo = gr.ChatInterface(
29
  respond,
30
  additional_inputs=[
 
31
  gr.Slider(minimum=8, maximum=128, value=64, step=8, label="Max new tokens"),
32
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
33
  gr.Slider(
 
1
  import gradio as gr
2
  from transformers import pipeline
3
 
4
+ models = {
5
+ "MagicPrompt": "pszemraj/distilgpt2-magicprompt-SD",
6
+ "Llama-SmolTalk-3.2-1B": "prithivMLmods/Llama-SmolTalk-3.2-1B-Instruct"
7
+ }
8
+ pipelines = {}
9
 
10
+ for key, value in models.items():
11
+ pipelines[key] = pipeline("text-generation", model=value)
12
 
13
  def respond(
14
  message,
15
  _: list[tuple[str, str]],
16
+ model: str,
17
  max_new_tokens: int,
18
  temperature: float,
19
  top_p: float,
20
  top_k: int
21
  ):
22
+ yield pipelines[model](
23
  message,
24
  max_new_tokens=max_new_tokens,
25
  do_sample=True,
 
35
  demo = gr.ChatInterface(
36
  respond,
37
  additional_inputs=[
38
+ gr.Radio(choices=[models.items()], value="pszemraj/distilgpt2-magicprompt-SD", type="value", label="Model"),
39
  gr.Slider(minimum=8, maximum=128, value=64, step=8, label="Max new tokens"),
40
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
41
  gr.Slider(