Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -7,6 +7,58 @@ from gradio.themes.utils import colors
|
|
7 |
|
8 |
from transformers import pipeline, TextIteratorStreamer, AutoModelForCausalLM, AutoTokenizer
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
# Custom theme colors based on brand standards
|
12 |
class ArtemisiaTheme(Base):
|
@@ -431,11 +483,8 @@ paper_plane_svg = """<svg xmlns="http://www.w3.org/2000/svg" width="20" height="
|
|
431 |
<path d="M22 2L15 22L11 13L2 9L22 2Z"/>
|
432 |
</svg>"""
|
433 |
|
|
|
434 |
|
435 |
-
# Pipeline loading
|
436 |
-
#generator = pipeline("text-generation", model="openai-community/gpt2")
|
437 |
-
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
438 |
-
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
439 |
|
440 |
# Mock data function for chatbot
|
441 |
def send_message(message, history):
|
@@ -444,23 +493,10 @@ def send_message(message, history):
|
|
444 |
history.append({"role": "user", "content": message})
|
445 |
#history.append({"role": "assistant", "content": f"This is a response about: {message}"})
|
446 |
#return history
|
447 |
-
|
448 |
-
input_ids = tokenizer.encode(message, return_tensors="pt")
|
449 |
-
gen_kwargs = {
|
450 |
-
"inputs": input_ids,
|
451 |
-
"streamer": streamer,
|
452 |
-
"pad_token_id": tokenizer.eos_token_id,
|
453 |
-
"max_length": 8192,
|
454 |
-
"temperature": 0.1,
|
455 |
-
"top_p": 0.8,
|
456 |
-
"repetition_penalty": 1.25,
|
457 |
-
}
|
458 |
partial = ""
|
459 |
-
|
460 |
-
|
461 |
-
#for token in generator(message, max_new_tokens=200):
|
462 |
-
for t in streamer:
|
463 |
-
partial += t#token["generated_text"][len(message):]
|
464 |
yield history + [{"role": "assistant", "content": partial}]
|
465 |
|
466 |
|
|
|
7 |
|
8 |
from transformers import pipeline, TextIteratorStreamer, AutoModelForCausalLM, AutoTokenizer
|
9 |
|
10 |
+
SYSTEM_PROMPT = "You are a compliance assistant. Use the provided risk data to answer user questions. If a single risk object is given, provide a direct answer. If a list of risks is provided, summarize, compare, or analyze the collection as needed. Always base your response on the data provided."
|
11 |
+
|
12 |
+
class HfModelWrapper:
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
model_path="casperhansen/llama-3.3-70b-instruct-awq",
|
16 |
+
sys_prompt=SYSTEM_PROMPT,
|
17 |
+
adapter_path="artemisiaai/fine-tuned-adapter",
|
18 |
+
):
|
19 |
+
|
20 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
21 |
+
model_path, device_map="auto", quantization_config=quantization, device_map="auto"
|
22 |
+
)
|
23 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
24 |
+
self.sys_prompt = sys_prompt
|
25 |
+
self.adapter_path = adapter_path
|
26 |
+
self.model.load_adapter(self.adapter_path)
|
27 |
+
self.model.enable_adapters()
|
28 |
+
|
29 |
+
def build_prompt(self, user_msg, history):
|
30 |
+
|
31 |
+
inppt = []
|
32 |
+
inppt.append({"role": "system", "content": self.sys_prompt})
|
33 |
+
inppt += history
|
34 |
+
inppt.append({"role": "user", "content": user_msg})
|
35 |
+
|
36 |
+
prompt = self.tokenizer.apply_chat_template(
|
37 |
+
inppt,
|
38 |
+
tokenize=False,
|
39 |
+
)
|
40 |
+
return prompt
|
41 |
+
|
42 |
+
def generate(self, user_input, history):
|
43 |
+
input_text = self.build_prompt(user_input, history)
|
44 |
+
input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to("cuda")
|
45 |
+
|
46 |
+
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True)
|
47 |
+
|
48 |
+
gen_kwargs = {
|
49 |
+
"inputs": input_ids,
|
50 |
+
"streamer": streamer,
|
51 |
+
"pad_token_id": self.tokenizer.eos_token_id,
|
52 |
+
"max_length": 8192,
|
53 |
+
"temperature": 0.1,
|
54 |
+
"top_p": 0.8,
|
55 |
+
"repetition_penalty": 1.25,
|
56 |
+
}
|
57 |
+
|
58 |
+
thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
|
59 |
+
thread.start()
|
60 |
+
|
61 |
+
return streamer
|
62 |
|
63 |
# Custom theme colors based on brand standards
|
64 |
class ArtemisiaTheme(Base):
|
|
|
483 |
<path d="M22 2L15 22L11 13L2 9L22 2Z"/>
|
484 |
</svg>"""
|
485 |
|
486 |
+
wrapper = HfModelWrapper()
|
487 |
|
|
|
|
|
|
|
|
|
488 |
|
489 |
# Mock data function for chatbot
|
490 |
def send_message(message, history):
|
|
|
493 |
history.append({"role": "user", "content": message})
|
494 |
#history.append({"role": "assistant", "content": f"This is a response about: {message}"})
|
495 |
#return history
|
496 |
+
response_generator = wrapper.generate(user_input, history)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
497 |
partial = ""
|
498 |
+
for t in response_generator:
|
499 |
+
partial += t
|
|
|
|
|
|
|
500 |
yield history + [{"role": "assistant", "content": partial}]
|
501 |
|
502 |
|