Spaces:
Sleeping
Sleeping
import os | |
import math | |
from typing import Union, Optional | |
import torch | |
import logging | |
#from vllm import LLM, SamplingParams | |
#from vllm.lora.request import LoRARequest | |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, set_seed, BitsAndBytesConfig | |
import openai | |
from openai.error import (APIError, RateLimitError, ServiceUnavailableError, | |
Timeout, APIConnectionError, InvalidRequestError) | |
from tenacity import (before_sleep_log, retry, retry_if_exception_type, | |
stop_after_delay, wait_random_exponential, stop_after_attempt) | |
logger = logging.getLogger(__name__) | |
class Summarizer: | |
def __init__(self, | |
inference_mode:str, | |
model_id:str, | |
api_key:str, | |
dtype:str="bfloat16", | |
seed=42, | |
context_size:int=int(1024*26), | |
gpu_memory_utilization:int=0.7, | |
tensor_parallel_size=1 | |
) -> None: | |
self.inference_mode=inference_mode | |
self.model = None | |
self.tokenizer = None | |
self.seed = seed | |
openai.api_key = api_key | |
self.model = model_id | |
def get_generation_config( | |
self, | |
repetition_penalty:float = 1.2, | |
do_sample:bool=True, | |
temperature:float = 0.1, | |
top_p:float = 0.9, | |
max_tokens:int = 1024 | |
): | |
return generation_config | |
def inference_with_gpt(self, prompt): | |
prompt_messages = [{"role": "user", "content": prompt}] | |
try: | |
response = openai.ChatCompletion.create(model = self.model, messages = prompt_messages, temperature = 0.1) | |
#finish_reason = response.choices[0].finish_reason | |
response = response.choices[0].message.content | |
except InvalidRequestError: | |
response = '' | |
return response | |