import React, { useEffect, useState } from 'react' import { Listbox, ListboxButton, ListboxOption, ListboxOptions, Transition } from '@headlessui/react' import { useModel } from '../contexts/ModelContext' import { getModelInfo } from '../lib/huggingface' import { Heart, Download, ChevronDown, Check, ArrowUpDown } from 'lucide-react' type SortOption = 'likes' | 'downloads' | 'createdAt' | 'name' const ModelSelector: React.FC = () => { const { models, setModelInfo, modelInfo, pipeline } = useModel() const [sortBy, setSortBy] = useState('downloads') const [sortOrder, setSortOrder] = useState<'asc' | 'desc'>('desc') const formatNumber = (num: number) => { if (num >= 1000000000) { return (num / 1000000000).toFixed(1) + 'B' } else if (num >= 1000000) { return (num / 1000000).toFixed(1) + 'M' } else if (num >= 1000) { return (num / 1000).toFixed(1) + 'K' } return num.toString() } // Sort models based on current sort criteria const sortedModels = React.useMemo(() => { return [...models].sort((a, b) => { let comparison = 0 switch (sortBy) { case 'downloads': comparison = (a.downloads || 0) - (b.downloads || 0) break case 'createdAt': const dateA = new Date(a.createdAt || '').getTime() const dateB = new Date(b.createdAt || '').getTime() comparison = dateA - dateB break case 'name': comparison = a.id.localeCompare(b.id) break case 'likes': default: comparison = (a.likes || 0) - (b.likes || 0) break } return sortOrder === 'desc' ? -comparison : comparison }) }, [models, sortBy, sortOrder]) // Function to fetch detailed model info and set as selected const fetchAndSetModelInfo = async (modelId: string) => { try { const modelInfoResponse = await getModelInfo(modelId) let parameters = 0 if (modelInfoResponse.safetensors) { const safetensors = modelInfoResponse.safetensors parameters = safetensors.parameters.BF16 || safetensors.parameters.F16 || safetensors.parameters.F32 || safetensors.parameters.total || 0 } const modelInfo = { id: modelId, name: modelInfoResponse.id || modelId, architecture: modelInfoResponse.config?.architectures?.[0] || 'Unknown', parameters, likes: modelInfoResponse.likes || 0, downloads: modelInfoResponse.downloads || 0, createdAt: modelInfoResponse.createdAt || '', isCompatible: modelInfoResponse.isCompatible, incompatibilityReason: modelInfoResponse.incompatibilityReason, supportedQuantizations: modelInfoResponse.supportedQuantizations, baseId: modelInfoResponse.baseId } console.log('Fetched model info:', modelInfoResponse) setModelInfo(modelInfo) } catch (error) { console.error('Error fetching model info:', error) } } // Update modelInfo to first model when pipeline changes useEffect(() => { if (models.length > 0) { const firstModel = models[0] fetchAndSetModelInfo(firstModel.id) } }, [pipeline, models]) const handleModelSelect = (modelId: string) => { fetchAndSetModelInfo(modelId) } const handleSortChange = (newSortBy: SortOption) => { if (sortBy === newSortBy) { setSortOrder(sortOrder === 'asc' ? 'desc' : 'asc') } else { setSortBy(newSortBy) setSortOrder('desc') } } const selectedModel = models.find(model => model.id === modelInfo.id) || models[0] return (
handleModelSelect(model.id)}>
{modelInfo.id || 'Select a model'}
{selectedModel && (selectedModel.likes > 0 || selectedModel.downloads > 0) && (
{selectedModel.likes > 0 && (
{formatNumber(selectedModel.likes)}
)} {selectedModel.downloads > 0 && (
{formatNumber(selectedModel.downloads)}
)}
)}
{/* Sort Controls - Always Visible */}
Sort by:
{/* Model Options - Scrollable */}
{sortedModels.map((model) => { const hasStats = model.likes > 0 || model.downloads > 0 return ( `px-3 py-2 cursor-pointer border-b border-gray-100 last:border-b-0 ${ active ? 'bg-gray-50' : '' } ${selected ? 'bg-blue-50' : ''}` } > {({ selected }) => (
{model.id} {selected && ( )}
{/* Stats Display */} {hasStats && (
{model.likes > 0 && (
{formatNumber(model.likes)}
)} {model.downloads > 0 && (
{formatNumber(model.downloads)}
)} {model.createdAt && ( {model.createdAt.split('T')[0]} )}
)}
)}
) })}
) } export default ModelSelector