Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,557 Bytes
88d793f b164762 88d793f de8c49a deacbef de8c49a deacbef de8c49a 8625475 de8c49a 7a6ac2d b164762 8625475 af0d996 3508dbc 8625475 a7583a6 8625475 7feb3a5 8625475 b3b665b b164762 b3b665b 8625475 b0c49af 8625475 88d793f 78844be 88d793f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import gradio as gr
from gradio.data_classes import FileData
from huggingface_hub import snapshot_download
from pathlib import Path
import base64
import spaces
import os
from mistral_inference.transformer import Transformer
from mistral_inference.generate import generate
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import UserMessage, AssistantMessage, TextChunk, ImageURLChunk
from mistral_common.protocol.instruct.request import ChatCompletionRequest
models_path = Path.home().joinpath('pixtral', 'Pixtral')
models_path.mkdir(parents=True, exist_ok=True)
snapshot_download(repo_id="mistral-community/pixtral-12b-240910",
allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"],
local_dir=models_path)
tokenizer = MistralTokenizer.from_file(f"{models_path}/tekken.json")
model = Transformer.from_folder(models_path)
def image_to_base64(image_path):
with open(image_path, 'rb') as img:
encoded_string = base64.b64encode(img.read()).decode('utf-8')
return f"data:image/jpeg;base64,{encoded_string}"
import requests
import base64
def url_to_base64(image_url):
# Fetch the image from the URL
response = requests.get(image_url)
if response.status_code == 200:
# Encode image content to Base64
base64_image = base64.b64encode(response.content).decode('utf-8')
return f"data:image/jpeg;base64,{base64_image}"
else:
return f"data:image/jpeg;base64,"
import json
@spaces.GPU(duration=90)
def run_inference(message, history):
try:
messages= message['text']
print("messages ", messages)
messages = json.loads(messages)
final_msg=[]
for x in messages:
if x['role']=='user':
tmmp=[]
for y in x['content']:
if y['type']=='image':
print('inserting image')
tmmp+=[ImageURLChunk(image_url= url_to_base64(y['url'])) ]
else:
tmmp+=[TextChunk(text= y['text'] )]
final_msg.append(UserMessage(content =tmmp ) )
else:
final_msg.append(AssistantMessage(content = x['content'][0]['text'] ))
print('final msg ', final_msg)
completion_request = ChatCompletionRequest(messages=final_msg)
encoded = tokenizer.encode_chat_completion(completion_request)
images = encoded.images
tokens = encoded.tokens
out_tokens, _ = generate([tokens], model, images=[images], max_tokens=2048, temperature=0.45, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
result = tokenizer.decode(out_tokens[0])
return result
## may work
except Exception as e:
print('usig deqfualt ', e)
messages = []
images = []
print('\n\nmessage ',message)
print('\n\nhistoery ',history)
for couple in history:
if type(couple[0]) is tuple:
images += couple[0]
elif couple[0][1]:
messages.append(UserMessage(content = [ImageURLChunk(image_url=image_to_base64(path)) for path in images]+[TextChunk(text=couple[0][1])]))
messages.append(AssistantMessage(content = couple[1]))
images = []
##
messages.append(UserMessage(content = [ImageURLChunk(image_url=image_to_base64(file["path"])) for file in message["files"]]+[TextChunk(text=message["text"])]))
print('\n\nfinal messageds', messages)
completion_request = ChatCompletionRequest(messages=messages)
encoded = tokenizer.encode_chat_completion(completion_request)
images = encoded.images
tokens = encoded.tokens
out_tokens, _ = generate([tokens], model, images=[images], max_tokens=512, temperature=0.45, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
result = tokenizer.decode(out_tokens[0])
return result
demo = gr.ChatInterface(fn=run_inference, title="Pixtral 12B", multimodal=True, description="A demo chat interface with Pixtral 12B, deployed using Mistral Inference.")
demo.queue().launch() |