#!/usr/bin/env python # coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy from enum import Enum from typing import Dict, List, Optional from huggingface_hub import InferenceClient from ..pipelines.base import Pipeline class MessageRole(str, Enum): USER = "user" ASSISTANT = "assistant" SYSTEM = "system" TOOL_CALL = "tool-call" TOOL_RESPONSE = "tool-response" @classmethod def roles(cls): return [r.value for r in cls] def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {}): """ Subsequent messages with the same role will be concatenated to a single message. Args: message_list (`List[Dict[str, str]]`): List of chat messages. """ final_message_list = [] message_list = deepcopy(message_list) # Avoid modifying the original list for message in message_list: if not set(message.keys()) == {"role", "content"}: raise ValueError("Message should contain only 'role' and 'content' keys!") role = message["role"] if role not in MessageRole.roles(): raise ValueError(f"Incorrect role {role}, only {MessageRole.roles()} are supported for now.") if role in role_conversions: message["role"] = role_conversions[role] if len(final_message_list) > 0 and message["role"] == final_message_list[-1]["role"]: final_message_list[-1]["content"] += "\n=======\n" + message["content"] else: final_message_list.append(message) return final_message_list llama_role_conversions = { MessageRole.TOOL_RESPONSE: MessageRole.USER, } class HfApiEngine: """This engine leverages Hugging Face's Inference API service, either serverless or with a dedicated endpoint.""" def __init__(self, model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct"): self.model = model self.client = InferenceClient(self.model, timeout=120) def __call__( self, messages: List[Dict[str, str]], stop_sequences: List[str] = [], grammar: Optional[str] = None ) -> str: # Get clean message list messages = get_clean_message_list(messages, role_conversions=llama_role_conversions) # Get LLM output if grammar is not None: response = self.client.chat_completion( messages, stop=stop_sequences, max_tokens=1500, response_format=grammar ) else: response = self.client.chat_completion(messages, stop=stop_sequences, max_tokens=1500) response = response.choices[0].message.content # Remove stop sequences from LLM output for stop_seq in stop_sequences: if response[-len(stop_seq) :] == stop_seq: response = response[: -len(stop_seq)] return response class TransformersEngine: """This engine uses a pre-initialized local text-generation pipeline.""" def __init__(self, pipeline: Pipeline): self.pipeline = pipeline def __call__( self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None ) -> str: # Get clean message list messages = get_clean_message_list(messages, role_conversions=llama_role_conversions) # Get LLM output output = self.pipeline( messages, stop_strings=stop_sequences, max_length=1500, tokenizer=self.pipeline.tokenizer, ) response = output[0]["generated_text"][-1]["content"] # Remove stop sequences from LLM output if stop_sequences is not None: for stop_seq in stop_sequences: if response[-len(stop_seq) :] == stop_seq: response = response[: -len(stop_seq)] return response DEFAULT_JSONAGENT_REGEX_GRAMMAR = { "type": "regex", "value": 'Thought: .+?\\nAction:\\n\\{\\n\\s{4}"action":\\s"[^"\\n]+",\\n\\s{4}"action_input":\\s"[^"\\n]+"\\n\\}\\n', } DEFAULT_CODEAGENT_REGEX_GRAMMAR = { "type": "regex", "value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```", }