Spaces:
Running
Running
Fix edge case error when converting vlm models: "Received parameters not in model"
Browse files
app.py
CHANGED
@@ -1,8 +1,14 @@
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import tempfile
|
3 |
import importlib.util
|
4 |
from enum import Enum
|
5 |
|
|
|
|
|
|
|
6 |
os.environ["HF_HUB_CACHE"] = "cache"
|
7 |
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
8 |
import gradio as gr
|
@@ -18,9 +24,12 @@ from apscheduler.schedulers.background import BackgroundScheduler
|
|
18 |
|
19 |
from textwrap import dedent
|
20 |
from typing import (
|
|
|
21 |
Callable,
|
22 |
Dict,
|
23 |
Optional,
|
|
|
|
|
24 |
Union,
|
25 |
NamedTuple,
|
26 |
)
|
@@ -172,6 +181,47 @@ def upload_to_hub(path, upload_repo, hf_path, oauth_token, runtime: Runtime):
|
|
172 |
|
173 |
print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
|
174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
def convert(
|
176 |
hf_path: str,
|
177 |
mlx_path: str = "mlx_model",
|
@@ -188,6 +238,8 @@ def convert(
|
|
188 |
skip_vision: bool = False, # mlx-vlm
|
189 |
trust_remote_code: bool = True, # mlx-vlm
|
190 |
) -> Runtime :
|
|
|
|
|
191 |
def mlx_lm_convert():
|
192 |
mlx_lm.convert(
|
193 |
hf_path=hf_path,
|
@@ -203,21 +255,44 @@ def convert(
|
|
203 |
)
|
204 |
|
205 |
def mlx_vlm_convert():
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
config = load_config(model_path)
|
222 |
model_type = config["model_type"]
|
223 |
model_type = MODEL_REMAPPING.get(model_type, model_type)
|
@@ -273,6 +348,7 @@ def process_model(model_id, q_method, oauth_token: gr.OAuthToken | None):
|
|
273 |
"llama.png",
|
274 |
)
|
275 |
except Exception as e:
|
|
|
276 |
return (f"Error: {e}", "error.png")
|
277 |
finally:
|
278 |
clear_hf_cache_space()
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import traceback
|
3 |
+
|
4 |
import os
|
5 |
import tempfile
|
6 |
import importlib.util
|
7 |
from enum import Enum
|
8 |
|
9 |
+
from contextlib import contextmanager, AbstractContextManager
|
10 |
+
from functools import wraps
|
11 |
+
|
12 |
os.environ["HF_HUB_CACHE"] = "cache"
|
13 |
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
14 |
import gradio as gr
|
|
|
24 |
|
25 |
from textwrap import dedent
|
26 |
from typing import (
|
27 |
+
Any,
|
28 |
Callable,
|
29 |
Dict,
|
30 |
Optional,
|
31 |
+
Tuple,
|
32 |
+
Type,
|
33 |
Union,
|
34 |
NamedTuple,
|
35 |
)
|
|
|
181 |
|
182 |
print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
|
183 |
|
184 |
+
|
185 |
+
@contextmanager
|
186 |
+
def patch_strict_default_methods_ctx() -> AbstractContextManager[Callable[[Any, str], None]]:
|
187 |
+
"""
|
188 |
+
Context manager to temporarily set the default value of the 'strict' arg to `False`
|
189 |
+
for specified class methods.
|
190 |
+
Does not affect explict `strict=True`.
|
191 |
+
|
192 |
+
(e.g. `def update(self, parameters: dict, strict: bool = True)`
|
193 |
+
becomes `def update(self, parameters: dict, strict: bool = False)`)
|
194 |
+
|
195 |
+
Typical usage:
|
196 |
+
|
197 |
+
with patch_strict_default_methods_ctx() as patch:
|
198 |
+
patch(Foo, "bar")
|
199 |
+
patch(Foo, "baz")
|
200 |
+
patch(Bar, "foo")
|
201 |
+
# Patched methods active here
|
202 |
+
# Originals restored here
|
203 |
+
"""
|
204 |
+
|
205 |
+
originals: Dict[Tuple[Type[Any], str], Callable] = {}
|
206 |
+
|
207 |
+
def patch(cls: Any, method_name: str):
|
208 |
+
method = getattr(cls, method_name)
|
209 |
+
originals[(cls, method_name)] = method
|
210 |
+
|
211 |
+
@wraps(method)
|
212 |
+
def wrapper(self, *args, strict=False, **kwargs):
|
213 |
+
return method(self, *args, strict=strict, **kwargs)
|
214 |
+
|
215 |
+
setattr(cls, method_name, wrapper)
|
216 |
+
|
217 |
+
try:
|
218 |
+
yield patch
|
219 |
+
finally:
|
220 |
+
# Restore all patched methods
|
221 |
+
for (cls, method_name), original in originals.items():
|
222 |
+
setattr(cls, method_name, original)
|
223 |
+
originals.clear()
|
224 |
+
|
225 |
def convert(
|
226 |
hf_path: str,
|
227 |
mlx_path: str = "mlx_model",
|
|
|
238 |
skip_vision: bool = False, # mlx-vlm
|
239 |
trust_remote_code: bool = True, # mlx-vlm
|
240 |
) -> Runtime :
|
241 |
+
model_path = get_model_path(hf_path, revision=revision)
|
242 |
+
|
243 |
def mlx_lm_convert():
|
244 |
mlx_lm.convert(
|
245 |
hf_path=hf_path,
|
|
|
255 |
)
|
256 |
|
257 |
def mlx_vlm_convert():
|
258 |
+
# try:
|
259 |
+
# new_model_path = remove_extra_parameters_from_weights(model_path=model_path)
|
260 |
+
# print(f"{new_model_path} exists: {Path(new_model_path).exists()}")
|
261 |
+
# except Exception as e:
|
262 |
+
# new_model_path = model_path
|
263 |
+
# print(f"Unexpected error while trying to fix model weights: {e}")
|
264 |
+
# traceback.print_exc()
|
265 |
+
# raise e
|
266 |
+
|
267 |
+
def _mlx_vlm_convert():
|
268 |
+
mlx_vlm.convert(
|
269 |
+
#hf_path=new_model_path,
|
270 |
+
hf_path=hf_path,
|
271 |
+
mlx_path=mlx_path,
|
272 |
+
quantize=quantize,
|
273 |
+
q_group_size=q_group_size,
|
274 |
+
q_bits=q_bits,
|
275 |
+
dtype=dtype,
|
276 |
+
upload_repo=upload_repo,
|
277 |
+
revision=revision,
|
278 |
+
dequantize=dequantize,
|
279 |
+
skip_vision=skip_vision,
|
280 |
+
trust_remote_code=trust_remote_code,
|
281 |
+
)
|
282 |
+
|
283 |
+
try:
|
284 |
+
_mlx_vlm_convert()
|
285 |
+
except ValueError as e:
|
286 |
+
print(e)
|
287 |
+
print(f"Error converting, try again with strict = False")
|
288 |
+
with patch_strict_default_methods_ctx() as patch:
|
289 |
+
import mlx.nn as n
|
290 |
+
patch(nn.Module, "load_weights")
|
291 |
+
patch(nn.Module, "update")
|
292 |
+
patch(nn.Module, "update_modules")
|
293 |
+
# patched strict=False by default, try again
|
294 |
+
_mlx_vlm_convert()
|
295 |
+
|
296 |
config = load_config(model_path)
|
297 |
model_type = config["model_type"]
|
298 |
model_type = MODEL_REMAPPING.get(model_type, model_type)
|
|
|
348 |
"llama.png",
|
349 |
)
|
350 |
except Exception as e:
|
351 |
+
traceback.print_exc()
|
352 |
return (f"Error: {e}", "error.png")
|
353 |
finally:
|
354 |
clear_hf_cache_space()
|