akhaliq HF Staff commited on
Commit
f89ae07
·
verified ·
1 Parent(s): fe7f387

Update Gradio app with multiple files

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -14,7 +14,7 @@ model_id = "Qwen/Qwen3-VL-4B-Instruct"
14
  # Load model with optimizations for inference
15
  model = Qwen3VLForConditionalGeneration.from_pretrained(
16
  model_id,
17
- dtype="auto",
18
  device_map="auto"
19
  )
20
  processor = AutoProcessor.from_pretrained(model_id)
@@ -80,6 +80,9 @@ def process_chat_message(
80
  return_tensors="pt"
81
  )
82
 
 
 
 
83
  # Generate response
84
  with torch.no_grad():
85
  generated_ids = model.generate(
@@ -93,7 +96,7 @@ def process_chat_message(
93
  # Decode the generated response
94
  generated_ids_trimmed = [
95
  out_ids[len(in_ids):]
96
- for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
97
  ]
98
 
99
  response = processor.batch_decode(
 
14
  # Load model with optimizations for inference
15
  model = Qwen3VLForConditionalGeneration.from_pretrained(
16
  model_id,
17
+ torch_dtype=torch.bfloat16,
18
  device_map="auto"
19
  )
20
  processor = AutoProcessor.from_pretrained(model_id)
 
80
  return_tensors="pt"
81
  )
82
 
83
+ # Move inputs to the same device as the model
84
+ inputs = {k: v.to(model.device) if torch.is_tensor(v) else v for k, v in inputs.items()}
85
+
86
  # Generate response
87
  with torch.no_grad():
88
  generated_ids = model.generate(
 
96
  # Decode the generated response
97
  generated_ids_trimmed = [
98
  out_ids[len(in_ids):]
99
+ for in_ids, out_ids in zip(inputs['input_ids'], generated_ids)
100
  ]
101
 
102
  response = processor.batch_decode(