File size: 3,999 Bytes
293ab16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import aiohttp
import asyncio
import requests
from tqdm import tqdm
from typing import Optional
from concurrent.futures import ThreadPoolExecutor

# =============================
# Configuration
# =============================
MODEL_DIR = os.getenv("MODEL_DIR", "models")
DEFAULT_MODEL_URL = os.getenv(
    "DEFAULT_MODEL_URL",
    "https://huggingface.co/TheBloke/CapybaraHermes-2.5-Mistral-7B-GGUF/resolve/main/capybarahermes-2.5-mistral-7b.Q5_K_S.gguf"
)
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")  # Optional: for gated repos

# =============================
# Utilities
# =============================

def extract_filename(url: str) -> str:
    return url.split("/")[-1]


def list_available_models() -> list:
    """List all .gguf models in the model directory."""
    if not os.path.isdir(MODEL_DIR):
        return []
    return [f for f in os.listdir(MODEL_DIR) if f.endswith(".gguf")]


def get_model_path(filename: str) -> str:
    """Get full path for a given model filename."""
    path = os.path.join(MODEL_DIR, filename)
    if not os.path.exists(path):
        raise FileNotFoundError(f"⚠️ Model not found at {path}")
    return path

# =============================
# Sync Model Download (fallback)
# =============================

def download_model_if_missing(model_url: str = DEFAULT_MODEL_URL) -> Optional[str]:
    os.makedirs(MODEL_DIR, exist_ok=True)
    model_filename = extract_filename(model_url)
    model_path = os.path.join(MODEL_DIR, model_filename)

    if os.path.exists(model_path):
        print(f"✅ Model already exists: {model_path}")
        return model_path

    headers = {}
    if HUGGINGFACE_TOKEN:
        headers["Authorization"] = f"Bearer {HUGGINGFACE_TOKEN}"

    try:
        print(f"⬇️ Downloading model from {model_url}")
        with requests.get(model_url, headers=headers, stream=True, timeout=60) as r:
            r.raise_for_status()
            total = int(r.headers.get('content-length', 0))
            with open(model_path, 'wb') as file, tqdm(
                total=total, unit='B', unit_scale=True, desc=model_filename
            ) as pbar:
                for chunk in r.iter_content(chunk_size=8192):
                    file.write(chunk)
                    pbar.update(len(chunk))
        print(f"✅ Download complete: {model_path}")
        return model_path

    except requests.exceptions.RequestException as e:
        print(f"❌ Failed to download model: {e}")
        return None

# =============================
# Optional Async Download (advanced)
# =============================

async def _async_download_file(url: str, output_path: str, token: Optional[str] = None):
    headers = {}
    if token:
        headers["Authorization"] = f"Bearer {token}"

    async with aiohttp.ClientSession(headers=headers) as session:
        async with session.get(url) as response:
            response.raise_for_status()
            total = int(response.headers.get("Content-Length", 0))
            with open(output_path, "wb") as f, tqdm(
                total=total, unit="B", unit_scale=True, desc=os.path.basename(output_path)
            ) as pbar:
                async for chunk in response.content.iter_chunked(1024):
                    f.write(chunk)
                    pbar.update(len(chunk))

def download_model_async(model_url: str = DEFAULT_MODEL_URL) -> Optional[str]:
    os.makedirs(MODEL_DIR, exist_ok=True)
    model_filename = extract_filename(model_url)
    model_path = os.path.join(MODEL_DIR, model_filename)

    if os.path.exists(model_path):
        print(f"✅ Model already exists: {model_path}")
        return model_path

    print(f"⬇️ [Async] Downloading model from {model_url} ...")
    try:
        asyncio.run(_async_download_file(model_url, model_path, token=HUGGINGFACE_TOKEN))
        print(f"✅ Async download complete: {model_path}")
        return model_path
    except Exception as e:
        print(f"❌ Async download failed: {e}")
        return None