#!/usr/bin/env python3 """ 唐卡修复AI模型上传脚本 上传模型到Hugging Face Developed by Wangchuk Mind """ from huggingface_hub import HfApi, create_repo import os from pathlib import Path from tqdm import tqdm # ===== 配置 ===== REPO_ID = "Wangchuk1376/ThangkaModels" SCRIPT_DIR = Path(__file__).parent LOCAL_DIR = SCRIPT_DIR # 初始化API api = HfApi() # ===== 打印横幅 ===== def print_banner(): print("╔══════════════════════════════════════════════════════════════╗") print("║ ║") print("║ 🎨 唐卡修复AI模型 - Hugging Face上传工具 🎨 ║") print("║ ║") print(f"║ 上传到: {REPO_ID:38} ║") print("║ ║") print("╚══════════════════════════════════════════════════════════════╝") print() # ===== 创建仓库 ===== def create_repository(): """创建或验证仓库""" print("🔧 步骤1: 创建/验证仓库...") try: create_repo( repo_id=REPO_ID, repo_type="model", exist_ok=True, private=False ) print(f"✅ 仓库 {REPO_ID} 已创建/验证") print(f"🌐 仓库地址: https://huggingface.co/{REPO_ID}") return True except Exception as e: print(f"❌ 创建仓库失败: {e}") print() print("💡 请手动创建仓库:") print(f" 1. 访问 https://huggingface.co/new") print(f" 2. Owner: Wangchuk1376") print(f" 3. Model name: ThangkaModels") print(f" 4. License: MIT") print(f" 5. Visibility: Public") return False # ===== 上传.gitattributes ===== def upload_gitattributes(): """上传Git LFS配置""" print("\n📝 步骤2: 上传.gitattributes...") gitattributes_content = """# 使用Git LFS跟踪大文件 *.safetensors filter=lfs diff=lfs merge=lfs -text *.pdparams filter=lfs diff=lfs merge=lfs -text *.bin filter=lfs diff=lfs merge=lfs -text *.ckpt filter=lfs diff=lfs merge=lfs -text *.pth filter=lfs diff=lfs merge=lfs -text *.h5 filter=lfs diff=lfs merge=lfs -text *.pb filter=lfs diff=lfs merge=lfs -text """ try: api.upload_file( path_or_fileobj=gitattributes_content.encode(), path_in_repo=".gitattributes", repo_id=REPO_ID, repo_type="model" ) print("✅ .gitattributes 上传成功") return True except Exception as e: print(f"⚠️ .gitattributes上传失败: {e}") return False # ===== 上传README ===== def upload_readme(): """上传README文件""" print("\n📝 步骤3: 上传README...") readme_path = LOCAL_DIR / "README.md" if not readme_path.exists(): print("⚠️ README.md 不存在,跳过") return True try: api.upload_file( path_or_fileobj=str(readme_path), path_in_repo="README.md", repo_id=REPO_ID, repo_type="model" ) print("✅ README.md 上传成功") return True except Exception as e: print(f"❌ README上传失败: {e}") return False # ===== 上传单个文件 ===== def upload_single_file(file_path, repo_path): """上传单个文件""" try: api.upload_file( path_or_fileobj=str(file_path), path_in_repo=repo_path, repo_id=REPO_ID, repo_type="model" ) return True except Exception as e: print(f" ❌ {repo_path}: {e}") return False # ===== 上传models目录 ===== def upload_models_directory(): """分批上传models目录""" print("\n📤 步骤4: 上传models目录...") print(" ⏳ 这可能需要较长时间,请耐心等待...") print() models_dir = LOCAL_DIR / "models" if not models_dir.exists(): print("⚠️ models目录不存在,跳过") return True # 收集所有文件 all_files = [] for root, dirs, files in os.walk(models_dir): # 跳过隐藏目录 dirs[:] = [d for d in dirs if not d.startswith('.')] for file in files: if file.startswith('.'): continue file_path = Path(root) / file relative_path = file_path.relative_to(LOCAL_DIR) all_files.append((file_path, str(relative_path))) print(f" 📊 找到 {len(all_files)} 个文件") print() # 使用进度条上传 success_count = 0 fail_count = 0 for file_path, repo_path in tqdm(all_files, desc="上传进度"): if upload_single_file(file_path, repo_path): success_count += 1 else: fail_count += 1 print() print(f"✅ 成功上传: {success_count} 个文件") if fail_count > 0: print(f"⚠️ 失败: {fail_count} 个文件") return fail_count == 0 # ===== 使用upload_folder (备选方案) ===== def upload_entire_folder(): """上传整个文件夹 (一次性上传)""" print("\n📤 步骤4: 上传整个目录...") print(" ⏳ 这可能需要较长时间,请耐心等待...") try: api.upload_folder( folder_path=str(LOCAL_DIR), repo_id=REPO_ID, repo_type="model", ignore_patterns=[ ".DS_Store", "*.pyc", "__pycache__", "*.sh", "*.py", "fix_upload_issues.md", ".git", ".gitignore" ], multi_commits=True, # 大文件夹分批上传 multi_commits_verbose=True ) print("✅ 所有文件上传成功!") return True except Exception as e: print(f"❌ 上传失败: {e}") print() print("💡 建议:") print(" 1. 检查网络连接") print(" 2. 尝试分批上传") print(" 3. 使用Git LFS方式上传") return False # ===== 验证上传 ===== def verify_upload(): """验证上传结果""" print("\n🔍 步骤5: 验证上传...") try: # 获取仓库信息 info = api.repo_info(repo_id=REPO_ID, repo_type="model") print(f"✅ 仓库验证成功") print(f" 最后更新: {info.last_modified}") return True except Exception as e: print(f"⚠️ 无法验证: {e}") return False # ===== 显示完成信息 ===== def show_completion(): """显示完成信息""" print() print("╔══════════════════════════════════════════════════════════════╗") print("║ ║") print("║ 🎉 上传完成! 🎉 ║") print("║ ║") print("╚══════════════════════════════════════════════════════════════╝") print() print(f"📦 模型仓库: https://huggingface.co/{REPO_ID}") print() print("📚 使用方法:") print() print(" # 使用CLI下载") print(f" huggingface-cli download {REPO_ID} --local-dir ./models") print() print(" # 使用Python下载") print(" from huggingface_hub import snapshot_download") print(f' snapshot_download(repo_id="{REPO_ID}", local_dir="./models")') print() print(" # 在代码中使用") print(" from diffusion_paddle import load_model") print(f' pipe = load_model("{REPO_ID}/sd2.1_base_paddle")') print() print("🌟 别忘了给项目点星!") print(" GitHub: https://github.com/WangchukMind/thangka-restoration-ai") print() # ===== 主函数 ===== def main(): """主函数""" print_banner() # 检查登录状态 try: user = api.whoami() print(f"👤 当前用户: {user['name']}") print() except Exception as e: print("❌ 未登录Hugging Face") print() print("请先登录:") print(" hf auth login") print() return # 步骤1: 创建仓库 if not create_repository(): print() print("⚠️ 请先手动创建仓库,然后重新运行此脚本") return # 步骤2: 上传.gitattributes upload_gitattributes() # 步骤3: 上传README upload_readme() # 步骤4: 上传models目录 # 选择上传方式 print() print("请选择上传方式:") print(" 1. 分批上传 (推荐,更稳定)") print(" 2. 一次性上传 (更快,但可能失败)") try: choice = input("\n请输入选择 (1/2) [1]: ").strip() or "1" if choice == "1": upload_models_directory() else: upload_entire_folder() except KeyboardInterrupt: print("\n\n⚠️ 上传已取消") return # 步骤5: 验证上传 verify_upload() # 显示完成信息 show_completion() if __name__ == "__main__": try: main() except KeyboardInterrupt: print("\n\n⚠️ 程序已中断") except Exception as e: print(f"\n❌ 发生错误: {e}") print() print("请查看 fix_upload_issues.md 了解更多故障排查方法")