Spaces:
Running
Running
"""Handles loading and running of models.""" | |
from calendar import c | |
import json | |
import math | |
import os | |
import re | |
import warnings | |
from time import sleep, time | |
import spaces | |
from dotenv import load_dotenv | |
from logger import logger | |
warnings.filterwarnings("ignore") | |
os.environ["VLLM_LOGGING_LEVEL"] = "ERROR" | |
load_dotenv() | |
safe_token = "No" | |
risky_token = "Yes" | |
nlogprobs = 20 | |
inference_engine = os.getenv("INFERENCE_ENGINE", "TORCH") | |
logger.debug(f"Inference engine is: {inference_engine}") | |
if inference_engine == "TORCH": | |
import torch | |
from transformers import AutoTokenizer | |
from vllm import LLM, SamplingParams | |
from torch.nn.functional import softmax | |
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel | |
# backend_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" | |
backend_device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.debug(f"Backend device is: {backend_device}") | |
model_path = os.getenv("MODEL_PATH", "ibm-granite/granite-guardian-3.2-3b-a800m") | |
logger.debug(f"model_path is {model_path}") | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
device = torch.device("cpu") | |
model = AutoModelForCausalLM.from_pretrained(model_path) | |
model = model.to(device).eval() | |
def get_probablities(logprobs): | |
safe_token_prob = 1e-50 | |
unsafe_token_prob = 1e-50 | |
for gen_token_i in logprobs: | |
for logprob, index in zip(gen_token_i.values.tolist()[0], gen_token_i.indices.tolist()[0]): | |
decoded_token = tokenizer.convert_ids_to_tokens(index) | |
if decoded_token.strip().lower() == safe_token.lower(): | |
safe_token_prob += math.exp(logprob) | |
if decoded_token.strip().lower() == risky_token.lower(): | |
unsafe_token_prob += math.exp(logprob) | |
probabilities = torch.softmax(torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0) | |
return probabilities | |
def parse_output(output_ids, input_len): | |
label, prob_of_risk = None, None | |
if nlogprobs > 0: | |
list_index_logprobs_i = [ | |
torch.topk(token_i, k=nlogprobs, largest=True, sorted=True) for token_i in list(output_ids.scores)[:-1] | |
] | |
if list_index_logprobs_i is not None: | |
prob = get_probablities(list_index_logprobs_i) | |
prob_of_risk = round(prob[1].item(), 3) | |
generated_text = tokenizer.decode(output_ids.sequences[:, input_len:][0], skip_special_tokens=True).strip() | |
res = re.search(r"^\w+", generated_text, re.MULTILINE).group(0).strip() | |
if risky_token.lower() == res.lower(): | |
label = risky_token | |
elif safe_token.lower() == res.lower(): | |
label = safe_token | |
else: | |
label = "Failed" | |
confidence_level = re.search(r"<confidence> (.*?) </confidence>", generated_text).group(1).strip() | |
certainty = prob_of_risk if prob_of_risk > 0.5 else 1 - prob_of_risk | |
return label, confidence_level, prob_of_risk, certainty | |
def get_prompt(messages, criteria_name, criteria_description = None): | |
"""Todo""" | |
logger.debug("Creating prompt for the model.") | |
logger.debug(f"Messages are: {json.dumps(messages, indent=2)}") | |
if criteria_name == "general_harm": | |
criteria_name = "harm" | |
elif criteria_name == "function_calling_hallucination": | |
criteria_name = "function_call" | |
logger.debug("Criteria name was changed too: " + criteria_name) | |
guardian_config = {"risk_name": criteria_name} | |
if criteria_description is not None: | |
guardian_config['risk_definition'] = criteria_description | |
logger.debug(f"guardian_config is: {guardian_config}") | |
prompt = tokenizer.apply_chat_template( | |
messages, | |
guardian_config=guardian_config, | |
tokenize=False, | |
add_generation_prompt=True, | |
) | |
logger.debug(f"Prompt is:\n{prompt}") | |
return prompt | |
def get_guardian_response(messages, criteria_name, criteria_description=None): | |
start = time() | |
if criteria_name == "general_harm": | |
criteria_name = "harm" | |
elif criteria_name == "function_calling_hallucination": | |
criteria_name = "function_call" | |
logger.debug(f"Messages are: {json.dumps(messages, indent=2)}") | |
if inference_engine == "MOCK": | |
logger.debug("Returning mocked model result.") | |
sleep(1) | |
label, confidence_level, prob_of_risk, certainty = "Yes", 'High', 0.97, 0.97 | |
elif inference_engine == "TORCH": | |
guardian_config = {"risk_name": criteria_name} | |
if criteria_description is not None: | |
guardian_config['risk_definition'] = criteria_description | |
logger.debug(f"guardian_config is: {guardian_config}") | |
input_ids = tokenizer.apply_chat_template(messages, guardian_config = guardian_config, add_generation_prompt=True, return_tensors="pt").to(model.device) | |
input_len = input_ids.shape[1] | |
with torch.no_grad(): | |
output_ids = model.generate( | |
input_ids, | |
do_sample=False, | |
max_new_tokens=nlogprobs, | |
return_dict_in_generate=True, | |
output_scores=True, | |
) | |
label, confidence_level, prob_of_risk, certainty = parse_output(output_ids, input_len) | |
else: | |
raise Exception("Environment variable 'INFERENCE_ENGINE' must be one of [MOCK, TORCH]") | |
logger.debug(f"label={label}, confidence_level={confidence_level}, prob_of_risk={prob_of_risk}, certainty={certainty}") | |
end = time() | |
total = end - start | |
logger.debug(f"The evaluation took {total} secs") | |
return {"label": label, "certainty": certainty} | |