File size: 4,348 Bytes
6ebf2fd 96812c9 6ebf2fd 96812c9 59a1fe9 1150456 6ebf2fd 1150456 6ebf2fd 1150456 6ebf2fd 1150456 6ebf2fd 1150456 6ebf2fd 1150456 59a1fe9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import { ModelInfoResponse, QuantizationType } from "../types"
const getModelInfo = async (modelName: string): Promise<ModelInfoResponse> => {
const token = process.env.REACT_APP_HUGGINGFACE_TOKEN
if (!token) {
throw new Error(
'Hugging Face token not found. Please set REACT_APP_HUGGINGFACE_TOKEN in your .env file'
)
}
const response = await fetch(
`https://huggingface.co/api/models/${modelName}`,
{
method: 'GET',
headers: {
Authorization: `Bearer ${token}`
}
}
)
if (!response.ok) {
throw new Error(`Failed to fetch model info: ${response.statusText}`)
}
const modelData: ModelInfoResponse = await response.json()
const requiredFiles = [
'config.json',
'tokenizer.json',
'tokenizer_config.json',
]
const siblingFiles = modelData.siblings?.map(s => s.rfilename) || []
const isCompatible =
requiredFiles.every((file) => siblingFiles.includes(file)) &&
siblingFiles.some((file) => file.endsWith('.onnx') && file.startsWith('onnx/'))
const incompatibilityReason = isCompatible
? ''
: `Missing required files: ${requiredFiles
.filter(file => !siblingFiles.includes(file))
.join(', ')}`
const supportedQuantizations = siblingFiles
.filter((file) => file.endsWith('.onnx') && file.includes('_'))
.map((file) => file.split('/')[1].split('_')[1].split('.')[0])
.filter((q) => q !== 'quantized')
const uniqueSupportedQuantizations = Array.from(new Set(supportedQuantizations))
uniqueSupportedQuantizations.sort((a, b) => {
const getNumericValue = (str: string) => {
const match = str.match(/(\d+)/)
return match ? parseInt(match[1]) : Infinity
}
return getNumericValue(a) - getNumericValue(b)
})
// If there's a base model, fetch its info and merge with compatibility data
const baseModel = modelData.cardData?.base_model ?? modelData.modelId
if (baseModel && !modelData.safetensors) {
const baseModelResponse = await fetch(
`https://huggingface.co/api/models/${baseModel}`,
{
method: 'GET',
headers: {
Authorization: `Bearer ${token}`
}
}
)
if (baseModelResponse.ok) {
const baseModelData: ModelInfoResponse = await baseModelResponse.json()
return {
...baseModelData,
id: modelData.id,
baseId: baseModel,
isCompatible,
incompatibilityReason,
supportedQuantizations: uniqueSupportedQuantizations as QuantizationType[]
}
}
}
return {
...modelData,
isCompatible,
incompatibilityReason,
supportedQuantizations: uniqueSupportedQuantizations as QuantizationType[]
}
}
const getModelsByPipeline = async (
pipeline_tag: string
): Promise<ModelInfoResponse[]> => {
const token = process.env.REACT_APP_HUGGINGFACE_TOKEN
if (!token) {
throw new Error(
'Hugging Face token not found. Please set REACT_APP_HUGGINGFACE_TOKEN in your .env file'
)
}
const response = await fetch(
`https://huggingface.co/api/models?filter=${pipeline_tag}&filter=transformers.js&sort=downloads`,
{
method: 'GET',
headers: {
Authorization: `Bearer ${token}`
}
}
)
if (!response.ok) {
throw new Error(`Failed to fetch models for pipeline: ${response.statusText}`)
}
const models = await response.json()
if (pipeline_tag === 'text-classification') {
return models.filter((model: ModelInfoResponse) => !model.tags.includes('reranker') && !model.id.includes('reranker')).slice(0, 10)
}
return models.slice(0, 10)
}
function getModelSize(
parameters: number,
quantization: QuantizationType
): number {
let bytesPerParameter: number
switch (quantization) {
case 'fp32':
// 32-bit floating point uses 4 bytes
bytesPerParameter = 4
break
case 'fp16':
bytesPerParameter = 2
break
case 'int8':
case 'bnb8':
case 'uint8':
case 'q8':
bytesPerParameter = 1
break
case 'bnb4':
case 'q4':
bytesPerParameter = 0.5
break
}
// There are 1,024 * 1,024 bytes in a megabyte
const sizeInBytes = parameters * bytesPerParameter
const sizeInMB = sizeInBytes / (1024 * 1024)
return sizeInMB
}
export { getModelInfo, getModelSize, getModelsByPipeline }
|