|
|
|
|
|
""" |
|
|
唐卡修复AI模型上传脚本 |
|
|
上传模型到Hugging Face |
|
|
Developed by Wangchuk Mind |
|
|
""" |
|
|
|
|
|
from huggingface_hub import HfApi, create_repo |
|
|
import os |
|
|
from pathlib import Path |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
REPO_ID = "Wangchuk1376/ThangkaModels" |
|
|
SCRIPT_DIR = Path(__file__).parent |
|
|
LOCAL_DIR = SCRIPT_DIR |
|
|
|
|
|
|
|
|
api = HfApi() |
|
|
|
|
|
|
|
|
def print_banner(): |
|
|
print("╔══════════════════════════════════════════════════════════════╗") |
|
|
print("║ ║") |
|
|
print("║ 🎨 唐卡修复AI模型 - Hugging Face上传工具 🎨 ║") |
|
|
print("║ ║") |
|
|
print(f"║ 上传到: {REPO_ID:38} ║") |
|
|
print("║ ║") |
|
|
print("╚══════════════════════════════════════════════════════════════╝") |
|
|
print() |
|
|
|
|
|
|
|
|
def create_repository(): |
|
|
"""创建或验证仓库""" |
|
|
print("🔧 步骤1: 创建/验证仓库...") |
|
|
try: |
|
|
create_repo( |
|
|
repo_id=REPO_ID, |
|
|
repo_type="model", |
|
|
exist_ok=True, |
|
|
private=False |
|
|
) |
|
|
print(f"✅ 仓库 {REPO_ID} 已创建/验证") |
|
|
print(f"🌐 仓库地址: https://huggingface.co/{REPO_ID}") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"❌ 创建仓库失败: {e}") |
|
|
print() |
|
|
print("💡 请手动创建仓库:") |
|
|
print(f" 1. 访问 https://huggingface.co/new") |
|
|
print(f" 2. Owner: Wangchuk1376") |
|
|
print(f" 3. Model name: ThangkaModels") |
|
|
print(f" 4. License: MIT") |
|
|
print(f" 5. Visibility: Public") |
|
|
return False |
|
|
|
|
|
|
|
|
def upload_gitattributes(): |
|
|
"""上传Git LFS配置""" |
|
|
print("\n📝 步骤2: 上传.gitattributes...") |
|
|
|
|
|
gitattributes_content = """# 使用Git LFS跟踪大文件 |
|
|
*.safetensors filter=lfs diff=lfs merge=lfs -text |
|
|
*.pdparams filter=lfs diff=lfs merge=lfs -text |
|
|
*.bin filter=lfs diff=lfs merge=lfs -text |
|
|
*.ckpt filter=lfs diff=lfs merge=lfs -text |
|
|
*.pth filter=lfs diff=lfs merge=lfs -text |
|
|
*.h5 filter=lfs diff=lfs merge=lfs -text |
|
|
*.pb filter=lfs diff=lfs merge=lfs -text |
|
|
""" |
|
|
|
|
|
try: |
|
|
api.upload_file( |
|
|
path_or_fileobj=gitattributes_content.encode(), |
|
|
path_in_repo=".gitattributes", |
|
|
repo_id=REPO_ID, |
|
|
repo_type="model" |
|
|
) |
|
|
print("✅ .gitattributes 上传成功") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"⚠️ .gitattributes上传失败: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def upload_readme(): |
|
|
"""上传README文件""" |
|
|
print("\n📝 步骤3: 上传README...") |
|
|
|
|
|
readme_path = LOCAL_DIR / "README.md" |
|
|
if not readme_path.exists(): |
|
|
print("⚠️ README.md 不存在,跳过") |
|
|
return True |
|
|
|
|
|
try: |
|
|
api.upload_file( |
|
|
path_or_fileobj=str(readme_path), |
|
|
path_in_repo="README.md", |
|
|
repo_id=REPO_ID, |
|
|
repo_type="model" |
|
|
) |
|
|
print("✅ README.md 上传成功") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"❌ README上传失败: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def upload_single_file(file_path, repo_path): |
|
|
"""上传单个文件""" |
|
|
try: |
|
|
api.upload_file( |
|
|
path_or_fileobj=str(file_path), |
|
|
path_in_repo=repo_path, |
|
|
repo_id=REPO_ID, |
|
|
repo_type="model" |
|
|
) |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f" ❌ {repo_path}: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def upload_models_directory(): |
|
|
"""分批上传models目录""" |
|
|
print("\n📤 步骤4: 上传models目录...") |
|
|
print(" ⏳ 这可能需要较长时间,请耐心等待...") |
|
|
print() |
|
|
|
|
|
models_dir = LOCAL_DIR / "models" |
|
|
if not models_dir.exists(): |
|
|
print("⚠️ models目录不存在,跳过") |
|
|
return True |
|
|
|
|
|
|
|
|
all_files = [] |
|
|
for root, dirs, files in os.walk(models_dir): |
|
|
|
|
|
dirs[:] = [d for d in dirs if not d.startswith('.')] |
|
|
|
|
|
for file in files: |
|
|
if file.startswith('.'): |
|
|
continue |
|
|
|
|
|
file_path = Path(root) / file |
|
|
relative_path = file_path.relative_to(LOCAL_DIR) |
|
|
all_files.append((file_path, str(relative_path))) |
|
|
|
|
|
print(f" 📊 找到 {len(all_files)} 个文件") |
|
|
print() |
|
|
|
|
|
|
|
|
success_count = 0 |
|
|
fail_count = 0 |
|
|
|
|
|
for file_path, repo_path in tqdm(all_files, desc="上传进度"): |
|
|
if upload_single_file(file_path, repo_path): |
|
|
success_count += 1 |
|
|
else: |
|
|
fail_count += 1 |
|
|
|
|
|
print() |
|
|
print(f"✅ 成功上传: {success_count} 个文件") |
|
|
if fail_count > 0: |
|
|
print(f"⚠️ 失败: {fail_count} 个文件") |
|
|
|
|
|
return fail_count == 0 |
|
|
|
|
|
|
|
|
def upload_entire_folder(): |
|
|
"""上传整个文件夹 (一次性上传)""" |
|
|
print("\n📤 步骤4: 上传整个目录...") |
|
|
print(" ⏳ 这可能需要较长时间,请耐心等待...") |
|
|
|
|
|
try: |
|
|
api.upload_folder( |
|
|
folder_path=str(LOCAL_DIR), |
|
|
repo_id=REPO_ID, |
|
|
repo_type="model", |
|
|
ignore_patterns=[ |
|
|
".DS_Store", |
|
|
"*.pyc", |
|
|
"__pycache__", |
|
|
"*.sh", |
|
|
"*.py", |
|
|
"fix_upload_issues.md", |
|
|
".git", |
|
|
".gitignore" |
|
|
], |
|
|
multi_commits=True, |
|
|
multi_commits_verbose=True |
|
|
) |
|
|
print("✅ 所有文件上传成功!") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"❌ 上传失败: {e}") |
|
|
print() |
|
|
print("💡 建议:") |
|
|
print(" 1. 检查网络连接") |
|
|
print(" 2. 尝试分批上传") |
|
|
print(" 3. 使用Git LFS方式上传") |
|
|
return False |
|
|
|
|
|
|
|
|
def verify_upload(): |
|
|
"""验证上传结果""" |
|
|
print("\n🔍 步骤5: 验证上传...") |
|
|
|
|
|
try: |
|
|
|
|
|
info = api.repo_info(repo_id=REPO_ID, repo_type="model") |
|
|
print(f"✅ 仓库验证成功") |
|
|
print(f" 最后更新: {info.last_modified}") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"⚠️ 无法验证: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def show_completion(): |
|
|
"""显示完成信息""" |
|
|
print() |
|
|
print("╔══════════════════════════════════════════════════════════════╗") |
|
|
print("║ ║") |
|
|
print("║ 🎉 上传完成! 🎉 ║") |
|
|
print("║ ║") |
|
|
print("╚══════════════════════════════════════════════════════════════╝") |
|
|
print() |
|
|
print(f"📦 模型仓库: https://huggingface.co/{REPO_ID}") |
|
|
print() |
|
|
print("📚 使用方法:") |
|
|
print() |
|
|
print(" # 使用CLI下载") |
|
|
print(f" huggingface-cli download {REPO_ID} --local-dir ./models") |
|
|
print() |
|
|
print(" # 使用Python下载") |
|
|
print(" from huggingface_hub import snapshot_download") |
|
|
print(f' snapshot_download(repo_id="{REPO_ID}", local_dir="./models")') |
|
|
print() |
|
|
print(" # 在代码中使用") |
|
|
print(" from diffusion_paddle import load_model") |
|
|
print(f' pipe = load_model("{REPO_ID}/sd2.1_base_paddle")') |
|
|
print() |
|
|
print("🌟 别忘了给项目点星!") |
|
|
print(" GitHub: https://github.com/WangchukMind/thangka-restoration-ai") |
|
|
print() |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""主函数""" |
|
|
print_banner() |
|
|
|
|
|
|
|
|
try: |
|
|
user = api.whoami() |
|
|
print(f"👤 当前用户: {user['name']}") |
|
|
print() |
|
|
except Exception as e: |
|
|
print("❌ 未登录Hugging Face") |
|
|
print() |
|
|
print("请先登录:") |
|
|
print(" hf auth login") |
|
|
print() |
|
|
return |
|
|
|
|
|
|
|
|
if not create_repository(): |
|
|
print() |
|
|
print("⚠️ 请先手动创建仓库,然后重新运行此脚本") |
|
|
return |
|
|
|
|
|
|
|
|
upload_gitattributes() |
|
|
|
|
|
|
|
|
upload_readme() |
|
|
|
|
|
|
|
|
|
|
|
print() |
|
|
print("请选择上传方式:") |
|
|
print(" 1. 分批上传 (推荐,更稳定)") |
|
|
print(" 2. 一次性上传 (更快,但可能失败)") |
|
|
|
|
|
try: |
|
|
choice = input("\n请输入选择 (1/2) [1]: ").strip() or "1" |
|
|
|
|
|
if choice == "1": |
|
|
upload_models_directory() |
|
|
else: |
|
|
upload_entire_folder() |
|
|
except KeyboardInterrupt: |
|
|
print("\n\n⚠️ 上传已取消") |
|
|
return |
|
|
|
|
|
|
|
|
verify_upload() |
|
|
|
|
|
|
|
|
show_completion() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
try: |
|
|
main() |
|
|
except KeyboardInterrupt: |
|
|
print("\n\n⚠️ 程序已中断") |
|
|
except Exception as e: |
|
|
print(f"\n❌ 发生错误: {e}") |
|
|
print() |
|
|
print("请查看 fix_upload_issues.md 了解更多故障排查方法") |
|
|
|
|
|
|