File size: 4,348 Bytes
6ebf2fd
96812c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ebf2fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96812c9
 
59a1fe9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1150456
 
 
 
 
 
 
6ebf2fd
1150456
 
 
6ebf2fd
1150456
 
6ebf2fd
 
 
 
1150456
 
6ebf2fd
 
1150456
6ebf2fd
1150456
 
 
 
 
 
 
 
 
 
59a1fe9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import { ModelInfoResponse, QuantizationType } from "../types"

const getModelInfo = async (modelName: string): Promise<ModelInfoResponse> => {
  const token = process.env.REACT_APP_HUGGINGFACE_TOKEN

  if (!token) {
    throw new Error(
      'Hugging Face token not found. Please set REACT_APP_HUGGINGFACE_TOKEN in your .env file'
    )
  }

  const response = await fetch(
    `https://huggingface.co/api/models/${modelName}`,
    {
      method: 'GET',
      headers: {
        Authorization: `Bearer ${token}`
      }
    }
  )

  if (!response.ok) {
    throw new Error(`Failed to fetch model info: ${response.statusText}`)
  }
  
  const modelData: ModelInfoResponse = await response.json()
  
  const requiredFiles = [
    'config.json',
    'tokenizer.json',
    'tokenizer_config.json',
  ]
  
  const siblingFiles = modelData.siblings?.map(s => s.rfilename) || []
  const isCompatible =
    requiredFiles.every((file) => siblingFiles.includes(file)) &&
    siblingFiles.some((file) => file.endsWith('.onnx') && file.startsWith('onnx/'))
  const incompatibilityReason = isCompatible
    ? ''
    : `Missing required files: ${requiredFiles
        .filter(file => !siblingFiles.includes(file))
        .join(', ')}`
  const supportedQuantizations = siblingFiles
      .filter((file) => file.endsWith('.onnx') && file.includes('_'))
      .map((file) => file.split('/')[1].split('_')[1].split('.')[0])
      .filter((q) => q !== 'quantized')
  const uniqueSupportedQuantizations = Array.from(new Set(supportedQuantizations))
  uniqueSupportedQuantizations.sort((a, b) => {
    const getNumericValue = (str: string) => {
      const match = str.match(/(\d+)/)
      return match ? parseInt(match[1]) : Infinity
    }
    return getNumericValue(a) - getNumericValue(b)
  })

  // If there's a base model, fetch its info and merge with compatibility data
  const baseModel = modelData.cardData?.base_model ?? modelData.modelId 
  if (baseModel && !modelData.safetensors) {
    const baseModelResponse = await fetch(
      `https://huggingface.co/api/models/${baseModel}`,
      {
        method: 'GET',
        headers: {
          Authorization: `Bearer ${token}`
        }
      }
    )

    if (baseModelResponse.ok) {
      const baseModelData: ModelInfoResponse = await baseModelResponse.json()
      
      return {
        ...baseModelData,
        id: modelData.id,
        baseId: baseModel,
        isCompatible,
        incompatibilityReason,
        supportedQuantizations: uniqueSupportedQuantizations as QuantizationType[]
      }
    }
  }
  
  return {
    ...modelData,
    isCompatible,
    incompatibilityReason,
    supportedQuantizations: uniqueSupportedQuantizations as QuantizationType[]
  }
}

const getModelsByPipeline = async (
  pipeline_tag: string
): Promise<ModelInfoResponse[]> => {
  const token = process.env.REACT_APP_HUGGINGFACE_TOKEN

  if (!token) {
    throw new Error(
      'Hugging Face token not found. Please set REACT_APP_HUGGINGFACE_TOKEN in your .env file'
    )
  }

  const response = await fetch(
    `https://huggingface.co/api/models?filter=${pipeline_tag}&filter=transformers.js&sort=downloads`,
    {
      method: 'GET',
      headers: {
        Authorization: `Bearer ${token}`
      }
    }
  )

  if (!response.ok) {
    throw new Error(`Failed to fetch models for pipeline: ${response.statusText}`)
  }
  const models = await response.json()
  if (pipeline_tag === 'text-classification') {
    return models.filter((model: ModelInfoResponse) => !model.tags.includes('reranker') && !model.id.includes('reranker')).slice(0, 10)
  }
  return models.slice(0, 10)
}

function getModelSize(
  parameters: number,
  quantization: QuantizationType
): number {
  let bytesPerParameter: number

  switch (quantization) {
    case 'fp32':
      // 32-bit floating point uses 4 bytes
      bytesPerParameter = 4
      break
    case 'fp16':
      bytesPerParameter = 2
      break
    case 'int8':
    case 'bnb8':
    case 'uint8':
    case 'q8':
      bytesPerParameter = 1
      break
    case 'bnb4':
    case 'q4': 
      bytesPerParameter = 0.5
    break
  }

  // There are 1,024 * 1,024 bytes in a megabyte
  const sizeInBytes = parameters * bytesPerParameter
  const sizeInMB = sizeInBytes / (1024 * 1024)

  return sizeInMB
}


export { getModelInfo, getModelSize, getModelsByPipeline }