cmcmaster commited on
Commit
c0ba980
·
verified ·
1 Parent(s): 943894e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -44
app.py CHANGED
@@ -1,69 +1,49 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
3
  from pydantic import BaseModel, Field
4
  from typing import Optional
 
5
 
6
- # Define the schema
7
  class Medication(BaseModel):
8
  drug_name: str = Field(description="The name of the drug.")
9
- is_generic: bool = Field(
10
- description="Indicates if the drug name is a generic drug name (e.g. 'Tylenol' is not generic, 'paracetamol' or 'acetaminophen' is generic)."
11
- )
12
  strength: Optional[str] = Field(default=None, description="The strength of the drug.")
13
  unit: Optional[str] = Field(default=None, description="The unit of measurement for the drug strength.")
14
  dosage_form: Optional[str] = Field(default=None, description="The form of the drug (e.g., patch, tablet).")
15
  frequency: Optional[str] = Field(default=None, description="The frequency of drug administration.")
16
  route: Optional[str] = Field(default=None, description="The route of administration (e.g., oral, topical).")
17
- is_prn: Optional[bool] = Field(default=None, description="Whether the medication is taken 'as needed' (pro re nata).")
18
  total_daily_dose_mg: Optional[float] = Field(default=None, description="The total daily dose in milligrams.")
19
 
20
- # Get the schema for structured generation
21
- schema = Medication.schema()
 
 
 
 
 
 
22
 
23
- # Connect to your model
24
- client = InferenceClient("cmcmaster/drug_parsing_Llama-3.2-1B-Instruct")
25
 
26
- # Response function
27
  def respond(
28
  message,
29
  history: list[tuple[str, str]],
30
- system_message,
31
- max_tokens,
32
- temperature,
33
- top_p,
34
  ):
35
- messages = [{"role": "system", "content": system_message}]
36
-
37
- for user_msg, assistant_msg in history:
38
- if user_msg:
39
- messages.append({"role": "user", "content": user_msg})
40
- if assistant_msg:
41
- messages.append({"role": "assistant", "content": assistant_msg})
42
-
43
- messages.append({"role": "user", "content": message})
44
-
45
- # Structured generation with schema
46
- output = client.chat_completion(
47
- messages=messages,
48
- max_tokens=max_tokens,
49
- temperature=temperature,
50
- top_p=top_p,
51
- stream=False,
52
- response_format={"type": "json", "value": schema},
53
- )
54
 
55
- content = output.choices[0].message.content
56
- yield content
57
 
58
- # Gradio app
59
  demo = gr.ChatInterface(
60
- respond,
61
- additional_inputs=[
62
- gr.Textbox(value="Extract structured medication details from this input.", label="System message"),
63
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
64
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
65
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
66
- ],
67
  )
68
 
69
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ from outlines.models.llamacpp import LlamaCpp
3
+ from outlines import generate, samplers
4
  from pydantic import BaseModel, Field
5
  from typing import Optional
6
+ import json
7
 
8
+ # Define the output schema
9
  class Medication(BaseModel):
10
  drug_name: str = Field(description="The name of the drug.")
11
+ is_generic: bool = Field(description="Indicates if the drug name is a generic drug name.")
 
 
12
  strength: Optional[str] = Field(default=None, description="The strength of the drug.")
13
  unit: Optional[str] = Field(default=None, description="The unit of measurement for the drug strength.")
14
  dosage_form: Optional[str] = Field(default=None, description="The form of the drug (e.g., patch, tablet).")
15
  frequency: Optional[str] = Field(default=None, description="The frequency of drug administration.")
16
  route: Optional[str] = Field(default=None, description="The route of administration (e.g., oral, topical).")
17
+ is_prn: Optional[bool] = Field(default=None, description="Whether the medication is taken 'as needed'.")
18
  total_daily_dose_mg: Optional[float] = Field(default=None, description="The total daily dose in milligrams.")
19
 
20
+ # Load your model locally via llama-cpp
21
+ model = LlamaCpp(
22
+ model_path="/path/to/cmcmaster/drug_parsing_Llama-3.2-1B-Instruct-Q5_K_S-GGUF.gguf", # Change this path
23
+ temperature=0.0,
24
+ max_tokens=512
25
+ )
26
+
27
+ sampler = samplers.greedy()
28
 
29
+ # Prepare structured generator
30
+ structured_generator = generate.json(model, Medication, sampler = sampler)
31
 
 
32
  def respond(
33
  message,
34
  history: list[tuple[str, str]],
 
 
 
 
35
  ):
36
+ try:
37
+ medication = structured_generator(message)
38
+ response = json.dumps(medication.model_dump(), indent=2)
39
+ except Exception as e:
40
+ response = f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ yield response
 
43
 
44
+ # Gradio interface
45
  demo = gr.ChatInterface(
46
+ respond
 
 
 
 
 
 
47
  )
48
 
49
  if __name__ == "__main__":