jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0
#
import glob
import importlib
import inspect
import logging
import os
import re
import sys
from collections.abc import Iterable
from typing import Optional, Union
def _transform_changelog(path_in: str, path_out: str) -> None:
"""Adjust changelog titles so not to be duplicated.
Args:
path_in: input MD file
path_out: output also MD file
"""
with open(path_in) as fp:
chlog_lines = fp.readlines()
# enrich short subsub-titles to be unique
chlog_ver = ""
for i, ln in enumerate(chlog_lines):
if ln.startswith("## "):
chlog_ver = ln[2:].split("-")[0].strip()
elif ln.startswith("### "):
ln = ln.replace("###", f"### {chlog_ver} -")
chlog_lines[i] = ln
with open(path_out, "w") as fp:
fp.writelines(chlog_lines)
def _linkcode_resolve(
domain: str,
info: dict,
github_user: str,
github_repo: str,
main_branch: str = "master",
stable_branch: str = "release/stable",
) -> str:
def find_source() -> tuple[str, int, int]:
# try to find the file and line number, based on code from numpy:
# https://github.com/numpy/numpy/blob/master/doc/source/conf.py#L286
obj = sys.modules[info["module"]]
for part in info["fullname"].split("."):
obj = getattr(obj, part)
fname = str(inspect.getsourcefile(obj))
# https://github.com/rtfd/readthedocs.org/issues/5735
if any(s in fname for s in ("readthedocs", "rtfd", "checkouts")):
# /home/docs/checkouts/readthedocs.org/user_builds/pytorch_lightning/checkouts/
# devel/pytorch_lightning/utilities/cls_experiment.py#L26-L176
path_top = os.path.abspath(os.path.join("..", "..", ".."))
fname = str(os.path.relpath(fname, start=path_top))
else:
# Local build, imitate master
fname = f"{main_branch}/{os.path.relpath(fname, start=os.path.abspath('..'))}"
source, line_start = inspect.getsourcelines(obj)
return fname, line_start, line_start + len(source) - 1
if domain != "py" or not info["module"]:
return ""
try:
filename = "%s#L%d-L%d" % find_source() # noqa: UP031
except Exception:
filename = info["module"].replace(".", "/") + ".py"
# import subprocess
# tag = subprocess.Popen(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE,
# universal_newlines=True).communicate()[0][:-1]
branch = filename.split("/")[0]
# do mapping from latest tags to master
branch = {"latest": main_branch, "stable": stable_branch}.get(branch, branch)
filename = "/".join([branch] + filename.split("/")[1:])
return f"https://github.com/{github_user}/{github_repo}/blob/{filename}"
def _load_pypi_versions(package_name: str) -> list[str]:
"""Load the versions of the package from PyPI.
>>> _load_pypi_versions("numpy") # doctest: +ELLIPSIS
['0.9.6', '0.9.8', '1.0', ...]
>>> _load_pypi_versions("scikit-learn") # doctest: +ELLIPSIS
['0.9', '0.10', '0.11', '0.12', ...]
"""
from distutils.version import LooseVersion
import requests
url = f"https://pypi.org/pypi/{package_name}/json"
data = requests.get(url, timeout=10).json()
versions = data["releases"].keys()
# filter all version which include only numbers and dots
versions = {k for k in versions if re.match(r"^\d+(\.\d+)*$", k)}
return sorted(versions, key=LooseVersion)
def _update_link_based_imported_package(link: str, pkg_ver: str, version_digits: Optional[int]) -> str:
"""Adjust the linked external docs to be local.
Args:
link: the source link to be replaced
pkg_ver: the target link to be replaced, if ``{package.version}`` is included it will be replaced accordingly
version_digits: for semantic versioning, how many digits to be considered
"""
pkg_att = pkg_ver.split(".")
try:
ver = _load_pypi_versions(pkg_att[0])[-1]
except Exception:
# load the package with all additional sub-modules
module = importlib.import_module(".".join(pkg_att[:-1]))
# load the attribute
ver = getattr(module, pkg_att[0])
# drop any additional context after `+`
ver = ver.split("+")[0]
# crop the version to the number of digits
ver = ".".join(ver.split(".")[:version_digits])
# replace the version
return link.replace(f"{{{pkg_ver}}}", ver)
def adjust_linked_external_docs(
source_link: str,
target_link: str,
browse_folder: Union[str, Iterable[str]],
file_extensions: Iterable[str] = (".rst", ".py"),
version_digits: int = 2,
) -> None:
r"""Adjust the linked external docs to be local.
Args:
source_link: the link to be replaced
target_link: the link to be replaced, if ``{package.version}`` is included it will be replaced accordingly
browse_folder: the location of the browsable folder
file_extensions: what kind of files shall be scanned
version_digits: for semantic versioning, how many digits to be considered
Examples:
>>> adjust_linked_external_docs(
... "https://numpy.org/doc/stable/",
... "https://numpy.org/doc/{numpy.__version__}/",
... "docs/source",
... )
"""
list_files = []
if isinstance(browse_folder, str):
browse_folder = [browse_folder]
for folder in browse_folder:
for ext in file_extensions:
list_files += glob.glob(os.path.join(folder, "**", f"*{ext}"), recursive=True)
if not list_files:
logging.warning(f'No files were listed in folder "{browse_folder}" and pattern "{file_extensions}"')
return
# find the expression for package version in {} brackets if any, use re to find it
pkg_ver_all = re.findall(r"{(.+)}", target_link)
for pkg_ver in pkg_ver_all:
target_link = _update_link_based_imported_package(target_link, pkg_ver, version_digits)
# replace the source link with target link
for fpath in set(list_files):
with open(fpath, encoding="UTF-8") as fopen:
lines = fopen.readlines()
found, skip = False, False
for i, ln in enumerate(lines):
# prevent the replacement its own function calls
if f"{adjust_linked_external_docs.__name__}(" in ln:
skip = True
if not skip and source_link in ln:
# replace the link if any found
lines[i] = ln.replace(source_link, target_link)
# record the found link for later write file
found = True
if skip and ")" in ln:
skip = False
if not found:
continue
logging.debug(f'links adjusting in {fpath}: "{source_link}" -> "{target_link}"')
with open(fpath, "w", encoding="UTF-8") as fw:
fw.writelines(lines)