File size: 5,690 Bytes
62193fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""Handles loading and running of models."""

from calendar import c
import json
import math
import os
import re
import warnings
from time import sleep, time

import spaces
from dotenv import load_dotenv

from logger import logger

warnings.filterwarnings("ignore")
os.environ["VLLM_LOGGING_LEVEL"] = "ERROR"


load_dotenv()
safe_token = "No"
risky_token = "Yes"
nlogprobs = 20

inference_engine = os.getenv("INFERENCE_ENGINE", "TORCH")
logger.debug(f"Inference engine is: {inference_engine}")

if inference_engine == "TORCH":
    import torch
    from transformers import AutoTokenizer
    from vllm import LLM, SamplingParams
    from torch.nn.functional import softmax
    from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel

    # backend_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    backend_device = "cuda" if torch.cuda.is_available() else "cpu"
    logger.debug(f"Backend device is: {backend_device}")

    model_path = os.getenv("MODEL_PATH", "ibm-granite/granite-guardian-3.2-3b-a800m")
    logger.debug(f"model_path is {model_path}")

    tokenizer = AutoTokenizer.from_pretrained(model_path)

    device = torch.device("cpu")

    model = AutoModelForCausalLM.from_pretrained(model_path)
    model = model.to(device).eval()


def get_probablities(logprobs):
    safe_token_prob = 1e-50
    unsafe_token_prob = 1e-50
    for gen_token_i in logprobs:
        for logprob, index in zip(gen_token_i.values.tolist()[0], gen_token_i.indices.tolist()[0]):
            decoded_token = tokenizer.convert_ids_to_tokens(index)
            if decoded_token.strip().lower() == safe_token.lower():
                safe_token_prob += math.exp(logprob)
            if decoded_token.strip().lower() == risky_token.lower():
                unsafe_token_prob += math.exp(logprob)

    probabilities = torch.softmax(torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0)

    return probabilities


def parse_output(output_ids, input_len):
    label, prob_of_risk = None, None
    if nlogprobs > 0:

        list_index_logprobs_i = [
            torch.topk(token_i, k=nlogprobs, largest=True, sorted=True) for token_i in list(output_ids.scores)[:-1]
        ]
        if list_index_logprobs_i is not None:
            prob = get_probablities(list_index_logprobs_i)
            prob_of_risk = round(prob[1].item(), 3)

    generated_text = tokenizer.decode(output_ids.sequences[:, input_len:][0], skip_special_tokens=True).strip()
    res = re.search(r"^\w+", generated_text, re.MULTILINE).group(0).strip()
    if risky_token.lower() == res.lower():
        label = risky_token
    elif safe_token.lower() == res.lower():
        label = safe_token
    else:
        label = "Failed"

    confidence_level = re.search(r"<confidence> (.*?) </confidence>", generated_text).group(1).strip()
    certainty = prob_of_risk if prob_of_risk > 0.5 else 1 - prob_of_risk

    return label, confidence_level, prob_of_risk, certainty

@spaces.GPU
def get_prompt(messages, criteria_name, criteria_description = None):
    """Todo"""
    logger.debug("Creating prompt for the model.")
    logger.debug(f"Messages are: {json.dumps(messages, indent=2)}")

    if criteria_name == "general_harm":
        criteria_name = "harm"
    elif criteria_name == "function_calling_hallucination":
        criteria_name = "function_call"
    logger.debug("Criteria name was changed too: " + criteria_name)
    guardian_config = {"risk_name": criteria_name}
    if criteria_description is not None:
        guardian_config['risk_definition'] = criteria_description
    logger.debug(f"guardian_config is: {guardian_config}")
    prompt = tokenizer.apply_chat_template(
        messages,
        guardian_config=guardian_config,
        tokenize=False,
        add_generation_prompt=True,
    )
    logger.debug(f"Prompt is:\n{prompt}")
    return prompt


@spaces.GPU
def get_guardian_response(messages, criteria_name, criteria_description=None):
    start = time()
    if criteria_name == "general_harm":
        criteria_name = "harm"
    elif criteria_name == "function_calling_hallucination":
        criteria_name = "function_call"
    logger.debug(f"Messages are: {json.dumps(messages, indent=2)}")
    if inference_engine == "MOCK":
        logger.debug("Returning mocked model result.")
        sleep(1)
        label, confidence_level, prob_of_risk, certainty = "Yes", 'High', 0.97,  0.97

    elif inference_engine == "TORCH":
        guardian_config = {"risk_name": criteria_name}
        if criteria_description is not None:
            guardian_config['risk_definition'] = criteria_description
        logger.debug(f"guardian_config is: {guardian_config}")

        input_ids = tokenizer.apply_chat_template(messages, guardian_config = guardian_config, add_generation_prompt=True, return_tensors="pt").to(model.device)
        input_len = input_ids.shape[1]

        with torch.no_grad():
            output_ids = model.generate(
                input_ids,
                do_sample=False,
                max_new_tokens=nlogprobs,
                return_dict_in_generate=True,
                output_scores=True,
            )

            label, confidence_level, prob_of_risk, certainty = parse_output(output_ids, input_len)

    else:
        raise Exception("Environment variable 'INFERENCE_ENGINE' must be one of [MOCK, TORCH]")

    logger.debug(f"label={label}, confidence_level={confidence_level}, prob_of_risk={prob_of_risk}, certainty={certainty}")

    end = time()
    total = end - start
    logger.debug(f"The evaluation took {total} secs")

    return {"label": label, "certainty": certainty}