Huggingface-Space-Commander / model_logic.py
broadfield-dev's picture
Create model_logic.py
6b5f0c3 verified
import os
import requests
import json
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
API_KEYS = {
"HUGGINGFACE": 'HF_TOKEN',
"GROQ": 'GROQ_API_KEY',
"OPENROUTER": 'OPENROUTER_API_KEY',
"TOGETHERAI": 'TOGETHERAI_API_KEY',
"COHERE": 'COHERE_API_KEY',
"XAI": 'XAI_API_KEY',
"OPENAI": 'OPENAI_API_KEY',
"GOOGLE": 'GOOGLE_API_KEY',
}
API_URLS = {
"HUGGINGFACE": 'https://api-inference.huggingface.co/models/',
"GROQ": 'https://api.groq.com/openai/v1/chat/completions',
"OPENROUTER": 'https://openrouter.ai/api/v1/chat/completions',
"TOGETHERAI": 'https://api.together.ai/v1/chat/completions',
"COHERE": 'https://api.cohere.ai/v1/chat',
"XAI": 'https://api.x.ai/v1/chat/completions',
"OPENAI": 'https://api.openai.com/v1/chat/completions',
"GOOGLE": 'https://generativelanguage.googleapis.com/v1beta/models/',
}
MODELS_BY_PROVIDER = {
"groq": {
"default": "llama3-8b-8192",
"models": {
"Llama 3 8B (Groq)": "llama3-8b-8192",
"Llama 3 70B (Groq)": "llama3-70b-8192",
"Mixtral 8x7B (Groq)": "mixtral-8x7b-32768",
"Gemma 7B (Groq)": "gemma-7b-it",
}
},
"openrouter": {
"default": "nousresearch/llama-3-8b-instruct",
"models": {
"Nous Llama-3 8B Instruct (OpenRouter)": "nousresearch/llama-3-8b-instruct",
"Mistral 7B Instruct v0.2 (OpenRouter)": "mistralai/mistral-7b-instruct:free",
"Gemma 7B Instruct (OpenRouter)": "google/gemma-7b-it:free",
"Mixtral 8x7B Instruct v0.1 (OpenRouter)": "mistralai/mixtral-8x7b-instruct",
"Llama 2 70B Chat (OpenRouter)": "meta-llama/llama-2-70b-chat",
"Neural Chat 7B v3.1 (OpenRouter)": "intel/neural-chat-7b-v3-1",
"Goliath 120B (OpenRouter)": "twob/goliath-v2-120b",
}
},
"togetherai": {
"default": "meta-llama/Llama-3-8b-chat-hf",
"models": {
"Llama 3 8B Chat (TogetherAI)": "meta-llama/Llama-3-8b-chat-hf",
"Llama 3 70B Chat (TogetherAI)": "meta-llama/Llama-3-70b-chat-hf",
"Mixtral 8x7B Instruct (TogetherAI)": "mistralai/Mixtral-8x7B-Instruct-v0.1",
"Gemma 7B Instruct (TogetherAI)": "google/gemma-7b-it",
"RedPajama INCITE Chat 3B (TogetherAI)": "togethercomputer/RedPajama-INCITE-Chat-3B-v1",
}
},
"google": {
"default": "gemini-1.5-flash-latest",
"models": {
"Gemini 1.5 Flash (Latest)": "gemini-1.5-flash-latest",
"Gemini 1.5 Pro (Latest)": "gemini-1.5-pro-latest",
}
},
"cohere": {
"default": "command-light",
"models": {
"Command R (Cohere)": "command-r",
"Command R+ (Cohere)": "command-r-plus",
"Command Light (Cohere)": "command-light",
"Command (Cohere)": "command",
}
},
"huggingface": {
"default": "HuggingFaceH4/zephyr-7b-beta",
"models": {
"Zephyr 7B Beta (H4/HF Inf.)": "HuggingFaceH4/zephyr-7b-beta",
"Mistral 7B Instruct v0.2 (HF Inf.)": "mistralai/Mistral-7B-Instruct-v0.2",
"Llama 2 13B Chat (Meta/HF Inf.)": "meta-llama/Llama-2-13b-chat-hf",
"OpenAssistant/oasst-sft-4-pythia-12b (HF Inf.)": "OpenAssistant/oasst-sft-4-pythia-12b",
}
},
"openai": {
"default": "gpt-3.5-turbo",
"models": {
"GPT-4o (OpenAI)": "gpt-4o",
"GPT-4o mini (OpenAI)": "gpt-4o-mini",
"GPT-4 Turbo (OpenAI)": "gpt-4-turbo",
"GPT-3.5 Turbo (OpenAI)": "gpt-3.5-turbo",
}
},
"xai": {
"default": "grok-1",
"models": {
"Grok-1 (xAI)": "grok-1",
}
}
}
def _get_api_key(provider: str, ui_api_key_override: str = None) -> str:
if ui_api_key_override:
return ui_api_key_override.strip()
env_var_name = API_KEYS.get(provider.upper())
if env_var_name:
env_key = os.getenv(env_var_name)
if env_key:
return env_key.strip()
if provider.lower() == 'huggingface':
hf_token = os.getenv("HF_TOKEN")
if hf_token: return hf_token.strip()
logger.warning(f"API Key not found for provider '{provider}'. Checked UI override and environment variable '{env_var_name or 'N/A'}'.")
return None
def get_available_providers() -> list[str]:
return sorted(list(MODELS_BY_PROVIDER.keys()))
def get_models_for_provider(provider: str) -> list[str]:
return sorted(list(MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {}).keys()))
def get_default_model_for_provider(provider: str) -> str | None:
models_dict = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {})
default_model_id = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("default")
if default_model_id:
for display_name, model_id in models_dict.items():
if model_id == default_model_id:
return display_name
if models_dict:
return sorted(list(models_dict.keys()))[0]
return None
def get_model_id_from_display_name(provider: str, display_name: str) -> str | None:
models = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {})
return models.get(display_name)
def generate_stream(provider: str, model_display_name: str, api_key_override: str, messages: list[dict]) -> iter:
provider_lower = provider.lower()
api_key = _get_api_key(provider_lower, api_key_override)
base_url = API_URLS.get(provider.upper())
model_id = get_model_id_from_display_name(provider_lower, model_display_name)
if not api_key:
env_var_name = API_KEYS.get(provider.upper(), 'N/A')
yield f"Error: API Key not found for {provider}. Please set it in the UI override or environment variable '{env_var_name}'."
return
if not base_url:
yield f"Error: Unknown provider '{provider}' or missing API URL configuration."
return
if not model_id:
yield f"Error: Unknown model '{model_display_name}' for provider '{provider}'. Please select a valid model."
return
headers = {}
payload = {}
request_url = base_url
logger.info(f"Calling {provider}/{model_display_name} (ID: {model_id}) stream...")
try:
if provider_lower in ["groq", "openrouter", "togetherai", "openai", "xai"]:
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
payload = {
"model": model_id,
"messages": messages,
"stream": True
}
if provider_lower == "openrouter":
headers["HTTP-Referer"] = os.getenv("SPACE_HOST") or "https://github.com/your_username/ai-space-builder"
headers["X-Title"] = "AI Space Builder"
response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=180)
response.raise_for_status()
byte_buffer = b""
for chunk in response.iter_content(chunk_size=8192):
byte_buffer += chunk
while b'\n' in byte_buffer:
line, byte_buffer = byte_buffer.split(b'\n', 1)
decoded_line = line.decode('utf-8', errors='ignore')
if decoded_line.startswith('data: '):
data = decoded_line[6:]
if data == '[DONE]':
byte_buffer = b''
break
try:
event_data = json.loads(data)
if event_data.get("choices") and len(event_data["choices"]) > 0:
delta = event_data["choices"][0].get("delta")
if delta and delta.get("content"):
yield delta["content"]
except json.JSONDecodeError:
logger.warning(f"Failed to decode JSON from stream line: {decoded_line}")
except Exception as e:
logger.error(f"Error processing stream data: {e}, Data: {decoded_line}")
if byte_buffer:
remaining_line = byte_buffer.decode('utf-8', errors='ignore')
if remaining_line.startswith('data: '):
data = remaining_line[6:]
if data != '[DONE]':
try:
event_data = json.loads(data)
if event_data.get("choices") and len(event_data["choices"]) > 0:
delta = event_data["choices"][0].get("delta")
if delta and delta.get("content"):
yield delta["content"]
except json.JSONDecodeError:
logger.warning(f"Failed to decode final stream buffer JSON: {remaining_line}")
except Exception as e:
logger.error(f"Error processing final stream buffer data: {e}, Data: {remaining_line}")
elif provider_lower == "google":
system_instruction = None
filtered_messages = []
for msg in messages:
if msg["role"] == "system":
system_instruction = msg["content"]
else:
role = "model" if msg["role"] == "assistant" else msg["role"]
filtered_messages.append({"role": role, "parts": [{"text": msg["content"]}]})
payload = {
"contents": filtered_messages,
"safetySettings": [
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
],
"generationConfig": {
"temperature": 0.7,
}
}
if system_instruction:
payload["system_instruction"] = {"parts": [{"text": system_instruction}]}
request_url = f"{base_url}{model_id}:streamGenerateContent"
headers = {"Content-Type": "application/json"}
request_url = f"{request_url}?key={api_key}"
response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=180)
response.raise_for_status()
byte_buffer = b""
for chunk in response.iter_content(chunk_size=8192):
byte_buffer += chunk
while b'\n' in byte_buffer:
line, byte_buffer = byte_buffer.split(b'\n', 1)
decoded_line = line.decode('utf-8', errors='ignore')
if decoded_line.startswith('data: '):
decoded_line = decoded_line[6:].strip()
if not decoded_line: continue
try:
event_data_list = json.loads(f"[{decoded_line}]")
if not isinstance(event_data_list, list): event_data_list = [event_data_list]
for event_data in event_data_list:
if not isinstance(event_data, dict): continue
if event_data.get("candidates") and len(event_data["candidates"]) > 0:
candidate = event_data["candidates"][0]
if candidate.get("content") and candidate["content"].get("parts"):
full_text_chunk = "".join(part.get("text", "") for part in candidate["content"]["parts"])
if full_text_chunk:
yield full_text_chunk
except json.JSONDecodeError:
logger.warning(f"Failed to decode JSON from Google stream chunk: {decoded_line}. Accumulating buffer.")
pass
except Exception as e:
logger.error(f"Error processing Google stream data: {e}, Data: {decoded_line}")
if byte_buffer:
remaining_line = byte_buffer.decode('utf-8', errors='ignore').strip()
if remaining_line:
try:
event_data_list = json.loads(f"[{remaining_line}]")
if not isinstance(event_data_list, list): event_data_list = [event_data_list]
for event_data in event_data_list:
if not isinstance(event_data, dict): continue
if event_data.get("candidates") and len(event_data["candidates"]) > 0:
candidate = event_data["candidates"][0]
if candidate.get("content") and candidate["content"].get("parts"):
full_text_chunk = "".join(part.get("text", "") for part in candidate["content"]["parts"])
if full_text_chunk:
yield full_text_chunk
except json.JSONDecodeError:
logger.warning(f"Failed to decode final Google stream buffer JSON: {remaining_line}")
except Exception as e:
logger.error(f"Error processing final Google stream buffer data: {e}, Data: {remaining_line}")
elif provider_lower == "cohere":
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
request_url = f"{base_url}"
chat_history_for_cohere = []
system_prompt_for_cohere = None
current_message_for_cohere = ""
temp_history = []
for msg in messages:
if msg["role"] == "system":
system_prompt_for_cohere = msg["content"]
elif msg["role"] == "user" or msg["role"] == "assistant":
temp_history.append(msg)
if temp_history:
current_message_for_cohere = temp_history[-1]["content"]
chat_history_for_cohere = [{"role": ("chatbot" if m["role"] == "assistant" else m["role"]), "message": m["content"]} for m in temp_history[:-1]]
if not current_message_for_cohere:
yield "Error: User message not found for Cohere API call."
return
payload = {
"model": model_id,
"message": current_message_for_cohere,
"stream": True,
"temperature": 0.7
}
if chat_history_for_cohere:
payload["chat_history"] = chat_history_for_cohere
if system_prompt_for_cohere:
payload["preamble"] = system_prompt_for_cohere
response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=180)
response.raise_for_status()
byte_buffer = b""
for chunk in response.iter_content(chunk_size=8192):
byte_buffer += chunk
while b'\n\n' in byte_buffer:
event_chunk, byte_buffer = byte_buffer.split(b'\n\n', 1)
lines = event_chunk.strip().split(b'\n')
event_type = None
event_data = None
for l in lines:
if l.startswith(b"event: "): event_type = l[7:].strip().decode('utf-8', errors='ignore')
elif l.startswith(b"data: "):
try: event_data = json.loads(l[6:].strip().decode('utf-8', errors='ignore'))
except json.JSONDecodeError: logger.warning(f"Cohere: Failed to decode event data JSON: {l[6:].strip()}")
if event_type == "text-generation" and event_data and "text" in event_data:
yield event_data["text"]
elif event_type == "stream-end":
byte_buffer = b''
break
if byte_buffer:
event_chunk = byte_buffer.strip()
if event_chunk:
lines = event_chunk.split(b'\n')
event_type = None
event_data = None
for l in lines:
if l.startswith(b"event: "): event_type = l[7:].strip().decode('utf-8', errors='ignore')
elif l.startswith(b"data: "):
try: event_data = json.loads(l[6:].strip().decode('utf-8', errors='ignore'))
except json.JSONDecodeError: logger.warning(f"Cohere: Failed to decode final event data JSON: {l[6:].strip()}")
if event_type == "text-generation" and event_data and "text" in event_data:
yield event_data["text"]
elif event_type == "stream-end":
pass
elif provider_lower == "huggingface":
yield f"Error: Direct Hugging Face Inference API streaming for chat models is experimental and model-dependent. Consider using OpenRouter or TogetherAI for HF models with standardized streaming."
return
else:
yield f"Error: Unsupported provider '{provider}' for streaming chat."
return
except requests.exceptions.HTTPError as e:
status_code = e.response.status_code if e.response is not None else 'N/A'
error_text = e.response.text if e.response is not None else 'No response text'
logger.error(f"HTTP error during streaming for {provider}/{model_id}: {e}")
yield f"API HTTP Error ({status_code}): {error_text}\nDetails: {e}"
except requests.exceptions.RequestException as e:
logger.error(f"Request error during streaming for {provider}/{model_id}: {e}")
yield f"API Request Error: Could not connect or receive response from {provider} ({e})"
except Exception as e:
logger.exception(f"Unexpected error during streaming for {provider}/{model_id}:")
yield f"An unexpected error occurred: {e}"