File size: 4,785 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from enum import Enum
from typing import Dict, List, Optional
from huggingface_hub import InferenceClient
from ..pipelines.base import Pipeline
class MessageRole(str, Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
TOOL_CALL = "tool-call"
TOOL_RESPONSE = "tool-response"
@classmethod
def roles(cls):
return [r.value for r in cls]
def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {}):
"""
Subsequent messages with the same role will be concatenated to a single message.
Args:
message_list (`List[Dict[str, str]]`): List of chat messages.
"""
final_message_list = []
message_list = deepcopy(message_list) # Avoid modifying the original list
for message in message_list:
if not set(message.keys()) == {"role", "content"}:
raise ValueError("Message should contain only 'role' and 'content' keys!")
role = message["role"]
if role not in MessageRole.roles():
raise ValueError(f"Incorrect role {role}, only {MessageRole.roles()} are supported for now.")
if role in role_conversions:
message["role"] = role_conversions[role]
if len(final_message_list) > 0 and message["role"] == final_message_list[-1]["role"]:
final_message_list[-1]["content"] += "\n=======\n" + message["content"]
else:
final_message_list.append(message)
return final_message_list
llama_role_conversions = {
MessageRole.TOOL_RESPONSE: MessageRole.USER,
}
class HfApiEngine:
"""This engine leverages Hugging Face's Inference API service, either serverless or with a dedicated endpoint."""
def __init__(self, model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct"):
self.model = model
self.client = InferenceClient(self.model, timeout=120)
def __call__(
self, messages: List[Dict[str, str]], stop_sequences: List[str] = [], grammar: Optional[str] = None
) -> str:
# Get clean message list
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)
# Get LLM output
if grammar is not None:
response = self.client.chat_completion(
messages, stop=stop_sequences, max_tokens=1500, response_format=grammar
)
else:
response = self.client.chat_completion(messages, stop=stop_sequences, max_tokens=1500)
response = response.choices[0].message.content
# Remove stop sequences from LLM output
for stop_seq in stop_sequences:
if response[-len(stop_seq) :] == stop_seq:
response = response[: -len(stop_seq)]
return response
class TransformersEngine:
"""This engine uses a pre-initialized local text-generation pipeline."""
def __init__(self, pipeline: Pipeline):
self.pipeline = pipeline
def __call__(
self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
) -> str:
# Get clean message list
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)
# Get LLM output
output = self.pipeline(
messages,
stop_strings=stop_sequences,
max_length=1500,
tokenizer=self.pipeline.tokenizer,
)
response = output[0]["generated_text"][-1]["content"]
# Remove stop sequences from LLM output
if stop_sequences is not None:
for stop_seq in stop_sequences:
if response[-len(stop_seq) :] == stop_seq:
response = response[: -len(stop_seq)]
return response
DEFAULT_JSONAGENT_REGEX_GRAMMAR = {
"type": "regex",
"value": 'Thought: .+?\\nAction:\\n\\{\\n\\s{4}"action":\\s"[^"\\n]+",\\n\\s{4}"action_input":\\s"[^"\\n]+"\\n\\}\\n<end_action>',
}
DEFAULT_CODEAGENT_REGEX_GRAMMAR = {
"type": "regex",
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_action>",
}
|