import gradio as gr
from datasets import load_dataset, Dataset, Audio, concatenate_datasets
import json
import os
from datetime import datetime
import shutil

# Directory to save recordings
AUDIO_DIR = "data/audios"
SAMPLING_RATE = 16000
os.makedirs(AUDIO_DIR, exist_ok=True)

# State variables
state = {
    "sentences": [],
    "recordings": {},  # Dictionary to store recordings by ID
    "index": 0,  # Index for navigating through sentences
    "idx": 0,  # Index for sentences (IDs)
    "json_loaded": False

}

def load_json(file):
    with open(file.name, "r", encoding="utf-8") as f:
        content = json.load(f)
    state["sentences"].extend(content)
    state["recordings"].update({k["id"]:[] for k in content})
    state["json_loaded"] = True
    return update_display()

def update_display():
    if not state["sentences"]:
        return "No data loaded.", None, "", "", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
    
    idx = state["index"]
    progress = ""
    if state["json_loaded"]:
        if idx >= len(state["sentences"]):
            export_json()
            return "āœ… All sentences recorded!\nšŸ’¾ Data Exported to Json", None, "", "", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)

        progress = 0
        for recordings in state["recordings"].values(): 
            if len(recordings) > 0: 
                progress += 1
        progress = f"{progress} / {len(state['sentences'])} recorded"
    
    
    # Enable/Disable buttons based on the current index
    next_btn_enabled = gr.update(visible= not (state["index"] == len(state["sentences"]) - 1))
    prev_btn_enabled = gr.update(visible= not (state["index"] == 0))
    
    recordings = []
    text = ""
    current_id = f"s_{state['idx']}"
    if idx < len(state["sentences"]):
        current = state["sentences"][idx]
        current_id = current['id']
        text = current["text"]
        recordings = state["recordings"].get(current["id"], [])
        
    if recordings:
        # Get the most recent recording for that sentence ID
        current_recording = recordings[-1]
        current_audio = current_recording["audio"]
        audio_visibility = gr.update(visible=True)
    else:
        current_audio = None
        audio_visibility = gr.update(visible=False)

    return text, None, f"ID: {current_id}", progress, gr.update(visible=True), prev_btn_enabled, next_btn_enabled, current_audio, audio_visibility

def record_audio(audio, text):
    if state["sentences"] and state["index"] >= len(state["sentences"]):
        return update_display()

    if audio is None: 
        gr.Warning("The audio is empty, please provide a valid audio")
        return update_display()
    if state["json_loaded"]:
        state["sentences"][state["index"]]["text"] = text # overwrite with current written value
    else:
        state["sentences"].append({"id": f"s_{state['idx']}", "text": text})
        state["idx"] += 1

    sentence = state["sentences"][state["index"]]
    uid = sentence["id"]

    filename = f"{uid}_{datetime.now().strftime('%Y%m%d%H%M%S')}.wav"
    filepath = os.path.join(AUDIO_DIR, filename)

    shutil.copy(audio, filepath)

    # Add the new recording under the correct ID in the recordings dictionary

    uid_versioning = uid
    recordings = state["recordings"].get(uid, [])
    if recordings:
        uid_versioning = f"{uid}_v{len(recordings)}"
    
    state["recordings"].setdefault(uid, []).append({
        "id": uid_versioning,
        "text": sentence["text"],
        "audio": filepath
    })
    state["index"] += 1
    return update_display()

def export_json():
    output_path = "data/tts_dataset.json"
    data = [record for records in state["recordings"].values() for record in records]
    if data: 
        with open(output_path, "w") as f:
            json.dump(data, f, indent=2)
    else:
        gr.Warning("There is no recorded data")
    return output_path

def go_previous():
    if state["index"] > 0:
        state["index"] -= 1
    return update_display()

def go_next():
    if state["index"] < len(state["sentences"]) - 1:
        state["index"] += 1
    return update_display()
