ThangkaModels / upload_models.py
Wangchuk1376's picture
Upload folder using huggingface_hub
36bf676 verified
#!/usr/bin/env python3
"""
唐卡修复AI模型上传脚本
上传模型到Hugging Face
Developed by Wangchuk Mind
"""
from huggingface_hub import HfApi, create_repo
import os
from pathlib import Path
from tqdm import tqdm
# ===== 配置 =====
REPO_ID = "Wangchuk1376/ThangkaModels"
SCRIPT_DIR = Path(__file__).parent
LOCAL_DIR = SCRIPT_DIR
# 初始化API
api = HfApi()
# ===== 打印横幅 =====
def print_banner():
print("╔══════════════════════════════════════════════════════════════╗")
print("║ ║")
print("║ 🎨 唐卡修复AI模型 - Hugging Face上传工具 🎨 ║")
print("║ ║")
print(f"║ 上传到: {REPO_ID:38} ║")
print("║ ║")
print("╚══════════════════════════════════════════════════════════════╝")
print()
# ===== 创建仓库 =====
def create_repository():
"""创建或验证仓库"""
print("🔧 步骤1: 创建/验证仓库...")
try:
create_repo(
repo_id=REPO_ID,
repo_type="model",
exist_ok=True,
private=False
)
print(f"✅ 仓库 {REPO_ID} 已创建/验证")
print(f"🌐 仓库地址: https://huggingface.co/{REPO_ID}")
return True
except Exception as e:
print(f"❌ 创建仓库失败: {e}")
print()
print("💡 请手动创建仓库:")
print(f" 1. 访问 https://huggingface.co/new")
print(f" 2. Owner: Wangchuk1376")
print(f" 3. Model name: ThangkaModels")
print(f" 4. License: MIT")
print(f" 5. Visibility: Public")
return False
# ===== 上传.gitattributes =====
def upload_gitattributes():
"""上传Git LFS配置"""
print("\n📝 步骤2: 上传.gitattributes...")
gitattributes_content = """# 使用Git LFS跟踪大文件
*.safetensors filter=lfs diff=lfs merge=lfs -text
*.pdparams filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text
*.ckpt filter=lfs diff=lfs merge=lfs -text
*.pth filter=lfs diff=lfs merge=lfs -text
*.h5 filter=lfs diff=lfs merge=lfs -text
*.pb filter=lfs diff=lfs merge=lfs -text
"""
try:
api.upload_file(
path_or_fileobj=gitattributes_content.encode(),
path_in_repo=".gitattributes",
repo_id=REPO_ID,
repo_type="model"
)
print("✅ .gitattributes 上传成功")
return True
except Exception as e:
print(f"⚠️ .gitattributes上传失败: {e}")
return False
# ===== 上传README =====
def upload_readme():
"""上传README文件"""
print("\n📝 步骤3: 上传README...")
readme_path = LOCAL_DIR / "README.md"
if not readme_path.exists():
print("⚠️ README.md 不存在,跳过")
return True
try:
api.upload_file(
path_or_fileobj=str(readme_path),
path_in_repo="README.md",
repo_id=REPO_ID,
repo_type="model"
)
print("✅ README.md 上传成功")
return True
except Exception as e:
print(f"❌ README上传失败: {e}")
return False
# ===== 上传单个文件 =====
def upload_single_file(file_path, repo_path):
"""上传单个文件"""
try:
api.upload_file(
path_or_fileobj=str(file_path),
path_in_repo=repo_path,
repo_id=REPO_ID,
repo_type="model"
)
return True
except Exception as e:
print(f" ❌ {repo_path}: {e}")
return False
# ===== 上传models目录 =====
def upload_models_directory():
"""分批上传models目录"""
print("\n📤 步骤4: 上传models目录...")
print(" ⏳ 这可能需要较长时间,请耐心等待...")
print()
models_dir = LOCAL_DIR / "models"
if not models_dir.exists():
print("⚠️ models目录不存在,跳过")
return True
# 收集所有文件
all_files = []
for root, dirs, files in os.walk(models_dir):
# 跳过隐藏目录
dirs[:] = [d for d in dirs if not d.startswith('.')]
for file in files:
if file.startswith('.'):
continue
file_path = Path(root) / file
relative_path = file_path.relative_to(LOCAL_DIR)
all_files.append((file_path, str(relative_path)))
print(f" 📊 找到 {len(all_files)} 个文件")
print()
# 使用进度条上传
success_count = 0
fail_count = 0
for file_path, repo_path in tqdm(all_files, desc="上传进度"):
if upload_single_file(file_path, repo_path):
success_count += 1
else:
fail_count += 1
print()
print(f"✅ 成功上传: {success_count} 个文件")
if fail_count > 0:
print(f"⚠️ 失败: {fail_count} 个文件")
return fail_count == 0
# ===== 使用upload_folder (备选方案) =====
def upload_entire_folder():
"""上传整个文件夹 (一次性上传)"""
print("\n📤 步骤4: 上传整个目录...")
print(" ⏳ 这可能需要较长时间,请耐心等待...")
try:
api.upload_folder(
folder_path=str(LOCAL_DIR),
repo_id=REPO_ID,
repo_type="model",
ignore_patterns=[
".DS_Store",
"*.pyc",
"__pycache__",
"*.sh",
"*.py",
"fix_upload_issues.md",
".git",
".gitignore"
],
multi_commits=True, # 大文件夹分批上传
multi_commits_verbose=True
)
print("✅ 所有文件上传成功!")
return True
except Exception as e:
print(f"❌ 上传失败: {e}")
print()
print("💡 建议:")
print(" 1. 检查网络连接")
print(" 2. 尝试分批上传")
print(" 3. 使用Git LFS方式上传")
return False
# ===== 验证上传 =====
def verify_upload():
"""验证上传结果"""
print("\n🔍 步骤5: 验证上传...")
try:
# 获取仓库信息
info = api.repo_info(repo_id=REPO_ID, repo_type="model")
print(f"✅ 仓库验证成功")
print(f" 最后更新: {info.last_modified}")
return True
except Exception as e:
print(f"⚠️ 无法验证: {e}")
return False
# ===== 显示完成信息 =====
def show_completion():
"""显示完成信息"""
print()
print("╔══════════════════════════════════════════════════════════════╗")
print("║ ║")
print("║ 🎉 上传完成! 🎉 ║")
print("║ ║")
print("╚══════════════════════════════════════════════════════════════╝")
print()
print(f"📦 模型仓库: https://huggingface.co/{REPO_ID}")
print()
print("📚 使用方法:")
print()
print(" # 使用CLI下载")
print(f" huggingface-cli download {REPO_ID} --local-dir ./models")
print()
print(" # 使用Python下载")
print(" from huggingface_hub import snapshot_download")
print(f' snapshot_download(repo_id="{REPO_ID}", local_dir="./models")')
print()
print(" # 在代码中使用")
print(" from diffusion_paddle import load_model")
print(f' pipe = load_model("{REPO_ID}/sd2.1_base_paddle")')
print()
print("🌟 别忘了给项目点星!")
print(" GitHub: https://github.com/WangchukMind/thangka-restoration-ai")
print()
# ===== 主函数 =====
def main():
"""主函数"""
print_banner()
# 检查登录状态
try:
user = api.whoami()
print(f"👤 当前用户: {user['name']}")
print()
except Exception as e:
print("❌ 未登录Hugging Face")
print()
print("请先登录:")
print(" hf auth login")
print()
return
# 步骤1: 创建仓库
if not create_repository():
print()
print("⚠️ 请先手动创建仓库,然后重新运行此脚本")
return
# 步骤2: 上传.gitattributes
upload_gitattributes()
# 步骤3: 上传README
upload_readme()
# 步骤4: 上传models目录
# 选择上传方式
print()
print("请选择上传方式:")
print(" 1. 分批上传 (推荐,更稳定)")
print(" 2. 一次性上传 (更快,但可能失败)")
try:
choice = input("\n请输入选择 (1/2) [1]: ").strip() or "1"
if choice == "1":
upload_models_directory()
else:
upload_entire_folder()
except KeyboardInterrupt:
print("\n\n⚠️ 上传已取消")
return
# 步骤5: 验证上传
verify_upload()
# 显示完成信息
show_completion()
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\n\n⚠️ 程序已中断")
except Exception as e:
print(f"\n❌ 发生错误: {e}")
print()
print("请查看 fix_upload_issues.md 了解更多故障排查方法")