dicksonhk commited on
Commit
6c22125
·
1 Parent(s): 0aca28b

Fix edge case error when converting vlm models: "Received parameters not in model"

Browse files
Files changed (1) hide show
  1. app.py +91 -15
app.py CHANGED
@@ -1,8 +1,14 @@
 
 
 
1
  import os
2
  import tempfile
3
  import importlib.util
4
  from enum import Enum
5
 
 
 
 
6
  os.environ["HF_HUB_CACHE"] = "cache"
7
  os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
8
  import gradio as gr
@@ -18,9 +24,12 @@ from apscheduler.schedulers.background import BackgroundScheduler
18
 
19
  from textwrap import dedent
20
  from typing import (
 
21
  Callable,
22
  Dict,
23
  Optional,
 
 
24
  Union,
25
  NamedTuple,
26
  )
@@ -172,6 +181,47 @@ def upload_to_hub(path, upload_repo, hf_path, oauth_token, runtime: Runtime):
172
 
173
  print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  def convert(
176
  hf_path: str,
177
  mlx_path: str = "mlx_model",
@@ -188,6 +238,8 @@ def convert(
188
  skip_vision: bool = False, # mlx-vlm
189
  trust_remote_code: bool = True, # mlx-vlm
190
  ) -> Runtime :
 
 
191
  def mlx_lm_convert():
192
  mlx_lm.convert(
193
  hf_path=hf_path,
@@ -203,21 +255,44 @@ def convert(
203
  )
204
 
205
  def mlx_vlm_convert():
206
- mlx_vlm.convert(
207
- hf_path=hf_path,
208
- mlx_path=mlx_path,
209
- quantize=quantize,
210
- q_group_size=q_group_size,
211
- q_bits=q_bits,
212
- dtype=dtype,
213
- upload_repo=upload_repo,
214
- revision=revision,
215
- dequantize=dequantize,
216
- skip_vision=skip_vision,
217
- trust_remote_code=trust_remote_code,
218
- )
219
-
220
- model_path = get_model_path(hf_path, revision=revision)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  config = load_config(model_path)
222
  model_type = config["model_type"]
223
  model_type = MODEL_REMAPPING.get(model_type, model_type)
@@ -273,6 +348,7 @@ def process_model(model_id, q_method, oauth_token: gr.OAuthToken | None):
273
  "llama.png",
274
  )
275
  except Exception as e:
 
276
  return (f"Error: {e}", "error.png")
277
  finally:
278
  clear_hf_cache_space()
 
1
+ from pathlib import Path
2
+ import traceback
3
+
4
  import os
5
  import tempfile
6
  import importlib.util
7
  from enum import Enum
8
 
9
+ from contextlib import contextmanager, AbstractContextManager
10
+ from functools import wraps
11
+
12
  os.environ["HF_HUB_CACHE"] = "cache"
13
  os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
14
  import gradio as gr
 
24
 
25
  from textwrap import dedent
26
  from typing import (
27
+ Any,
28
  Callable,
29
  Dict,
30
  Optional,
31
+ Tuple,
32
+ Type,
33
  Union,
34
  NamedTuple,
35
  )
 
181
 
182
  print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
183
 
184
+
185
+ @contextmanager
186
+ def patch_strict_default_methods_ctx() -> AbstractContextManager[Callable[[Any, str], None]]:
187
+ """
188
+ Context manager to temporarily set the default value of the 'strict' arg to `False`
189
+ for specified class methods.
190
+ Does not affect explict `strict=True`.
191
+
192
+ (e.g. `def update(self, parameters: dict, strict: bool = True)`
193
+ becomes `def update(self, parameters: dict, strict: bool = False)`)
194
+
195
+ Typical usage:
196
+
197
+ with patch_strict_default_methods_ctx() as patch:
198
+ patch(Foo, "bar")
199
+ patch(Foo, "baz")
200
+ patch(Bar, "foo")
201
+ # Patched methods active here
202
+ # Originals restored here
203
+ """
204
+
205
+ originals: Dict[Tuple[Type[Any], str], Callable] = {}
206
+
207
+ def patch(cls: Any, method_name: str):
208
+ method = getattr(cls, method_name)
209
+ originals[(cls, method_name)] = method
210
+
211
+ @wraps(method)
212
+ def wrapper(self, *args, strict=False, **kwargs):
213
+ return method(self, *args, strict=strict, **kwargs)
214
+
215
+ setattr(cls, method_name, wrapper)
216
+
217
+ try:
218
+ yield patch
219
+ finally:
220
+ # Restore all patched methods
221
+ for (cls, method_name), original in originals.items():
222
+ setattr(cls, method_name, original)
223
+ originals.clear()
224
+
225
  def convert(
226
  hf_path: str,
227
  mlx_path: str = "mlx_model",
 
238
  skip_vision: bool = False, # mlx-vlm
239
  trust_remote_code: bool = True, # mlx-vlm
240
  ) -> Runtime :
241
+ model_path = get_model_path(hf_path, revision=revision)
242
+
243
  def mlx_lm_convert():
244
  mlx_lm.convert(
245
  hf_path=hf_path,
 
255
  )
256
 
257
  def mlx_vlm_convert():
258
+ # try:
259
+ # new_model_path = remove_extra_parameters_from_weights(model_path=model_path)
260
+ # print(f"{new_model_path} exists: {Path(new_model_path).exists()}")
261
+ # except Exception as e:
262
+ # new_model_path = model_path
263
+ # print(f"Unexpected error while trying to fix model weights: {e}")
264
+ # traceback.print_exc()
265
+ # raise e
266
+
267
+ def _mlx_vlm_convert():
268
+ mlx_vlm.convert(
269
+ #hf_path=new_model_path,
270
+ hf_path=hf_path,
271
+ mlx_path=mlx_path,
272
+ quantize=quantize,
273
+ q_group_size=q_group_size,
274
+ q_bits=q_bits,
275
+ dtype=dtype,
276
+ upload_repo=upload_repo,
277
+ revision=revision,
278
+ dequantize=dequantize,
279
+ skip_vision=skip_vision,
280
+ trust_remote_code=trust_remote_code,
281
+ )
282
+
283
+ try:
284
+ _mlx_vlm_convert()
285
+ except ValueError as e:
286
+ print(e)
287
+ print(f"Error converting, try again with strict = False")
288
+ with patch_strict_default_methods_ctx() as patch:
289
+ import mlx.nn as n
290
+ patch(nn.Module, "load_weights")
291
+ patch(nn.Module, "update")
292
+ patch(nn.Module, "update_modules")
293
+ # patched strict=False by default, try again
294
+ _mlx_vlm_convert()
295
+
296
  config = load_config(model_path)
297
  model_type = config["model_type"]
298
  model_type = MODEL_REMAPPING.get(model_type, model_type)
 
348
  "llama.png",
349
  )
350
  except Exception as e:
351
+ traceback.print_exc()
352
  return (f"Error: {e}", "error.png")
353
  finally:
354
  clear_hf_cache_space()