add sanity checks
Browse files
custom_generate/generate.py
CHANGED
@@ -203,7 +203,8 @@ def generate(model, window_length=256, num_sink_tokens=4, **kwargs):
|
|
203 |
or getattr(default_global_generation_config, arg) == getattr(generation_config, arg)
|
204 |
)
|
205 |
)
|
206 |
-
|
|
|
207 |
raise ValueError(
|
208 |
f"`{arg}` is set, but it's not supported in this custom generate function. List of "
|
209 |
f"unsupported arguments: {UNSUPPORTED_GENERATION_ARGS}"
|
|
|
203 |
or getattr(default_global_generation_config, arg) == getattr(generation_config, arg)
|
204 |
)
|
205 |
)
|
206 |
+
kwargs_has_arg = arg in kwargs and kwargs[arg] is not None
|
207 |
+
if kwargs_has_arg or has_custom_gen_config_arg:
|
208 |
raise ValueError(
|
209 |
f"`{arg}` is set, but it's not supported in this custom generate function. List of "
|
210 |
f"unsupported arguments: {UNSUPPORTED_GENERATION_ARGS}"
|