filipsedivy commited on
Commit
f0a5e1c
·
1 Parent(s): 379c6ae

Update inference

Browse files
Files changed (1) hide show
  1. app.py +12 -24
app.py CHANGED
@@ -1,23 +1,17 @@
1
  import gradio as gr
2
-
3
  from langchain.embeddings import SentenceTransformerEmbeddings
4
  from langchain.vectorstores import Chroma
5
-
6
- from transformers import T5Tokenizer, T5ForConditionalGeneration
7
 
8
  embeddings = SentenceTransformerEmbeddings(model_name="msmarco-distilbert-base-v4")
9
  db = Chroma(persist_directory="embeddings", embedding_function=embeddings)
10
 
11
- tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
12
- model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large")
13
 
14
 
15
  def respond(
16
  message,
17
  history: list[tuple[str, str]],
18
- max_tokens,
19
- temperature,
20
- repetition_penalty,
21
  ):
22
  matching_docs = db.similarity_search(message)
23
 
@@ -37,28 +31,22 @@ def respond(
37
  f"Answer:"
38
  )
39
 
40
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids
41
-
42
- outputs = model.generate(input_ids,
43
- do_sample=True,
44
- max_new_tokens=max_tokens,
45
- temperature=temperature,
46
- repetition_penalty=repetition_penalty)
47
 
48
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
49
 
50
 
51
  demo = gr.ChatInterface(
52
  respond,
53
- additional_inputs=[
54
- gr.Slider(minimum=1, maximum=2048, value=1024, step=1, label="Max new tokens"),
55
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
56
- gr.Slider(minimum=0.1, maximum=10, value=1.5, step=0.1, label="Repetition penalty"),
57
- ],
58
  examples=[
59
- "What types of roles are in the system?",
60
- "How to import records into stock receipts in Boost.space?",
61
- "Is it possible to create a PDF export from the product?",
62
  ],
63
  )
64
 
 
1
  import gradio as gr
 
2
  from langchain.embeddings import SentenceTransformerEmbeddings
3
  from langchain.vectorstores import Chroma
4
+ from huggingface_hub import InferenceClient
 
5
 
6
  embeddings = SentenceTransformerEmbeddings(model_name="msmarco-distilbert-base-v4")
7
  db = Chroma(persist_directory="embeddings", embedding_function=embeddings)
8
 
9
+ client = InferenceClient(model="google/flan-t5-large")
 
10
 
11
 
12
  def respond(
13
  message,
14
  history: list[tuple[str, str]],
 
 
 
15
  ):
16
  matching_docs = db.similarity_search(message)
17
 
 
31
  f"Answer:"
32
  )
33
 
34
+ response = client.text_generation(
35
+ prompt,
36
+ max_new_tokens=250,
37
+ temperature=0.7,
38
+ top_p=0.95,
39
+ )
 
40
 
41
+ yield response
42
 
43
 
44
  demo = gr.ChatInterface(
45
  respond,
 
 
 
 
 
46
  examples=[
47
+ ["What types of roles are in the system?"],
48
+ ["How to import records into stock receipts in Boost.space?"],
49
+ ["Is it possible to create a PDF export from the product?"],
50
  ],
51
  )
52