|
from typing import Optional |
|
|
|
import requests |
|
from huggingface_hub import Discussion, HfApi, get_repo_discussions |
|
|
|
from .utils import cached_file, http_user_agent, logging |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
def previous_pr(api: HfApi, model_id: str, pr_title: str, token: str) -> Optional["Discussion"]: |
|
main_commit = api.list_repo_commits(model_id, token=token)[0].commit_id |
|
for discussion in get_repo_discussions(repo_id=model_id, token=token): |
|
if discussion.title == pr_title and discussion.status == "open" and discussion.is_pull_request: |
|
commits = api.list_repo_commits(model_id, revision=discussion.git_reference, token=token) |
|
|
|
if main_commit == commits[1].commit_id: |
|
return discussion |
|
return None |
|
|
|
|
|
def spawn_conversion(token: str, private: bool, model_id: str): |
|
logger.info("Attempting to convert .bin model on the fly to safetensors.") |
|
|
|
safetensors_convert_space_url = "https://safetensors-convert.hf.space" |
|
sse_url = f"{safetensors_convert_space_url}/call/run" |
|
|
|
def start(_sse_connection): |
|
for line in _sse_connection.iter_lines(): |
|
line = line.decode() |
|
if line.startswith("event:"): |
|
status = line[7:] |
|
logger.debug(f"Safetensors conversion status: {status}") |
|
|
|
if status == "complete": |
|
return |
|
elif status == "heartbeat": |
|
logger.debug("Heartbeat") |
|
else: |
|
logger.debug(f"Unknown status {status}") |
|
else: |
|
logger.debug(line) |
|
|
|
data = {"data": [model_id, private, token]} |
|
|
|
result = requests.post(sse_url, stream=True, json=data).json() |
|
event_id = result["event_id"] |
|
|
|
with requests.get(f"{sse_url}/{event_id}", stream=True) as sse_connection: |
|
try: |
|
logger.debug("Spawning safetensors automatic conversion.") |
|
start(sse_connection) |
|
except Exception as e: |
|
logger.warning(f"Error during conversion: {repr(e)}") |
|
|
|
|
|
def get_conversion_pr_reference(api: HfApi, model_id: str, **kwargs): |
|
private = api.model_info(model_id).private |
|
|
|
logger.info("Attempting to create safetensors variant") |
|
pr_title = "Adding `safetensors` variant of this model" |
|
token = kwargs.get("token") |
|
|
|
|
|
|
|
|
|
pr = previous_pr(api, model_id, pr_title, token=token) |
|
|
|
if pr is None or (not private and pr.author != "SFConvertBot"): |
|
spawn_conversion(token, private, model_id) |
|
pr = previous_pr(api, model_id, pr_title, token=token) |
|
else: |
|
logger.info("Safetensors PR exists") |
|
|
|
sha = f"refs/pr/{pr.num}" |
|
|
|
return sha |
|
|
|
|
|
def auto_conversion(pretrained_model_name_or_path: str, ignore_errors_during_conversion=False, **cached_file_kwargs): |
|
try: |
|
api = HfApi(token=cached_file_kwargs.get("token"), headers={"user-agent": http_user_agent()}) |
|
sha = get_conversion_pr_reference(api, pretrained_model_name_or_path, **cached_file_kwargs) |
|
|
|
if sha is None: |
|
return None, None |
|
cached_file_kwargs["revision"] = sha |
|
del cached_file_kwargs["_commit_hash"] |
|
|
|
|
|
|
|
sharded = api.file_exists( |
|
pretrained_model_name_or_path, |
|
"model.safetensors.index.json", |
|
revision=sha, |
|
token=cached_file_kwargs.get("token"), |
|
) |
|
filename = "model.safetensors.index.json" if sharded else "model.safetensors" |
|
|
|
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) |
|
return resolved_archive_file, sha, sharded |
|
except Exception as e: |
|
if not ignore_errors_during_conversion: |
|
raise e |
|
|