Vokturz's picture
Add StyleTTS2 support with KokoroTTS integration
79eafc9
import React, { useCallback, useEffect, useState } from 'react'
import {
Listbox,
ListboxButton,
ListboxOption,
ListboxOptions,
Transition
} from '@headlessui/react'
import { useModel } from '../contexts/ModelContext'
import { getModelInfo } from '../lib/huggingface'
import {
Heart,
Download,
ChevronDown,
Check,
ArrowDown,
ArrowUp,
Plus,
Search,
X
} from 'lucide-react'
import Tooltip from './Tooltip'
import { ModelInfoResponse } from '@/types'
type SortOption = 'likes' | 'downloads' | 'createdAt' | 'name'
function ModelSelector() {
const {
models,
setModelInfo,
modelInfo,
pipeline,
isFetching,
setIsFetching
} = useModel()
const [sortBy, setSortBy] = useState<SortOption>('createdAt')
const [sortOrder, setSortOrder] = useState<'asc' | 'desc'>('desc')
const [showCustomInput, setShowCustomInput] = useState(false)
const [customModelName, setCustomModelName] = useState('')
const [isLoadingCustomModel, setIsLoadingCustomModel] = useState(false)
const [customModelError, setCustomModelError] = useState('')
const [isCustomModel, setIsCustomModel] = useState(false)
const formatNumber = (num: number) => {
if (num >= 1000000000) {
return (num / 1000000000).toFixed(1) + 'B'
} else if (num >= 1000000) {
return (num / 1000000).toFixed(1) + 'M'
} else if (num >= 1000) {
return (num / 1000).toFixed(1) + 'K'
}
return num.toString()
}
// Sort models based on current sort criteria
const sortedModels = React.useMemo(() => {
return [...models].sort((a, b) => {
let comparison = 0
switch (sortBy) {
case 'downloads':
comparison = (a.downloads || 0) - (b.downloads || 0)
break
case 'createdAt':
const dateA = new Date(a.createdAt || '').getTime()
const dateB = new Date(b.createdAt || '').getTime()
comparison = dateA - dateB
break
case 'name':
comparison = a.id.localeCompare(b.id)
break
case 'likes':
default:
comparison = (a.likes || 0) - (b.likes || 0)
break
}
return sortOrder === 'desc' ? -comparison : comparison
})
}, [models, sortBy, sortOrder])
// Function to fetch detailed model info and set as selected
const fetchAndSetModelInfo = useCallback(
async (model: ModelInfoResponse, isCustom: boolean = false) => {
try {
const modelInfoResponse = await getModelInfo(model.id, pipeline)
let parameters = 0
if (modelInfoResponse.safetensors) {
const safetensors = modelInfoResponse.safetensors
parameters =
safetensors.parameters.BF16 ||
safetensors.parameters.F16 ||
safetensors.parameters.F32 ||
safetensors.parameters.total ||
0
}
const allTags = [...model.tags, ...modelInfoResponse.tags]
const modelInfo = {
id: model.id,
name: modelInfoResponse.id || model.id,
architecture:
modelInfoResponse.config?.architectures?.[0] || 'Unknown',
parameters,
likes: modelInfoResponse.likes || 0,
downloads: modelInfoResponse.downloads || 0,
createdAt: modelInfoResponse.createdAt || '',
isCompatible: modelInfoResponse.isCompatible,
incompatibilityReason: modelInfoResponse.incompatibilityReason,
supportedQuantizations: modelInfoResponse.supportedQuantizations,
baseId: modelInfoResponse.baseId,
readme: modelInfoResponse.readme,
hasChatTemplate: Boolean(
modelInfoResponse.config?.tokenizer_config?.chat_template
),
isStyleTTS2: Boolean(allTags.includes('style_text_to_speech_2')),
widgetData: modelInfoResponse.widgetData,
voices: modelInfoResponse.voices
}
setModelInfo(modelInfo)
setIsCustomModel(isCustom)
setIsFetching(false)
} catch (error) {
console.error('Error fetching model info:', error)
setIsFetching(false)
throw error
}
},
[setModelInfo, pipeline, setIsFetching]
)
useEffect(() => {
// Reset custom model state when pipeline changes
setIsCustomModel(false)
setShowCustomInput(false)
setCustomModelName('')
setCustomModelError('')
if (pipeline !== 'feature-extraction') {
setSortBy('downloads')
}
}, [pipeline])
// Update modelInfo to first model when models are loaded and no custom model is selected
useEffect(() => {
if (models.length > 0 && !isCustomModel && !modelInfo) {
const firstModel = sortedModels[0]
fetchAndSetModelInfo(firstModel, false)
}
}, [models, sortedModels, fetchAndSetModelInfo, isCustomModel, modelInfo])
const handleModelSelect = (model: ModelInfoResponse) => {
fetchAndSetModelInfo(model, false)
}
const handleSortChange = (newSortBy: SortOption) => {
if (sortBy === newSortBy) {
setSortOrder(sortOrder === 'asc' ? 'desc' : 'asc')
} else {
setSortBy(newSortBy)
setSortOrder('desc')
}
}
const handleCustomModelLoad = async () => {
if (!customModelName.trim()) {
setCustomModelError('Please enter a model name')
return
}
setIsLoadingCustomModel(true)
setCustomModelError('')
try {
await fetchAndSetModelInfo(
{
id: customModelName.trim(),
tags: []
} as unknown as ModelInfoResponse,
true
)
setShowCustomInput(false)
setCustomModelName('')
} catch (error) {
setCustomModelError(
'Failed to load model. Please check the model name and try again.'
)
} finally {
setIsLoadingCustomModel(false)
}
}
const handleRemoveCustomModel = () => {
setIsCustomModel(false)
// Load the first model from the list
if (sortedModels.length > 0) {
fetchAndSetModelInfo(sortedModels[0], false)
}
}
const handleCustomInputKeyPress = (e: React.KeyboardEvent) => {
if (e.key === 'Enter') {
handleCustomModelLoad()
} else if (e.key === 'Escape') {
setShowCustomInput(false)
setCustomModelName('')
setCustomModelError('')
}
}
const selectedModel =
models.find((model) => model.id === modelInfo?.id) || models[0]
const SortIcon = ({ sortOrder }: { sortOrder: 'asc' | 'desc' }) => {
return sortOrder === 'asc' ? (
<ArrowUp className="w-3 h-3 ml-1" />
) : (
<ArrowDown className="w-3 h-3 ml-1" />
)
}
if (isCustomModel) {
return (
<div className="relative">
<div className="w-full px-3 py-2 border border-gray-300 rounded-md bg-white flex items-center justify-between">
<div className="flex flex-col flex-1 min-w-0">
<span className="truncate font-medium">
{modelInfo?.id || 'Custom model'}
</span>
</div>
<div className="flex items-center space-x-3">
{modelInfo && (modelInfo.likes > 0 || modelInfo.downloads > 0) && (
<div className="flex items-center space-x-3 text-xs text-gray-500">
{modelInfo.likes > 0 && (
<div className="flex items-center space-x-1">
<Heart className="w-3 h-3 text-red-500" />
<span>{formatNumber(modelInfo.likes)}</span>
</div>
)}
{modelInfo.downloads > 0 && (
<div className="flex items-center space-x-1">
<Download className="w-3 h-3 text-green-500" />
<span>{formatNumber(modelInfo.downloads)}</span>
</div>
)}
</div>
)}
<button
onClick={handleRemoveCustomModel}
className="p-1 text-gray-400 hover:text-red-500 transition-colors"
title="Remove custom model"
>
<X className="w-4 h-4" />
</button>
</div>
</div>
</div>
)
}
if (isFetching || models.length === 0) {
return (
<div className="relative">
<div className="w-full px-3 py-2 border border-gray-300 rounded-md bg-white flex items-center justify-between animate-pulse h-10">
<div className="flex flex-col flex-1 min-w-0 space-y-2">
<div className="h-4 bg-gray-200 rounded-sm w-3/4"></div>
</div>
<div className="flex items-center space-x-3">
<div className="flex items-center space-x-3">
<div className="flex items-center space-x-1">
<Heart className="w-3 h-3 text-red-500" />
<div className="h-3 bg-gray-200 rounded-sm w-8"></div>
</div>
<div className="flex items-center space-x-1">
<Download className="w-3 h-3 text-green-500" />
<div className="h-3 bg-gray-200 rounded-sm w-8"></div>
</div>
</div>
<div className="w-4 h-4 bg-gray-200 rounded-sm"></div>
</div>
</div>
</div>
)
}
return (
<div className="relative">
<Listbox
value={selectedModel}
onChange={(model) => handleModelSelect(model)}
>
<div className="relative">
<ListboxButton className="w-full px-3 py-2 border border-gray-300 rounded-md focus:outline-hidden focus:ring-2 focus:ring-blue-500 focus:border-transparent bg-white text-left flex items-center justify-between">
<div className="flex items-center justify-between w-full">
<div className="flex flex-col flex-1 min-w-0">
<Tooltip content={modelInfo?.id || 'Select a model'}>
<span className="truncate font-medium block">
{modelInfo?.id || 'Select a model'}
</span>
</Tooltip>
</div>
<div className="flex items-center space-x-3">
{selectedModel &&
(selectedModel.likes > 0 || selectedModel.downloads > 0) && (
<div className="flex items-center space-x-3 text-xs text-gray-500">
{selectedModel.likes > 0 && (
<div className="flex items-center space-x-1">
<Heart className="w-3 h-3 text-red-500" />
<span>{formatNumber(selectedModel.likes)}</span>
</div>
)}
{selectedModel.downloads > 0 && (
<div className="flex items-center space-x-1">
<Download className="w-3 h-3 text-green-500" />
<span>{formatNumber(selectedModel.downloads)}</span>
</div>
)}
</div>
)}
<ChevronDown className="w-4 h-4 ui-open:rotate-180 transition-transform shrink-0" />
</div>
</div>
</ListboxButton>
<Transition
enter="transition duration-100 ease-out"
enterFrom="transform scale-95 opacity-0"
enterTo="transform scale-100 opacity-100"
leave="transition duration-75 ease-out"
leaveFrom="transform scale-100 opacity-100"
leaveTo="transform scale-95 opacity-0"
>
<ListboxOptions className="absolute z-10 w-full mt-1 bg-white border border-gray-300 rounded-md shadow-lg max-h-96 overflow-hidden focus:outline-hidden">
{/* Custom Model Input */}
{showCustomInput ? (
<div className="px-3 py-3 border-b border-gray-200 bg-gray-50 sticky top-0 z-10">
<div className="space-y-2">
<div className="flex items-center space-x-2">
<input
type="text"
value={customModelName}
onChange={(e) => setCustomModelName(e.target.value)}
onKeyDown={handleCustomInputKeyPress}
placeholder="onnx-community/Qwen3-0.6B-ONNX"
className="flex-1 px-2 py-1 text-sm border border-gray-300 rounded-sm focus:outline-hidden focus:ring-1 focus:ring-blue-500"
autoFocus
/>
<button
onClick={handleCustomModelLoad}
disabled={isLoadingCustomModel}
className="px-3 py-1 text-sm bg-blue-600 text-white rounded-sm hover:bg-blue-700 disabled:opacity-50 disabled:cursor-not-allowed flex items-center space-x-1"
>
{isLoadingCustomModel ? (
<div className="w-3 h-3 border border-white border-t-transparent rounded-full animate-spin" />
) : (
<Search className="w-3 h-3" />
)}
<span>Load</span>
</button>
<button
onClick={() => {
setShowCustomInput(false)
setCustomModelName('')
setCustomModelError('')
}}
className="px-2 py-1 text-sm text-gray-600 hover:text-gray-800"
>
Cancel
</button>
</div>
{customModelError && (
<p className="text-xs text-red-600">{customModelError}</p>
)}
<p className="text-xs text-gray-500">
Press Enter to load or Escape to cancel
</p>
</div>
</div>
) : (
<>
<div className="sticky top-0 z-10 bg-gray-50 border-b border-gray-200 p-3 space-y-3">
{/* Load Custom Model Button */}
<button
onClick={() => setShowCustomInput(true)}
className="w-full flex items-center justify-center space-x-2 px-3 py-2 text-sm text-blue-600 hover:bg-blue-50 rounded-sm transition-colors"
>
<Plus className="w-4 h-4" />
<span>Load Custom Model</span>
</button>
{/* Sort Controls */}
<div className="flex items-center space-x-2 text-xs">
<span className="text-gray-600 font-medium">
Sort by:
</span>
<button
onClick={() => handleSortChange('name')}
className={`px-2 py-1 rounded flex items-center space-x-1 ${
sortBy === 'name'
? 'bg-blue-100 text-blue-700'
: 'text-gray-600 hover:bg-gray-100'
}`}
>
<span>Name</span>
{sortBy === 'name' && (
<SortIcon sortOrder={sortOrder} />
)}
</button>
<button
onClick={() => handleSortChange('likes')}
className={`px-2 py-1 rounded flex items-center space-x-1 ${
sortBy === 'likes'
? 'bg-blue-100 text-blue-700'
: 'text-gray-600 hover:bg-gray-100'
}`}
>
<Heart className="w-3 h-3" />
<span>Likes</span>
{sortBy === 'likes' && (
<SortIcon sortOrder={sortOrder} />
)}
</button>
<button
onClick={() => handleSortChange('downloads')}
className={`px-2 py-1 rounded flex items-center space-x-1 ${
sortBy === 'downloads'
? 'bg-blue-100 text-blue-700'
: 'text-gray-600 hover:bg-gray-100'
}`}
>
<Download className="w-3 h-3" />
<span>Downloads</span>
{sortBy === 'downloads' && (
<SortIcon sortOrder={sortOrder} />
)}
</button>
<button
onClick={() => handleSortChange('createdAt')}
className={`px-2 py-1 rounded flex items-center space-x-1 ${
sortBy === 'createdAt'
? 'bg-blue-100 text-blue-700'
: 'text-gray-600 hover:bg-gray-100'
}`}
>
<span>Date</span>
{sortBy === 'createdAt' && (
<SortIcon sortOrder={sortOrder} />
)}
</button>
</div>
</div>
</>
)}
{/* Model Options - Scrollable */}
{!showCustomInput && (
<div className="overflow-auto max-h-48">
{sortedModels.map((model) => {
const hasStats = model.likes > 0 || model.downloads > 0
return (
<ListboxOption
key={model.id}
value={model}
className={({ active, selected }) =>
`px-3 py-3 cursor-pointer border-b border-gray-100 last:border-b-0 ${
active ? 'bg-gray-50' : ''
} ${selected ? 'bg-blue-50' : ''}`
}
>
{({ selected }) => (
<div className="relative flex items-start py-1">
<div className="flex-1 min-w-0 pr-3">
<div className="flex items-center justify-between">
<Tooltip content={model.id}>
<span className="text-sm font-medium truncate block max-w-[450px]">
{model.id}
</span>
</Tooltip>
{selected && (
<Check className="w-4 h-4 text-blue-600 ml-2 shrink-0" />
)}
</div>
{/* Stats Display */}
{hasStats && (
<div className="flex items-center space-x-3 text-xs text-gray-500 mt-1">
{model.likes > 0 && (
<div className="flex items-center space-x-1">
<Heart className="w-3 h-3 text-red-500" />
<span>{formatNumber(model.likes)}</span>
</div>
)}
{model.downloads > 0 && (
<div className="flex items-center space-x-1">
<Download className="w-3 h-3 text-green-500" />
<span>
{formatNumber(model.downloads)}
</span>
</div>
)}
{model.createdAt && (
<span className="text-xs text-gray-400">
{model.createdAt.split('T')[0]}
</span>
)}
</div>
)}
</div>
</div>
)}
</ListboxOption>
)
})}
</div>
)}
</ListboxOptions>
</Transition>
</div>
</Listbox>
</div>
)
}
export default ModelSelector