fCola commited on
Commit
d9e2d70
·
verified ·
1 Parent(s): 461067a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -20
app.py CHANGED
@@ -7,6 +7,58 @@ from gradio.themes.utils import colors
7
 
8
  from transformers import pipeline, TextIteratorStreamer, AutoModelForCausalLM, AutoTokenizer
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  # Custom theme colors based on brand standards
12
  class ArtemisiaTheme(Base):
@@ -431,11 +483,8 @@ paper_plane_svg = """<svg xmlns="http://www.w3.org/2000/svg" width="20" height="
431
  <path d="M22 2L15 22L11 13L2 9L22 2Z"/>
432
  </svg>"""
433
 
 
434
 
435
- # Pipeline loading
436
- #generator = pipeline("text-generation", model="openai-community/gpt2")
437
- tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
438
- model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
439
 
440
  # Mock data function for chatbot
441
  def send_message(message, history):
@@ -444,23 +493,10 @@ def send_message(message, history):
444
  history.append({"role": "user", "content": message})
445
  #history.append({"role": "assistant", "content": f"This is a response about: {message}"})
446
  #return history
447
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
448
- input_ids = tokenizer.encode(message, return_tensors="pt")
449
- gen_kwargs = {
450
- "inputs": input_ids,
451
- "streamer": streamer,
452
- "pad_token_id": tokenizer.eos_token_id,
453
- "max_length": 8192,
454
- "temperature": 0.1,
455
- "top_p": 0.8,
456
- "repetition_penalty": 1.25,
457
- }
458
  partial = ""
459
- thread = Thread(target=model.generate, kwargs=gen_kwargs)
460
- thread.start()
461
- #for token in generator(message, max_new_tokens=200):
462
- for t in streamer:
463
- partial += t#token["generated_text"][len(message):]
464
  yield history + [{"role": "assistant", "content": partial}]
465
 
466
 
 
7
 
8
  from transformers import pipeline, TextIteratorStreamer, AutoModelForCausalLM, AutoTokenizer
9
 
10
+ SYSTEM_PROMPT = "You are a compliance assistant. Use the provided risk data to answer user questions. If a single risk object is given, provide a direct answer. If a list of risks is provided, summarize, compare, or analyze the collection as needed. Always base your response on the data provided."
11
+
12
+ class HfModelWrapper:
13
+ def __init__(
14
+ self,
15
+ model_path="casperhansen/llama-3.3-70b-instruct-awq",
16
+ sys_prompt=SYSTEM_PROMPT,
17
+ adapter_path="artemisiaai/fine-tuned-adapter",
18
+ ):
19
+
20
+ self.model = AutoModelForCausalLM.from_pretrained(
21
+ model_path, device_map="auto", quantization_config=quantization, device_map="auto"
22
+ )
23
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
24
+ self.sys_prompt = sys_prompt
25
+ self.adapter_path = adapter_path
26
+ self.model.load_adapter(self.adapter_path)
27
+ self.model.enable_adapters()
28
+
29
+ def build_prompt(self, user_msg, history):
30
+
31
+ inppt = []
32
+ inppt.append({"role": "system", "content": self.sys_prompt})
33
+ inppt += history
34
+ inppt.append({"role": "user", "content": user_msg})
35
+
36
+ prompt = self.tokenizer.apply_chat_template(
37
+ inppt,
38
+ tokenize=False,
39
+ )
40
+ return prompt
41
+
42
+ def generate(self, user_input, history):
43
+ input_text = self.build_prompt(user_input, history)
44
+ input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to("cuda")
45
+
46
+ streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True)
47
+
48
+ gen_kwargs = {
49
+ "inputs": input_ids,
50
+ "streamer": streamer,
51
+ "pad_token_id": self.tokenizer.eos_token_id,
52
+ "max_length": 8192,
53
+ "temperature": 0.1,
54
+ "top_p": 0.8,
55
+ "repetition_penalty": 1.25,
56
+ }
57
+
58
+ thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
59
+ thread.start()
60
+
61
+ return streamer
62
 
63
  # Custom theme colors based on brand standards
64
  class ArtemisiaTheme(Base):
 
483
  <path d="M22 2L15 22L11 13L2 9L22 2Z"/>
484
  </svg>"""
485
 
486
+ wrapper = HfModelWrapper()
487
 
 
 
 
 
488
 
489
  # Mock data function for chatbot
490
  def send_message(message, history):
 
493
  history.append({"role": "user", "content": message})
494
  #history.append({"role": "assistant", "content": f"This is a response about: {message}"})
495
  #return history
496
+ response_generator = wrapper.generate(user_input, history)
 
 
 
 
 
 
 
 
 
 
497
  partial = ""
498
+ for t in response_generator:
499
+ partial += t
 
 
 
500
  yield history + [{"role": "assistant", "content": partial}]
501
 
502