z-coder commited on
Commit
242ba54
Β·
verified Β·
1 Parent(s): e1ef4f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -10
app.py CHANGED
@@ -1,14 +1,41 @@
1
- import gradio as gr
2
- from transformers import AutoProcessor, AutoModelForCausalLM
3
- from PIL import Image
4
  import torch
 
 
 
 
 
 
 
 
 
5
 
6
- model = AutoModelForCausalLM.from_pretrained("llava-hf/llava-1.5-7b-hf", torch_dtype=torch.float16).to("cuda")
7
- processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- def chat(image, prompt):
10
- inputs = processor(prompt, images=image, return_tensors="pt").to("cuda")
11
- output = model.generate(**inputs, max_new_tokens=50)
12
- return processor.tokenizer.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
 
 
13
 
14
- gr.Interface(fn=chat, inputs=["image", "text"], outputs="text").launch()
 
 
 
 
1
  import torch
2
+ from transformers import AutoModelForCausalLM, AutoProcessor
3
+ from PIL import Image
4
+ import gradio as gr
5
+
6
+ # Load the MAGMA model and processor
7
+ model_id = "microsoft/Magma-8B"
8
+ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
9
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, trust_remote_code=True)
10
+ model.to("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
+ def magma_inference(image: Image.Image, prompt: str) -> str:
13
+ # Prepare conversation
14
+ convs = [
15
+ {"role": "system", "content": "You are an agent that can see, talk, and act."},
16
+ {"role": "user", "content": prompt}
17
+ ]
18
+ # Generate prompt
19
+ text_prompt = processor.tokenizer.apply_chat_template(convs, tokenize=False, add_generation_prompt=True)
20
+ # Process inputs
21
+ inputs = processor(images=[image], texts=text_prompt, return_tensors="pt").to(model.device)
22
+ # Generate output
23
+ with torch.inference_mode():
24
+ generate_ids = model.generate(**inputs, max_new_tokens=50)
25
+ generate_ids = generate_ids[:, inputs["input_ids"].shape[-1]:]
26
+ response = processor.decode(generate_ids[0], skip_special_tokens=True).strip()
27
+ return response
28
 
29
+ # Gradio interface
30
+ interface = gr.Interface(
31
+ fn=magma_inference,
32
+ inputs=[
33
+ gr.Image(type="pil", label="Input Image"),
34
+ gr.Textbox(label="Prompt")
35
+ ],
36
+ outputs=gr.Textbox(label="MAGMA Output"),
37
+ title="MAGMA Image + Text to Text API",
38
+ description="Upload an image and enter a prompt. Returns MAGMA's textual response."
39
+ )
40
 
41
+ app = gr.mount_gradio_app(app=interface, path="/")