File size: 7,933 Bytes
d4e3b98
 
 
 
 
 
 
 
1dac694
d4e3b98
1dac694
d4e3b98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1dac694
 
 
 
 
 
 
 
 
 
 
d4e3b98
 
 
 
 
 
 
 
 
1dac694
d4e3b98
 
1dac694
 
d4e3b98
 
1dac694
d4e3b98
 
 
 
 
 
 
1dac694
 
d4e3b98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
from snac import SNAC
import numpy as np
import torch
import asyncio
import threading
import queue
import os

# Kartoffel-spezifische Konstanten (basierend auf Referenz-Implementierung)
CODE_TOKEN_OFFSET = 128266
CODE_START_TOKEN_ID = 128257  # Token für Audio-Code-Start
CODE_REMOVE_TOKEN_ID = 128258

print("DEBUG KARTOFFEL: Loading SNAC model...")
model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval()

snac_device = os.environ.get("SNAC_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
model = model.to(snac_device)
if snac_device == "cuda":
    model = model.half()
model.eval()
print(f"DEBUG KARTOFFEL: SNAC model loaded successfully on device: {snac_device}")

def redistribute_codes_kartoffel(code_list):
    """Kartoffel-spezifische Code-Redistribution"""
    if not code_list:
        return torch.tensor([[]], device=snac_device, dtype=torch.float32)

    num_codes = len(code_list)
    num_groups = num_codes // 7
    if num_groups == 0:
        return torch.tensor([[]], device=snac_device, dtype=torch.float32)

    # Nur vollständige 7er-Gruppen verwenden
    code_list = code_list[:num_groups * 7]

    layer_1, layer_2, layer_3 = [], [], []
    for i in range(num_groups):
        base_idx = 7 * i
        try:
            layer_1.append(code_list[base_idx])
            layer_2.append(code_list[base_idx + 1] - 4096)
            layer_3.append(code_list[base_idx + 2] - (2 * 4096))
            layer_3.append(code_list[base_idx + 3] - (3 * 4096))
            layer_2.append(code_list[base_idx + 4] - (4 * 4096))
            layer_3.append(code_list[base_idx + 5] - (5 * 4096))
            layer_3.append(code_list[base_idx + 6] - (6 * 4096))
        except IndexError:
            print(f"DEBUG KARTOFFEL: IndexError during code redistribution at group {i}. Skipping group.")
            break

    if not layer_1:
        return torch.tensor([[]], device=snac_device, dtype=torch.float32)

    codes = [
        torch.tensor(layer_1, device=snac_device).unsqueeze(0),
        torch.tensor(layer_2, device=snac_device).unsqueeze(0),
        torch.tensor(layer_3, device=snac_device).unsqueeze(0),
    ]

    with torch.no_grad():
        audio_hat = model.decode(codes)
    return audio_hat

def convert_to_audio_kartoffel(audio_tensor):
    """Konvertiert Audio-Tensor zu PCM16-Bytes"""
    if audio_tensor is None or audio_tensor.numel() == 0:
        return b''

    # Audio zu PCM16 konvertieren
    audio_numpy = (audio_tensor.squeeze().cpu().to(torch.float32).numpy() * 32767)
    audio_numpy = np.clip(audio_numpy, -32768, 32767).astype(np.int16)
    return audio_numpy.tobytes()

def extract_kartoffel_tokens(token_text, tokenizer):
    """Extrahiert Audio-Token-IDs aus dem generierten Text"""
    try:
        print(f"DEBUG KARTOFFEL: Received token_text: {token_text}")
        
        # Prüfen ob es sich um numerische Token-IDs handelt (neues Format)
        if isinstance(token_text, str) and all(c.isdigit() or c.isspace() for c in token_text):
            # Numerische Token-IDs direkt parsen
            token_ids = [int(x) for x in token_text.split()]
            print(f"DEBUG KARTOFFEL: Parsed token_ids from string: {token_ids}")
        else:
            # Fallback: Text zu Token-IDs konvertieren (altes Format)
            token_ids = tokenizer.encode(token_text)
            print(f"DEBUG KARTOFFEL: Encoded token_ids: {token_ids}")
        
        # Nach Start-Token suchen
        start_idx = -1
        for i, token_id in enumerate(token_ids):
            if token_id == CODE_START_TOKEN_ID:
                start_idx = i
                break
        
        if start_idx == -1:
            print(f"DEBUG KARTOFFEL: No start token found ({CODE_START_TOKEN_ID})")
            return []
        
        print(f"DEBUG KARTOFFEL: Found start token at index {start_idx}")
        
        # Audio-Tokens extrahieren (nach Start-Token)
        potential_code_tokens = token_ids[start_idx + 1:]
        print(f"DEBUG KARTOFFEL: Potential code tokens: {potential_code_tokens[:10]}...")
        
        # Nur gültige Audio-Tokens (>= CODE_TOKEN_OFFSET, nicht REMOVE_TOKEN)
        valid_raw_codes = [
            token for token in potential_code_tokens
            if token != CODE_REMOVE_TOKEN_ID and token >= CODE_TOKEN_OFFSET
        ]
        
        print(f"DEBUG KARTOFFEL: Valid raw codes count: {len(valid_raw_codes)}")
        
        # Offset abziehen
        valid_codes = [token - CODE_TOKEN_OFFSET for token in valid_raw_codes]
        
        return valid_codes
        
    except Exception as e:
        print(f"DEBUG KARTOFFEL: Error extracting tokens: {e}")
        return []

async def tokens_decoder_kartoffel(token_gen, tokenizer):
    """Kartoffel-spezifischer Token-Decoder"""
    buffer = []
    accumulated_text = ""
    processed_count = 0
    chunk_size = 28  # 4 Gruppen à 7 Tokens
    
    print("DEBUG KARTOFFEL: Starting token decoding")
    
    async for token_text in token_gen:
        accumulated_text += token_text
        print(f"DEBUG KARTOFFEL: Accumulated text length: {len(accumulated_text)}")
        
        # Audio-Tokens aus dem akkumulierten Text extrahieren
        valid_codes = extract_kartoffel_tokens(accumulated_text, tokenizer)
        
        if len(valid_codes) > processed_count:
            new_codes = valid_codes[processed_count:]
            buffer.extend(new_codes)
            print(f"DEBUG KARTOFFEL: Added {len(new_codes)} new codes. Buffer size: {len(buffer)}")
            
            # Wenn genug Codes für Audio-Generation vorhanden
            while len(buffer) >= chunk_size:
                codes_to_process = buffer[:chunk_size]
                buffer = buffer[chunk_size:]
                processed_count += chunk_size
                
                print(f"DEBUG KARTOFFEL: Processing {len(codes_to_process)} codes")
                
                # Audio generieren
                audio_tensor = redistribute_codes_kartoffel(codes_to_process)
                audio_bytes = convert_to_audio_kartoffel(audio_tensor)
                
                if audio_bytes:
                    print(f"DEBUG KARTOFFEL: Generated {len(audio_bytes)} bytes of audio")
                    yield audio_bytes
                else:
                    print("DEBUG KARTOFFEL: No audio bytes generated")
    
    # Verbleibende Codes verarbeiten
    if len(buffer) >= 7:  # Mindestens eine vollständige Gruppe
        final_count = (len(buffer) // 7) * 7
        final_codes = buffer[:final_count]
        
        print(f"DEBUG KARTOFFEL: Processing final {len(final_codes)} codes")
        
        audio_tensor = redistribute_codes_kartoffel(final_codes)
        audio_bytes = convert_to_audio_kartoffel(audio_tensor)
        
        if audio_bytes:
            print(f"DEBUG KARTOFFEL: Generated final {len(audio_bytes)} bytes of audio")
            yield audio_bytes
    
    print("DEBUG KARTOFFEL: Token decoding completed")

def tokens_decoder_kartoffel_sync(syn_token_gen, tokenizer):
    """Synchroner Wrapper für den Kartoffel-Decoder"""
    audio_queue = queue.Queue()

    # Synchronen Generator zu async konvertieren
    async def async_token_gen():
        for token in syn_token_gen:
            yield token

    async def async_producer():
        try:
            async for audio_chunk in tokens_decoder_kartoffel(async_token_gen(), tokenizer):
                audio_queue.put(audio_chunk)
        except Exception as e:
            print(f"DEBUG KARTOFFEL: Error in async producer: {e}")
            import traceback
            traceback.print_exc()
        finally:
            audio_queue.put(None)  # Sentinel

    def run_async():
        asyncio.run(async_producer())

    thread = threading.Thread(target=run_async)
    thread.start()

    while True:
        audio = audio_queue.get()
        if audio is None:
            break
        yield audio

    thread.join()