Vokturz commited on
Commit
79eafc9
·
1 Parent(s): 25647ae

Add StyleTTS2 support with KokoroTTS integration

Browse files
public/workers/text-to-speech.js CHANGED
@@ -1,5 +1,6 @@
1
  /* eslint-disable no-restricted-globals */
2
  import { pipeline } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@latest'
 
3
 
4
  class MyTextToSpeechPipeline {
5
  static task = 'text-to-speech'
@@ -40,10 +41,66 @@ class MyTextToSpeechPipeline {
40
  }
41
  }
42
 
43
- // Listen for messages from the main thread
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  self.addEventListener('message', async (event) => {
45
  try {
46
- const { type, model, dtype, text, config } = event.data
47
 
48
  if (!model) {
49
  self.postMessage({
@@ -53,19 +110,31 @@ self.addEventListener('message', async (event) => {
53
  return
54
  }
55
 
56
- // Retrieve the pipeline. This will download the model if not already cached.
57
- const synthesizer = await MyTextToSpeechPipeline.getInstance(
58
- model,
59
- dtype,
60
- (x) => {
61
- self.postMessage({ status: 'loading', output: x })
62
- }
63
- )
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  if (type === 'load') {
66
  self.postMessage({
67
  status: 'ready',
68
- output: `Model ${model}, dtype ${dtype} loaded`
69
  })
70
  return
71
  }
@@ -79,31 +148,44 @@ self.addEventListener('message', async (event) => {
79
  return
80
  }
81
 
82
- const options = {}
 
 
 
 
83
 
84
- // Add speaker embeddings if provided
85
- if (config?.speakerEmbeddings) {
86
- try {
87
- const response = await fetch(config.speakerEmbeddings)
88
- if (response.ok) {
89
- const embeddings = await response.arrayBuffer()
90
- options.speaker_embeddings = new Float32Array(embeddings)
91
  }
92
- } catch (error) {
93
- console.warn('Failed to load speaker embeddings:', error)
94
- // Continue without speaker embeddings
95
- }
96
- }
97
 
98
- try {
99
- const output = await synthesizer(text.trim(), options)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  self.postMessage({
102
  status: 'output',
103
- output: {
104
- audio: Array.from(output.audio),
105
- sampling_rate: output.sampling_rate
106
- }
107
  })
108
 
109
  self.postMessage({ status: 'ready' })
@@ -114,7 +196,8 @@ self.addEventListener('message', async (event) => {
114
  } catch (error) {
115
  self.postMessage({
116
  status: 'error',
117
- output: error.message || 'An error occurred during text-to-speech synthesis'
 
118
  })
119
  }
120
  })
 
1
  /* eslint-disable no-restricted-globals */
2
  import { pipeline } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@latest'
3
+ import { KokoroTTS } from 'https://cdn.jsdelivr.net/npm/kokoro-js@1.2.1/dist/kokoro.web.js'
4
 
5
  class MyTextToSpeechPipeline {
6
  static task = 'text-to-speech'
 
41
  }
42
  }
43
 
44
+ class MyKokoroTTSPipeline {
45
+ static instance = null
46
+
47
+ static async getInstance(model, dtype = 'fp32', progress_callback = null) {
48
+ try {
49
+ const device = 'webgpu'
50
+ if (progress_callback) {
51
+ progress_callback({
52
+ status: 'loading',
53
+ message: `Loading Kokoro TTS model with ${device} device`
54
+ })
55
+ }
56
+
57
+ this.instance = await KokoroTTS.from_pretrained(model, {
58
+ dtype,
59
+ device,
60
+ progress_callback: progress_callback
61
+ ? (data) => {
62
+ progress_callback({
63
+ status: 'loading',
64
+ ...data
65
+ })
66
+ }
67
+ : null
68
+ })
69
+ return this.instance
70
+ } catch (webgpuError) {
71
+ // Fallback to WASM if WebGPU fails
72
+ if (progress_callback) {
73
+ progress_callback({
74
+ status: 'fallback',
75
+ message: 'WebGPU failed, falling back to WASM'
76
+ })
77
+ }
78
+ try {
79
+ this.instance = await KokoroTTS.from_pretrained(model, {
80
+ dtype,
81
+ device: 'wasm',
82
+ progress_callback: progress_callback
83
+ ? (data) => {
84
+ progress_callback({
85
+ status: 'loading',
86
+ ...data
87
+ })
88
+ }
89
+ : null
90
+ })
91
+ return this.instance
92
+ } catch (wasmError) {
93
+ throw new Error(
94
+ `Both WebGPU and WASM failed for Kokoro TTS. WebGPU error: ${webgpuError.message}. WASM error: ${wasmError.message}`
95
+ )
96
+ }
97
+ }
98
+ }
99
+ }
100
+
101
  self.addEventListener('message', async (event) => {
102
  try {
103
+ const { type, model, dtype, text, isStyleTTS2, config } = event.data
104
 
105
  if (!model) {
106
  self.postMessage({
 
110
  return
111
  }
112
 
113
+ let synthesizer
114
+ if (isStyleTTS2) {
115
+ // Use Kokoro TTS for StyleTTS2 models
116
+ synthesizer = await MyKokoroTTSPipeline.getInstance(
117
+ model,
118
+ dtype || 'q8',
119
+ (x) => {
120
+ self.postMessage({ status: 'loading', output: x })
121
+ }
122
+ )
123
+ } else {
124
+ // Use standard transformers pipeline
125
+ synthesizer = await MyTextToSpeechPipeline.getInstance(
126
+ model,
127
+ dtype || 'fp32',
128
+ (x) => {
129
+ self.postMessage({ status: 'loading', output: x })
130
+ }
131
+ )
132
+ }
133
 
134
  if (type === 'load') {
135
  self.postMessage({
136
  status: 'ready',
137
+ output: `Model ${model}${isStyleTTS2 ? ' StyleTTS2' : ''}, dtype ${dtype} loaded`
138
  })
139
  return
140
  }
 
148
  return
149
  }
150
 
151
+ try {
152
+ let output
153
+
154
+ if (isStyleTTS2) {
155
+ const options = {}
156
 
157
+ options.voice = config.voice
158
+ const audioResult = await synthesizer.generate(text.trim(), options)
159
+
160
+ output = {
161
+ audio: Array.from(audioResult.audio),
162
+ sampling_rate: audioResult.sampling_rate || 24000 // Default for Kokoro
 
163
  }
164
+ } else {
165
+ const options = {}
 
 
 
166
 
167
+ if (config?.speakerEmbeddings) {
168
+ try {
169
+ const response = await fetch(config.speakerEmbeddings)
170
+ if (response.ok) {
171
+ const embeddings = await response.arrayBuffer()
172
+ options.speaker_embeddings = new Float32Array(embeddings)
173
+ }
174
+ } catch (error) {
175
+ console.warn('Failed to load speaker embeddings:', error)
176
+ }
177
+ }
178
+
179
+ const result = await synthesizer(text.trim(), options)
180
+ output = {
181
+ audio: Array.from(result.audio),
182
+ sampling_rate: result.sampling_rate
183
+ }
184
+ }
185
 
186
  self.postMessage({
187
  status: 'output',
188
+ output
 
 
 
189
  })
190
 
191
  self.postMessage({ status: 'ready' })
 
196
  } catch (error) {
197
  self.postMessage({
198
  status: 'error',
199
+ output:
200
+ error.message || 'An error occurred during text-to-speech synthesis'
201
  })
202
  }
203
  })
src/components/AudioPlayer.tsx CHANGED
@@ -6,6 +6,7 @@ interface AudioPlayerProps {
6
  samplingRate: number
7
  text: string
8
  index: number
 
9
  }
10
 
11
  function createWavBuffer(
@@ -235,7 +236,13 @@ function CustomAudioVisualizer({
235
  )
236
  }
237
 
238
- function AudioPlayer({ audio, samplingRate, text, index }: AudioPlayerProps) {
 
 
 
 
 
 
239
  const [isPlaying, setIsPlaying] = useState(false)
240
  const [currentTime, setCurrentTime] = useState(0)
241
  const [duration, setDuration] = useState(0)
@@ -373,7 +380,9 @@ function AudioPlayer({ audio, samplingRate, text, index }: AudioPlayerProps) {
373
  return (
374
  <div className="border border-gray-200 rounded-lg p-4 bg-gray-50">
375
  <div className="mb-3">
376
- <p className="text-sm text-gray-700 font-medium mb-2">Prompt:</p>
 
 
377
  <p className="text-sm text-gray-600 italic bg-white p-2 rounded border">
378
  "{text}"
379
  </p>
 
6
  samplingRate: number
7
  text: string
8
  index: number
9
+ voice?: string
10
  }
11
 
12
  function createWavBuffer(
 
236
  )
237
  }
238
 
239
+ function AudioPlayer({
240
+ audio,
241
+ samplingRate,
242
+ text,
243
+ index,
244
+ voice
245
+ }: AudioPlayerProps) {
246
  const [isPlaying, setIsPlaying] = useState(false)
247
  const [currentTime, setCurrentTime] = useState(0)
248
  const [duration, setDuration] = useState(0)
 
380
  return (
381
  <div className="border border-gray-200 rounded-lg p-4 bg-gray-50">
382
  <div className="mb-3">
383
+ <p className="text-sm text-gray-700 font-medium mb-2">
384
+ Prompt{voice ? ` (${voice})` : ''}:
385
+ </p>
386
  <p className="text-sm text-gray-600 italic bg-white p-2 rounded border">
387
  "{text}"
388
  </p>
src/components/ModelLoader.tsx CHANGED
@@ -8,7 +8,6 @@ import { Alert, AlertDescription } from './ui/alert'
8
  const ModelLoader = () => {
9
  const [showAlert, setShowAlert] = useState(false)
10
  const [alertMessage, setAlertMessage] = useState<React.ReactNode>('')
11
- const [lastModel, setLastModel] = useState<string | null>(null)
12
  const {
13
  modelInfo,
14
  selectedQuantization,
@@ -134,7 +133,9 @@ const ModelLoader = () => {
134
  const message = {
135
  type: 'load',
136
  model: modelInfo.name,
137
- dtype: selectedQuantization ?? 'fp32'
 
 
138
  }
139
  activeWorker?.postMessage(message)
140
  }, [modelInfo, selectedQuantization, activeWorker])
@@ -149,7 +150,7 @@ const ModelLoader = () => {
149
 
150
  <div className="flex items-center justify-between space-x-4">
151
  <div className="flex items-center space-x-2">
152
- {modelInfo.supportedQuantizations.length > 1 ? (
153
  <>
154
  <span className="text-xs text-gray-600 font-medium">Quant:</span>
155
 
 
8
  const ModelLoader = () => {
9
  const [showAlert, setShowAlert] = useState(false)
10
  const [alertMessage, setAlertMessage] = useState<React.ReactNode>('')
 
11
  const {
12
  modelInfo,
13
  selectedQuantization,
 
133
  const message = {
134
  type: 'load',
135
  model: modelInfo.name,
136
+ dtype: selectedQuantization ?? 'fp32',
137
+ isStyleTTS2:
138
+ modelInfo.isStyleTTS2 || modelInfo.name.includes('kitten-tts') || false // text-to-speech only
139
  }
140
  activeWorker?.postMessage(message)
141
  }, [modelInfo, selectedQuantization, activeWorker])
 
150
 
151
  <div className="flex items-center justify-between space-x-4">
152
  <div className="flex items-center space-x-2">
153
+ {modelInfo.supportedQuantizations.length >= 1 ? (
154
  <>
155
  <span className="text-xs text-gray-600 font-medium">Quant:</span>
156
 
src/components/ModelSelector.tsx CHANGED
@@ -20,6 +20,7 @@ import {
20
  X
21
  } from 'lucide-react'
22
  import Tooltip from './Tooltip'
 
23
 
24
  type SortOption = 'likes' | 'downloads' | 'createdAt' | 'name'
25
 
@@ -80,9 +81,9 @@ function ModelSelector() {
80
 
81
  // Function to fetch detailed model info and set as selected
82
  const fetchAndSetModelInfo = useCallback(
83
- async (modelId: string, isCustom: boolean = false) => {
84
  try {
85
- const modelInfoResponse = await getModelInfo(modelId, pipeline)
86
 
87
  let parameters = 0
88
  if (modelInfoResponse.safetensors) {
@@ -95,9 +96,11 @@ function ModelSelector() {
95
  0
96
  }
97
 
 
 
98
  const modelInfo = {
99
- id: modelId,
100
- name: modelInfoResponse.id || modelId,
101
  architecture:
102
  modelInfoResponse.config?.architectures?.[0] || 'Unknown',
103
  parameters,
@@ -112,7 +115,9 @@ function ModelSelector() {
112
  hasChatTemplate: Boolean(
113
  modelInfoResponse.config?.tokenizer_config?.chat_template
114
  ),
115
- widgetData: modelInfoResponse.widgetData
 
 
116
  }
117
  setModelInfo(modelInfo)
118
  setIsCustomModel(isCustom)
@@ -143,12 +148,12 @@ function ModelSelector() {
143
  useEffect(() => {
144
  if (models.length > 0 && !isCustomModel && !modelInfo) {
145
  const firstModel = sortedModels[0]
146
- fetchAndSetModelInfo(firstModel.id, false)
147
  }
148
  }, [models, sortedModels, fetchAndSetModelInfo, isCustomModel, modelInfo])
149
 
150
- const handleModelSelect = (modelId: string) => {
151
- fetchAndSetModelInfo(modelId, false)
152
  }
153
 
154
  const handleSortChange = (newSortBy: SortOption) => {
@@ -170,7 +175,13 @@ function ModelSelector() {
170
  setCustomModelError('')
171
 
172
  try {
173
- await fetchAndSetModelInfo(customModelName.trim(), true)
 
 
 
 
 
 
174
  setShowCustomInput(false)
175
  setCustomModelName('')
176
  } catch (error) {
@@ -186,7 +197,7 @@ function ModelSelector() {
186
  setIsCustomModel(false)
187
  // Load the first model from the list
188
  if (sortedModels.length > 0) {
189
- fetchAndSetModelInfo(sortedModels[0].id, false)
190
  }
191
  }
192
 
@@ -281,7 +292,7 @@ function ModelSelector() {
281
  <div className="relative">
282
  <Listbox
283
  value={selectedModel}
284
- onChange={(model) => handleModelSelect(model.id)}
285
  >
286
  <div className="relative">
287
  <ListboxButton className="w-full px-3 py-2 border border-gray-300 rounded-md focus:outline-hidden focus:ring-2 focus:ring-blue-500 focus:border-transparent bg-white text-left flex items-center justify-between">
 
20
  X
21
  } from 'lucide-react'
22
  import Tooltip from './Tooltip'
23
+ import { ModelInfoResponse } from '@/types'
24
 
25
  type SortOption = 'likes' | 'downloads' | 'createdAt' | 'name'
26
 
 
81
 
82
  // Function to fetch detailed model info and set as selected
83
  const fetchAndSetModelInfo = useCallback(
84
+ async (model: ModelInfoResponse, isCustom: boolean = false) => {
85
  try {
86
+ const modelInfoResponse = await getModelInfo(model.id, pipeline)
87
 
88
  let parameters = 0
89
  if (modelInfoResponse.safetensors) {
 
96
  0
97
  }
98
 
99
+ const allTags = [...model.tags, ...modelInfoResponse.tags]
100
+
101
  const modelInfo = {
102
+ id: model.id,
103
+ name: modelInfoResponse.id || model.id,
104
  architecture:
105
  modelInfoResponse.config?.architectures?.[0] || 'Unknown',
106
  parameters,
 
115
  hasChatTemplate: Boolean(
116
  modelInfoResponse.config?.tokenizer_config?.chat_template
117
  ),
118
+ isStyleTTS2: Boolean(allTags.includes('style_text_to_speech_2')),
119
+ widgetData: modelInfoResponse.widgetData,
120
+ voices: modelInfoResponse.voices
121
  }
122
  setModelInfo(modelInfo)
123
  setIsCustomModel(isCustom)
 
148
  useEffect(() => {
149
  if (models.length > 0 && !isCustomModel && !modelInfo) {
150
  const firstModel = sortedModels[0]
151
+ fetchAndSetModelInfo(firstModel, false)
152
  }
153
  }, [models, sortedModels, fetchAndSetModelInfo, isCustomModel, modelInfo])
154
 
155
+ const handleModelSelect = (model: ModelInfoResponse) => {
156
+ fetchAndSetModelInfo(model, false)
157
  }
158
 
159
  const handleSortChange = (newSortBy: SortOption) => {
 
175
  setCustomModelError('')
176
 
177
  try {
178
+ await fetchAndSetModelInfo(
179
+ {
180
+ id: customModelName.trim(),
181
+ tags: []
182
+ } as unknown as ModelInfoResponse,
183
+ true
184
+ )
185
  setShowCustomInput(false)
186
  setCustomModelName('')
187
  } catch (error) {
 
197
  setIsCustomModel(false)
198
  // Load the first model from the list
199
  if (sortedModels.length > 0) {
200
+ fetchAndSetModelInfo(sortedModels[0], false)
201
  }
202
  }
203
 
 
292
  <div className="relative">
293
  <Listbox
294
  value={selectedModel}
295
+ onChange={(model) => handleModelSelect(model)}
296
  >
297
  <div className="relative">
298
  <ListboxButton className="w-full px-3 py-2 border border-gray-300 rounded-md focus:outline-hidden focus:ring-2 focus:ring-blue-500 focus:border-transparent bg-white text-left flex items-center justify-between">
src/components/Sidebar.tsx CHANGED
@@ -79,17 +79,6 @@ const Sidebar = ({
79
  </Tooltip>
80
  </span>
81
  )}
82
- {pipeline === 'text-to-speech' && (
83
- <span className="flex text-xs text-yellow-500 justify-center text-center">
84
- Not fully supported{' '}
85
- <Tooltip
86
- content="Transformers.js has limited support for text-to-speech"
87
- className="transform -translate-x-1/3 break-keep max-w-12"
88
- >
89
- <CircleQuestionMark className="inline w-4 h-4 ml-1" />
90
- </Tooltip>
91
- </span>
92
- )}
93
  </div>
94
  <PipelineSelector pipeline={pipeline} setPipeline={setPipeline} />
95
  </div>
 
79
  </Tooltip>
80
  </span>
81
  )}
 
 
 
 
 
 
 
 
 
 
 
82
  </div>
83
  <PipelineSelector pipeline={pipeline} setPipeline={setPipeline} />
84
  </div>
src/components/pipelines/TextToSpeech.tsx CHANGED
@@ -7,6 +7,7 @@ import {
7
  AudioResult
8
  } from '../../contexts/TextToSpeechContext'
9
  import AudioPlayer from '../AudioPlayer'
 
10
 
11
  const SAMPLE_TEXTS = [
12
  'Hello, this is a sample text for text-to-speech synthesis.',
@@ -18,6 +19,7 @@ const SAMPLE_TEXTS = [
18
  function TextToSpeech() {
19
  const {
20
  config,
 
21
  audioResults,
22
  currentText,
23
  setCurrentText,
@@ -46,8 +48,10 @@ function TextToSpeech() {
46
  text: currentText.trim(),
47
  model: modelInfo.id,
48
  dtype: selectedQuantization ?? 'fp32',
 
49
  config: {
50
- speakerEmbeddings: config.speakerEmbeddings
 
51
  }
52
  }
53
 
@@ -72,7 +76,7 @@ function TextToSpeech() {
72
  audio: new Float32Array(output.audio),
73
  sampling_rate: output.sampling_rate
74
  }
75
- addAudioResult(currentText, audioResult)
76
  } else if (status === 'ready' || status === 'error') {
77
  setIsSynthesizing(false)
78
  }
@@ -82,6 +86,15 @@ function TextToSpeech() {
82
  return () => activeWorker.removeEventListener('message', onMessageReceived)
83
  }, [activeWorker, currentText, addAudioResult])
84
 
 
 
 
 
 
 
 
 
 
85
  const handleKeyPress = (e: React.KeyboardEvent) => {
86
  if (e.key === 'Enter' && !e.shiftKey) {
87
  e.preventDefault()
@@ -172,6 +185,7 @@ function TextToSpeech() {
172
  samplingRate={result.sampling_rate}
173
  text={result.text}
174
  index={index}
 
175
  />
176
  ))}
177
  </div>
 
7
  AudioResult
8
  } from '../../contexts/TextToSpeechContext'
9
  import AudioPlayer from '../AudioPlayer'
10
+ import { preview } from 'vite'
11
 
12
  const SAMPLE_TEXTS = [
13
  'Hello, this is a sample text for text-to-speech synthesis.',
 
19
  function TextToSpeech() {
20
  const {
21
  config,
22
+ setConfig,
23
  audioResults,
24
  currentText,
25
  setCurrentText,
 
48
  text: currentText.trim(),
49
  model: modelInfo.id,
50
  dtype: selectedQuantization ?? 'fp32',
51
+ isStyleTTS2: modelInfo.isStyleTTS2 ?? false,
52
  config: {
53
+ speakerEmbeddings: config.speakerEmbeddings,
54
+ voice: config.voice
55
  }
56
  }
57
 
 
76
  audio: new Float32Array(output.audio),
77
  sampling_rate: output.sampling_rate
78
  }
79
+ addAudioResult(currentText, audioResult, config.voice)
80
  } else if (status === 'ready' || status === 'error') {
81
  setIsSynthesizing(false)
82
  }
 
86
  return () => activeWorker.removeEventListener('message', onMessageReceived)
87
  }, [activeWorker, currentText, addAudioResult])
88
 
89
+ useEffect(() => {
90
+ if (!modelInfo) return
91
+ if (modelInfo && modelInfo?.voices.length > 0)
92
+ setConfig((prev) => ({
93
+ ...prev,
94
+ voice: modelInfo.voices[0]
95
+ }))
96
+ }, [modelInfo])
97
+
98
  const handleKeyPress = (e: React.KeyboardEvent) => {
99
  if (e.key === 'Enter' && !e.shiftKey) {
100
  e.preventDefault()
 
185
  samplingRate={result.sampling_rate}
186
  text={result.text}
187
  index={index}
188
+ voice={result.voice}
189
  />
190
  ))}
191
  </div>
src/components/pipelines/TextToSpeechConfig.tsx CHANGED
@@ -2,6 +2,14 @@ import React from 'react'
2
  import { Label } from '@/components/ui/label'
3
  import { Input } from '@/components/ui/input'
4
  import { useTextToSpeech } from '../../contexts/TextToSpeechContext'
 
 
 
 
 
 
 
 
5
 
6
  interface TextToSpeechConfigProps {
7
  className?: string
@@ -10,31 +18,64 @@ interface TextToSpeechConfigProps {
10
  const TextToSpeechConfig: React.FC<TextToSpeechConfigProps> = ({
11
  className = ''
12
  }) => {
 
13
  const { config, setConfig } = useTextToSpeech()
14
 
15
  return (
16
  <div className={`space-y-4 ${className}`}>
17
- <div className="space-y-2">
18
- <Label htmlFor="speakerEmbeddings" className="text-sm font-medium">
19
- Speaker Embeddings URL
20
- </Label>
21
- <Input
22
- id="speakerEmbeddings"
23
- type="url"
24
- value={config.speakerEmbeddings}
25
- onChange={(e) =>
26
- setConfig((prev) => ({
27
- ...prev,
28
- speakerEmbeddings: e.target.value
29
- }))
30
- }
31
- placeholder="https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/speaker_embeddings.bin"
32
- className="text-sm"
33
- />
34
- <p className="text-xs text-gray-500">
35
- URL to speaker embeddings file for voice characteristics
36
- </p>
37
- </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  </div>
39
  )
40
  }
 
2
  import { Label } from '@/components/ui/label'
3
  import { Input } from '@/components/ui/input'
4
  import { useTextToSpeech } from '../../contexts/TextToSpeechContext'
5
+ import { useModel } from '@/contexts/ModelContext'
6
+ import {
7
+ Select,
8
+ SelectContent,
9
+ SelectItem,
10
+ SelectTrigger,
11
+ SelectValue
12
+ } from '../ui/select'
13
 
14
  interface TextToSpeechConfigProps {
15
  className?: string
 
18
  const TextToSpeechConfig: React.FC<TextToSpeechConfigProps> = ({
19
  className = ''
20
  }) => {
21
+ const { modelInfo } = useModel()
22
  const { config, setConfig } = useTextToSpeech()
23
 
24
  return (
25
  <div className={`space-y-4 ${className}`}>
26
+ {modelInfo?.isStyleTTS2 ? (
27
+ <div className="space-y-2 h-1/3">
28
+ <p className="text-xs text-gray-500">Style TTS2 Model</p>
29
+ <Label htmlFor="speakerEmbeddings" className="text-sm font-medium">
30
+ Select Voice
31
+ </Label>
32
+ <Select
33
+ value={config.voice}
34
+ onValueChange={(value) =>
35
+ setConfig((prev) => ({
36
+ ...prev,
37
+ voice: value
38
+ }))
39
+ }
40
+ >
41
+ <SelectTrigger className="w-full text-sm xl:text-base">
42
+ <SelectValue placeholder="Select a voice" />
43
+ </SelectTrigger>
44
+ <SelectContent className="max-h-96">
45
+ {modelInfo.voices.map((voice) => (
46
+ <SelectItem key={voice} value={voice} className="text-sm">
47
+ {voice}
48
+ </SelectItem>
49
+ ))}
50
+ </SelectContent>
51
+ </Select>
52
+ <p className="text-xs text-gray-500">
53
+ Voice to use for text-to-speech synthesis.
54
+ </p>
55
+ </div>
56
+ ) : (
57
+ <div className="space-y-2">
58
+ <Label htmlFor="speakerEmbeddings" className="text-sm font-medium">
59
+ Speaker Embeddings URL
60
+ </Label>
61
+ <Input
62
+ id="speakerEmbeddings"
63
+ type="url"
64
+ value={config.speakerEmbeddings}
65
+ onChange={(e) =>
66
+ setConfig((prev) => ({
67
+ ...prev,
68
+ speakerEmbeddings: e.target.value
69
+ }))
70
+ }
71
+ placeholder="https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/speaker_embeddings.bin"
72
+ className="text-sm"
73
+ />
74
+ <p className="text-xs text-gray-500">
75
+ URL to speaker embeddings file for voice characteristics
76
+ </p>
77
+ </div>
78
+ )}
79
  </div>
80
  )
81
  }
src/contexts/TextToSpeechContext.tsx CHANGED
@@ -1,13 +1,15 @@
1
  import { createContext, useContext, useState, ReactNode } from 'react'
2
 
3
  export interface TextToSpeechConfigState {
4
- speakerEmbeddings: string
 
5
  }
6
 
7
  export interface AudioResult {
8
  audio: Float32Array
9
  sampling_rate: number
10
  text: string
 
11
  }
12
 
13
  interface TextToSpeechContextType {
@@ -28,14 +30,19 @@ const TextToSpeechContext = createContext<TextToSpeechContextType | undefined>(
28
  export function TextToSpeechProvider({ children }: { children: ReactNode }) {
29
  const [config, setConfig] = useState<TextToSpeechConfigState>({
30
  speakerEmbeddings:
31
- 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/speaker_embeddings.bin'
 
32
  })
33
 
34
  const [audioResults, setAudioResults] = useState<AudioResult[]>([])
35
  const [currentText, setCurrentText] = useState<string>('')
36
 
37
- const addAudioResult = (text: string, audio: Omit<AudioResult, 'text'>) => {
38
- const fullAudioResult: AudioResult = { ...audio, text }
 
 
 
 
39
  setAudioResults((prev) => [...prev, fullAudioResult])
40
  }
41
 
 
1
  import { createContext, useContext, useState, ReactNode } from 'react'
2
 
3
  export interface TextToSpeechConfigState {
4
+ speakerEmbeddings?: string
5
+ voice?: string
6
  }
7
 
8
  export interface AudioResult {
9
  audio: Float32Array
10
  sampling_rate: number
11
  text: string
12
+ voice?: string
13
  }
14
 
15
  interface TextToSpeechContextType {
 
30
  export function TextToSpeechProvider({ children }: { children: ReactNode }) {
31
  const [config, setConfig] = useState<TextToSpeechConfigState>({
32
  speakerEmbeddings:
33
+ 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/speaker_embeddings.bin',
34
+ voice: undefined
35
  })
36
 
37
  const [audioResults, setAudioResults] = useState<AudioResult[]>([])
38
  const [currentText, setCurrentText] = useState<string>('')
39
 
40
+ const addAudioResult = (
41
+ text: string,
42
+ audio: Omit<AudioResult, 'text'>,
43
+ voice?: string
44
+ ) => {
45
+ const fullAudioResult: AudioResult = { ...audio, text, voice }
46
  setAudioResults((prev) => [...prev, fullAudioResult])
47
  }
48
 
src/lib/huggingface.ts CHANGED
@@ -75,6 +75,20 @@ const getModelInfo = async (
75
  return getNumericValue(a) - getNumericValue(b)
76
  })
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  // Fetch README content
79
  const fetchReadme = async (modelId: string): Promise<string> => {
80
  try {
@@ -111,7 +125,8 @@ const getModelInfo = async (
111
  incompatibilityReason,
112
  supportedQuantizations:
113
  uniqueSupportedQuantizations as QuantizationType[],
114
- readme
 
115
  }
116
  }
117
  }
@@ -123,7 +138,8 @@ const getModelInfo = async (
123
  isCompatible,
124
  incompatibilityReason,
125
  supportedQuantizations: uniqueSupportedQuantizations as QuantizationType[],
126
- readme
 
127
  }
128
  }
129
 
@@ -180,7 +196,7 @@ const getModelsByPipeline = async (
180
  return uniqueModels
181
  .filter(
182
  (model: ModelInfoResponse) =>
183
- !model.tags.includes('style_text_to_speech_2') &&
184
  !model.id.includes('qwen2')
185
  )
186
  .slice(0, 30)
 
75
  return getNumericValue(a) - getNumericValue(b)
76
  })
77
 
78
+ if (
79
+ uniqueSupportedQuantizations.length === 0 &&
80
+ siblingFiles.some((file) => file.endsWith('_quantized.onnx'))
81
+ ) {
82
+ uniqueSupportedQuantizations.push('q8')
83
+ }
84
+
85
+ const voices: string[] = []
86
+ siblingFiles
87
+ .filter((file) => file.startsWith('voices/') && !file.endsWith('af.bin'))
88
+ .forEach((file) => {
89
+ voices.push(file.split('/')[1].split('.')[0])
90
+ })
91
+
92
  // Fetch README content
93
  const fetchReadme = async (modelId: string): Promise<string> => {
94
  try {
 
125
  incompatibilityReason,
126
  supportedQuantizations:
127
  uniqueSupportedQuantizations as QuantizationType[],
128
+ readme,
129
+ voices
130
  }
131
  }
132
  }
 
138
  isCompatible,
139
  incompatibilityReason,
140
  supportedQuantizations: uniqueSupportedQuantizations as QuantizationType[],
141
+ readme,
142
+ voices
143
  }
144
  }
145
 
 
196
  return uniqueModels
197
  .filter(
198
  (model: ModelInfoResponse) =>
199
+ // !model.tags.includes('style_text_to_speech_2') &&
200
  !model.id.includes('qwen2')
201
  )
202
  .slice(0, 30)
src/types.ts CHANGED
@@ -85,8 +85,10 @@ export interface TextToSpeechWorkerInput {
85
  text: string
86
  model: string
87
  dtype: QuantizationType
 
88
  config?: {
89
  speakerEmbeddings?: string
 
90
  }
91
  }
92
 
@@ -157,8 +159,10 @@ export interface ModelInfo {
157
  supportedQuantizations: QuantizationType[]
158
  baseId?: string
159
  readme?: string
160
- hasChatTemplate: boolean
 
161
  widgetData?: any
 
162
  }
163
 
164
  export interface ModelInfoResponse {
@@ -202,4 +206,5 @@ export interface ModelInfoResponse {
202
  likes: number
203
  downloads: number
204
  readme?: string
 
205
  }
 
85
  text: string
86
  model: string
87
  dtype: QuantizationType
88
+ isStyleTTS2: boolean
89
  config?: {
90
  speakerEmbeddings?: string
91
+ voice?: string
92
  }
93
  }
94
 
 
159
  supportedQuantizations: QuantizationType[]
160
  baseId?: string
161
  readme?: string
162
+ hasChatTemplate: boolean // text-generation only
163
+ isStyleTTS2: boolean // text-to-speech only
164
  widgetData?: any
165
+ voices: string[] // text-to-speech only
166
  }
167
 
168
  export interface ModelInfoResponse {
 
206
  likes: number
207
  downloads: number
208
  readme?: string
209
+ voices: string[] // text-to-speech only
210
  }