prithivMLmods commited on
Commit
5cdeb4d
·
verified ·
1 Parent(s): 914bd4d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -20,6 +20,7 @@ from transformers import (
20
  AutoModelForVision2Seq,
21
  AutoProcessor,
22
  TextIteratorStreamer,
 
23
  )
24
  from transformers.image_utils import load_image
25
 
@@ -137,6 +138,8 @@ def model_chat(prompt, image):
137
  add_special_tokens=False,
138
  return_tensors="pt"
139
  ).to(device)
 
 
140
  outputs = model.generate(
141
  pixel_values=pixel_values,
142
  decoder_input_ids=prompt_inputs.input_ids,
@@ -150,7 +153,8 @@ def model_chat(prompt, image):
150
  return_dict_in_generate=True,
151
  do_sample=False,
152
  num_beams=1,
153
- repetition_penalty=1.1
 
154
  )
155
  sequence = processor.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)[0]
156
  cleaned = sequence.replace(f"<s>{prompt} <Answer/>", "").replace("<pad>", "").replace("</s>", "").strip()
 
20
  AutoModelForVision2Seq,
21
  AutoProcessor,
22
  TextIteratorStreamer,
23
+ EncoderDecoderCache # Added to handle the new caching mechanism
24
  )
25
  from transformers.image_utils import load_image
26
 
 
138
  add_special_tokens=False,
139
  return_tensors="pt"
140
  ).to(device)
141
+
142
+ # Explicitly set past_key_values to None to align with new caching mechanism and avoid deprecated tuple warning
143
  outputs = model.generate(
144
  pixel_values=pixel_values,
145
  decoder_input_ids=prompt_inputs.input_ids,
 
153
  return_dict_in_generate=True,
154
  do_sample=False,
155
  num_beams=1,
156
+ repetition_penalty=1.1,
157
+ past_key_values=None # Added to prevent deprecated tuple handling
158
  )
159
  sequence = processor.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)[0]
160
  cleaned = sequence.replace(f"<s>{prompt} <Answer/>", "").replace("<pad>", "").replace("</s>", "").strip()