Final_Assignment_Template / mistral_hf_wrapper.py
FD900's picture
Update mistral_hf_wrapper.py
b2bf871 verified
import os
import requests
class MistralInference:
def __init__(self, api_url=None, api_token=None):
self.api_url = api_url or os.getenv("HF_MISTRAL_ENDPOINT")
self.api_token = api_token or os.getenv("HF_TOKEN")
def run(self, prompt: str) -> str:
headers = {
"Authorization": f"Bearer {self.api_token}",
"Content-Type": "application/json"
}
payload = {
"inputs": prompt,
"parameters": {"max_new_tokens": 512}
}
try:
response = requests.post(self.api_url, headers=headers, json=payload,timeout=30)
response.raise_for_status()
output = response.json()
# Check different possible keys depending on model
if isinstance(output, list) and "generated_text" in output[0]:
return output[0]["generated_text"]
elif "generated_text" in output:
return output["generated_text"]
elif "text" in output:
return output["text"]
else:
return str(output)
except Exception as e:
return f"Error querying Mistral: {str(e)}"