def push_to_hub(hub_id, is_new_dataset, sampling_rate):
    if hub_id:
        # flatten recordings 
        recordings = []
        for element in state["recordings"].values():
            for version in element:
                recordings.append({"id": version["id"], "audio": version["audio"], "text": version["text"]}) 
            
        dataset = Dataset.from_list(recordings)
        dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))
        if not is_new_dataset:
            previous_dataset = load_dataset(hub_id, split="train")
            dataset = concatenate_datasets([previous_dataset, dataset])
        dataset.push_to_hub(hub_id)
        gr.Info("Succesfully synched with the hub")
    else:
        gr.Warning("The hub_id field is empty, please provide a relevant hub id.")
    return update_display()
with gr.Blocks() as demo:
    gr.Markdown("""# šŸ—£ļø TTS Dataset Recorder

Welcome to the **TTS Dataset Recorder**! This tool helps you quickly create a high-quality dataset for Text-to-Speech (TTS) models. Whether you're starting from scratch or have a pre-existing set of text data, this app lets you record audio samples and export them with the corresponding metadata.

### **How to Use?**
1. **Upload a JSON File** containing the sentences you'd like to record (or manually input them through the app).
2. **Record Audio** for each sentence. The app will automatically associate your recordings with the correct text.
3. **Export the Dataset** as a JSON file or **Sync** to HuggingFace for easy sharing and use.

### **Data Input Format**
Your JSON file should follow this structure:
```json
[
    { "id": "001", "text": "Hello, how are you?" },
    { "id": "002", "text": "This is a sample sentence." }
]

                """)
    
    with gr.Row():
        json_file = gr.File(label="Upload Sentences JSON", file_types=[".json"])
        with gr.Column():
            export_btn = gr.Button("šŸ’¾ Export Metadata")
            with gr.Row():
                hub_id = gr.Textbox(label="Hub id", interactive=True)
                with gr.Row():
                    is_new_dataset = gr.Checkbox(label="New dataset", interactive=True)
                    sampling_rate = gr.Number(label="Sampling rate", value=SAMPLING_RATE, precision=0)
            push_to_hub_btn = gr.Button("šŸ¤— Sync to HuggingFace")
    
    id_display = gr.Textbox(label="ID", interactive=False)
    progress_text = gr.Textbox(label="Progress", interactive=False)
    sentence_text = gr.Textbox(label="Sentence", interactive=True)
    audio_input = gr.Audio(type="filepath", label="Record your voice", interactive=True)
    record_btn = gr.Button("āœ… Submit Recording")

    
    with gr.Row():
        prev_btn = gr.Button("ā¬…ļø Previous")
        next_btn = gr.Button("āž”ļø Next")

    # audio_player = gr.Audio(label="Play Recorded Audio", interactive=False)
    audio_player = gr.Audio(label="Play Recorded Audio", type="filepath")


    json_file.change(load_json, inputs=json_file, outputs=[sentence_text, audio_input, id_display, progress_text, record_btn, prev_btn, next_btn, audio_player, audio_player])
    record_btn.click(record_audio, inputs=[audio_input, sentence_text], outputs=[sentence_text, audio_input, id_display, progress_text, record_btn, prev_btn, next_btn, audio_player, audio_player])
    export_btn.click(export_json, outputs=gr.File())

    prev_btn.click(go_previous, outputs=[sentence_text, audio_input, id_display, progress_text, record_btn, prev_btn, next_btn, audio_player, audio_player])
    next_btn.click(go_next, outputs=[sentence_text, audio_input, id_display, progress_text, record_btn, prev_btn, next_btn, audio_player, audio_player])


    push_to_hub_btn.click(push_to_hub, inputs=[hub_id, is_new_dataset], outputs=[sentence_text, audio_input, id_display, progress_text, record_btn, prev_btn, next_btn, audio_player, audio_player])

demo.launch()