{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8Aa-nRCzPVdF"
      },
      "source": [
        "# IndicTrans2 HF Inference\n",
        "\n",
        "We provide an example notebook on how to use our IndicTrans2 models which were originally trained with the fairseq to HuggingFace transformers for inference purpose.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Cfsv02IeP2It"
      },
      "source": [
        "## Setup\n",
        "\n",
        "Please run the cells below to install the necessary dependencies.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qKcYlUZYGLrt"
      },
      "outputs": [],
      "source": [
        "%%capture\n",
        "!git clone https://github.com/AI4Bharat/IndicTrans2.git"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "U3vs7FkIGSxK"
      },
      "outputs": [],
      "source": [
        "%%capture\n",
        "%cd /content/IndicTrans2/huggingface_interface"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ddkRAXQ2Git0"
      },
      "outputs": [],
      "source": [
        "%%capture\n",
        "!python3 -m pip install nltk sacremoses pandas regex mock transformers>=4.33.2 mosestokenizer\n",
        "!python3 -c \"import nltk; nltk.download('punkt')\"\n",
        "!python3 -m pip install bitsandbytes scipy accelerate datasets\n",
        "!python3 -m pip install sentencepiece\n",
        "\n",
        "!git clone https://github.com/VarunGumma/IndicTransToolkit.git\n",
        "%cd IndicTransToolkit\n",
        "!python3 -m pip install --editable ./\n",
        "%cd .."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hjN7ub1tO33H"
      },
      "source": [
        "**IMPORTANT : Restart your run-time first and then run the cells below.**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_SLBIw6rQB-0"
      },
      "source": [
        "## Inference\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fYczM2U6G1Zv"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from transformers import AutoModelForSeq2SeqLM, BitsAndBytesConfig, AutoTokenizer\n",
        "from IndicTransToolkit import IndicProcessor\n",
        "\n",
        "BATCH_SIZE = 4\n",
        "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "quantization = None"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xj1WCNjuHG-d"
      },
      "outputs": [],
      "source": [
        "def initialize_model_and_tokenizer(ckpt_dir, quantization):\n",
        "    if quantization == \"4-bit\":\n",
        "        qconfig = BitsAndBytesConfig(\n",
        "            load_in_4bit=True,\n",
        "            bnb_4bit_use_double_quant=True,\n",
        "            bnb_4bit_compute_dtype=torch.bfloat16,\n",
        "        )\n",
        "    elif quantization == \"8-bit\":\n",
        "        qconfig = BitsAndBytesConfig(\n",
        "            load_in_8bit=True,\n",
        "            bnb_8bit_use_double_quant=True,\n",
        "            bnb_8bit_compute_dtype=torch.bfloat16,\n",
        "        )\n",
        "    else:\n",
        "        qconfig = None\n",
        "\n",
        "    tokenizer = AutoTokenizer.from_pretrained(ckpt_dir, trust_remote_code=True)\n",
        "    model = AutoModelForSeq2SeqLM.from_pretrained(\n",
        "        ckpt_dir,\n",
        "        trust_remote_code=True,\n",
        "        low_cpu_mem_usage=True,\n",
        "        quantization_config=qconfig,\n",
        "    )\n",
        "\n",
        "    if qconfig == None:\n",
        "        model = model.to(DEVICE)\n",
        "        if DEVICE == \"cuda\":\n",
        "            model.half()\n",
        "\n",
        "    model.eval()\n",
        "\n",
        "    return tokenizer, model\n",
        "\n",
        "\n",
        "def batch_translate(input_sentences, src_lang, tgt_lang, model, tokenizer, ip):\n",
        "    translations = []\n",
        "    for i in range(0, len(input_sentences), BATCH_SIZE):\n",
        "        batch = input_sentences[i : i + BATCH_SIZE]\n",
        "\n",
        "        # Preprocess the batch and extract entity mappings\n",
        "        batch = ip.preprocess_batch(batch, src_lang=src_lang, tgt_lang=tgt_lang)\n",
        "\n",
        "        # Tokenize the batch and generate input encodings\n",
        "        inputs = tokenizer(\n",
        "            batch,\n",
        "            truncation=True,\n",
        "            padding=\"longest\",\n",
        "            return_tensors=\"pt\",\n",
        "            return_attention_mask=True,\n",
        "        ).to(DEVICE)\n",
        "\n",
        "        # Generate translations using the model\n",
        "        with torch.no_grad():\n",
        "            generated_tokens = model.generate(\n",
        "                **inputs,\n",
        "                use_cache=True,\n",
        "                min_length=0,\n",
        "                max_length=256,\n",
        "                num_beams=5,\n",
        "                num_return_sequences=1,\n",
        "            )\n",
        "\n",
        "        # Decode the generated tokens into text\n",
        "\n",
        "        with tokenizer.as_target_tokenizer():\n",
        "            generated_tokens = tokenizer.batch_decode(\n",
        "                generated_tokens.detach().cpu().tolist(),\n",
        "                skip_special_tokens=True,\n",
        "                clean_up_tokenization_spaces=True,\n",
        "            )\n",
        "\n",
        "        # Postprocess the translations, including entity replacement\n",
        "        translations += ip.postprocess_batch(generated_tokens, lang=tgt_lang)\n",
        "\n",
        "        del inputs\n",
        "        torch.cuda.empty_cache()\n",
        "\n",
        "    return translations"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "erNCuZTEMt49"
      },
      "source": [
        "### English to Indic Example\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6OG3Bw-sHnf3",
        "outputId": "a204f50e-9456-4fb1-900a-e60680b97b99"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "eng_Latn - hin_Deva\n",
            "eng_Latn: When I was young, I used to go to the park every day.\n",
            "hin_Deva: जब मैं छोटा था, मैं हर दिन पार्क जाता था।\n",
            "eng_Latn: He has many old books, which he inherited from his ancestors.\n",
            "hin_Deva: उनके पास कई पुरानी किताबें हैं, जो उन्हें अपने पूर्वजों से विरासत में मिली हैं।\n",
            "eng_Latn: I can't figure out how to solve my problem.\n",
            "hin_Deva: मुझे समझ नहीं आ रहा है कि मैं अपनी समस्या का समाधान कैसे करूं।\n",
            "eng_Latn: She is very hardworking and intelligent, which is why she got all the good marks.\n",
            "hin_Deva: वह बहुत मेहनती और बुद्धिमान है, यही कारण है कि उसे सभी अच्छे अंक मिले।\n",
            "eng_Latn: We watched a new movie last week, which was very inspiring.\n",
            "hin_Deva: हमने पिछले हफ्ते एक नई फिल्म देखी, जो बहुत प्रेरणादायक थी।\n",
            "eng_Latn: If you had met me at that time, we would have gone out to eat.\n",
            "hin_Deva: अगर आप उस समय मुझसे मिलते तो हम बाहर खाना खाने जाते।\n",
            "eng_Latn: She went to the market with her sister to buy a new sari.\n",
            "hin_Deva: वह अपनी बहन के साथ नई साड़ी खरीदने के लिए बाजार गई थी।\n",
            "eng_Latn: Raj told me that he is going to his grandmother's house next month.\n",
            "hin_Deva: राज ने मुझे बताया कि वह अगले महीने अपनी दादी के घर जा रहा है।\n",
            "eng_Latn: All the kids were having fun at the party and were eating lots of sweets.\n",
            "hin_Deva: पार्टी में सभी बच्चे खूब मस्ती कर रहे थे और खूब मिठाइयां खा रहे थे।\n",
            "eng_Latn: My friend has invited me to his birthday party, and I will give him a gift.\n",
            "hin_Deva: मेरे दोस्त ने मुझे अपने जन्मदिन की पार्टी में आमंत्रित किया है, और मैं उसे एक उपहार दूंगा।\n"
          ]
        }
      ],
      "source": [
        "en_indic_ckpt_dir = \"ai4bharat/indictrans2-en-indic-1B\"  # ai4bharat/indictrans2-en-indic-dist-200M\n",
        "en_indic_tokenizer, en_indic_model = initialize_model_and_tokenizer(en_indic_ckpt_dir, quantization)\n",
        "\n",
        "ip = IndicProcessor(inference=True)\n",
        "\n",
        "en_sents = [\n",
        "    \"When I was young, I used to go to the park every day.\",\n",
        "    \"He has many old books, which he inherited from his ancestors.\",\n",
        "    \"I can't figure out how to solve my problem.\",\n",
        "    \"She is very hardworking and intelligent, which is why she got all the good marks.\",\n",
        "    \"We watched a new movie last week, which was very inspiring.\",\n",
        "    \"If you had met me at that time, we would have gone out to eat.\",\n",
        "    \"She went to the market with her sister to buy a new sari.\",\n",
        "    \"Raj told me that he is going to his grandmother's house next month.\",\n",
        "    \"All the kids were having fun at the party and were eating lots of sweets.\",\n",
        "    \"My friend has invited me to his birthday party, and I will give him a gift.\",\n",
        "]\n",
        "\n",
        "src_lang, tgt_lang = \"eng_Latn\", \"hin_Deva\"\n",
        "hi_translations = batch_translate(en_sents, src_lang, tgt_lang, en_indic_model, en_indic_tokenizer, ip)\n",
        "\n",
        "print(f\"\\n{src_lang} - {tgt_lang}\")\n",
        "for input_sentence, translation in zip(en_sents, hi_translations):\n",
        "    print(f\"{src_lang}: {input_sentence}\")\n",
        "    print(f\"{tgt_lang}: {translation}\")\n",
        "\n",
        "# flush the models to free the GPU memory\n",
        "del en_indic_tokenizer, en_indic_model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OM_1pbPtMpV9"
      },
      "source": [
        "### Indic to English Example"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "PLCEWJKvGG9I",
        "outputId": "ab9d8726-67c7-490b-ecb3-208df1c0f741"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "hin_Deva - eng_Latn\n",
            "hin_Deva: जब मैं छोटा था, मैं हर रोज़ पार्क जाता था।\n",
            "eng_Latn: When I was young, I used to go to the park every day.\n",
            "hin_Deva: उसके पास बहुत सारी पुरानी किताबें हैं, जिन्हें उसने अपने दादा-परदादा से विरासत में पाया।\n",
            "eng_Latn: She has a lot of old books, which she inherited from her grandparents.\n",
            "hin_Deva: मुझे समझ में नहीं आ रहा कि मैं अपनी समस्या का समाधान कैसे ढूंढूं।\n",
            "eng_Latn: I don't know how to find a solution to my problem.\n",
            "hin_Deva: वह बहुत मेहनती और समझदार है, इसलिए उसे सभी अच्छे मार्क्स मिले।\n",
            "eng_Latn: He is very hardworking and understanding, so he got all the good marks.\n",
            "hin_Deva: हमने पिछले सप्ताह एक नई फिल्म देखी जो कि बहुत प्रेरणादायक थी।\n",
            "eng_Latn: We saw a new movie last week that was very inspiring.\n",
            "hin_Deva: अगर तुम मुझे उस समय पास मिलते, तो हम बाहर खाना खाने चलते।\n",
            "eng_Latn: If you'd given me a pass at that time, we'd have gone out to eat.\n",
            "hin_Deva: वह अपनी दीदी के साथ बाजार गयी थी ताकि वह नई साड़ी खरीद सके।\n",
            "eng_Latn: She had gone to the market with her sister so that she could buy a new sari.\n",
            "hin_Deva: राज ने मुझसे कहा कि वह अगले महीने अपनी नानी के घर जा रहा है।\n",
            "eng_Latn: Raj told me that he was going to his grandmother's house next month.\n",
            "hin_Deva: सभी बच्चे पार्टी में मज़ा कर रहे थे और खूब सारी मिठाइयाँ खा रहे थे।\n",
            "eng_Latn: All the children were having fun at the party and eating a lot of sweets.\n",
            "hin_Deva: मेरे मित्र ने मुझे उसके जन्मदिन की पार्टी में बुलाया है, और मैं उसे एक तोहफा दूंगा।\n",
            "eng_Latn: My friend has invited me to her birthday party, and I'll give her a present.\n"
          ]
        }
      ],
      "source": [
        "indic_en_ckpt_dir = \"ai4bharat/indictrans2-indic-en-1B\"  # ai4bharat/indictrans2-indic-en-dist-200M\n",
        "indic_en_tokenizer, indic_en_model = initialize_model_and_tokenizer(indic_en_ckpt_dir, quantization)\n",
        "\n",
        "ip = IndicProcessor(inference=True)\n",
        "\n",
        "hi_sents = [\n",
        "    \"जब मैं छोटा था, मैं हर रोज़ पार्क जाता था।\",\n",
        "    \"उसके पास बहुत सारी पुरानी किताबें हैं, जिन्हें उसने अपने दादा-परदादा से विरासत में पाया।\",\n",
        "    \"मुझे समझ में नहीं आ रहा कि मैं अपनी समस्या का समाधान कैसे ढूंढूं।\",\n",
        "    \"वह बहुत मेहनती और समझदार है, इसलिए उसे सभी अच्छे मार्क्स मिले।\",\n",
        "    \"हमने पिछले सप्ताह एक नई फिल्म देखी जो कि बहुत प्रेरणादायक थी।\",\n",
        "    \"अगर तुम मुझे उस समय पास मिलते, तो हम बाहर खाना खाने चलते।\",\n",
        "    \"वह अपनी दीदी के साथ बाजार गयी थी ताकि वह नई साड़ी खरीद सके।\",\n",
        "    \"राज ने मुझसे कहा कि वह अगले महीने अपनी नानी के घर जा रहा है।\",\n",
        "    \"सभी बच्चे पार्टी में मज़ा कर रहे थे और खूब सारी मिठाइयाँ खा रहे थे।\",\n",
        "    \"मेरे मित्र ने मुझे उसके जन्मदिन की पार्टी में बुलाया है, और मैं उसे एक तोहफा दूंगा।\",\n",
        "]\n",
        "src_lang, tgt_lang = \"hin_Deva\", \"eng_Latn\"\n",
        "en_translations = batch_translate(hi_sents, src_lang, tgt_lang, indic_en_model, indic_en_tokenizer, ip)\n",
        "\n",
        "\n",
        "print(f\"\\n{src_lang} - {tgt_lang}\")\n",
        "for input_sentence, translation in zip(hi_sents, en_translations):\n",
        "    print(f\"{src_lang}: {input_sentence}\")\n",
        "    print(f\"{tgt_lang}: {translation}\")\n",
        "\n",
        "# flush the models to free the GPU memory\n",
        "del indic_en_tokenizer, indic_en_model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7VCAkyKBGtnV"
      },
      "source": [
        "### Indic to Indic Example\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "_7TxTTCoKjti",
        "outputId": "df1a750b-0f32-478d-cfc9-e445f669f3ee"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "hin_Deva - mar_Deva\n",
            "hin_Deva: जब मैं छोटा था, मैं हर रोज़ पार्क जाता था।\n",
            "mar_Deva: मी लहान होतो तेव्हा मी दररोज उद्यानाला जायचे.\n",
            "hin_Deva: उसके पास बहुत सारी पुरानी किताबें हैं, जिन्हें उसने अपने दादा-परदादा से विरासत में पाया।\n",
            "mar_Deva: तिच्याकडे बरीच जुनी पुस्तके आहेत, जी तिला तिच्या आजोबांकडून वारशाने मिळाली आहेत.\n",
            "hin_Deva: मुझे समझ में नहीं आ रहा कि मैं अपनी समस्या का समाधान कैसे ढूंढूं।\n",
            "mar_Deva: माझ्या समस्येवर तोडगा कसा काढायचा हे मला समजत नाही.\n",
            "hin_Deva: वह बहुत मेहनती और समझदार है, इसलिए उसे सभी अच्छे मार्क्स मिले।\n",
            "mar_Deva: तो खूप मेहनती आणि बुद्धिमान आहे, त्यामुळे त्याला सर्व चांगले गुण मिळाले.\n",
            "hin_Deva: हमने पिछले सप्ताह एक नई फिल्म देखी जो कि बहुत प्रेरणादायक थी।\n",
            "mar_Deva: आम्ही गेल्या आठवड्यात एक नवीन चित्रपट पाहिला जो खूप प्रेरणादायी होता.\n",
            "hin_Deva: अगर तुम मुझे उस समय पास मिलते, तो हम बाहर खाना खाने चलते।\n",
            "mar_Deva: जर तुम्हाला त्या वेळी मला पास मिळाला तर आम्ही बाहेर जेवायला जाऊ.\n",
            "hin_Deva: वह अपनी दीदी के साथ बाजार गयी थी ताकि वह नई साड़ी खरीद सके।\n",
            "mar_Deva: ती तिच्या बहिणीसोबत बाजारात गेली होती जेणेकरून ती नवीन साडी खरेदी करू शकेल.\n",
            "hin_Deva: राज ने मुझसे कहा कि वह अगले महीने अपनी नानी के घर जा रहा है।\n",
            "mar_Deva: राजने मला सांगितले की तो पुढच्या महिन्यात त्याच्या आजीच्या घरी जात आहे.\n",
            "hin_Deva: सभी बच्चे पार्टी में मज़ा कर रहे थे और खूब सारी मिठाइयाँ खा रहे थे।\n",
            "mar_Deva: सर्व मुले पार्टीचा आनंद घेत होती आणि भरपूर मिठाई खात होती.\n",
            "hin_Deva: मेरे मित्र ने मुझे उसके जन्मदिन की पार्टी में बुलाया है, और मैं उसे एक तोहफा दूंगा।\n",
            "mar_Deva: माझ्या मित्राने मला त्याच्या वाढदिवसाच्या मेजवानीसाठी आमंत्रित केले आहे आणि मी त्याला भेटवस्तू देईन.\n"
          ]
        }
      ],
      "source": [
        "indic_indic_ckpt_dir = \"ai4bharat/indictrans2-indic-indic-1B\"  # ai4bharat/indictrans2-indic-indic-dist-320M\n",
        "indic_indic_tokenizer, indic_indic_model = initialize_model_and_tokenizer(indic_indic_ckpt_dir, quantization)\n",
        "\n",
        "ip = IndicProcessor(inference=True)\n",
        "\n",
        "hi_sents = [\n",
        "    \"जब मैं छोटा था, मैं हर रोज़ पार्क जाता था।\",\n",
        "    \"उसके पास बहुत सारी पुरानी किताबें हैं, जिन्हें उसने अपने दादा-परदादा से विरासत में पाया।\",\n",
        "    \"मुझे समझ में नहीं आ रहा कि मैं अपनी समस्या का समाधान कैसे ढूंढूं।\",\n",
        "    \"वह बहुत मेहनती और समझदार है, इसलिए उसे सभी अच्छे मार्क्स मिले।\",\n",
        "    \"हमने पिछले सप्ताह एक नई फिल्म देखी जो कि बहुत प्रेरणादायक थी।\",\n",
        "    \"अगर तुम मुझे उस समय पास मिलते, तो हम बाहर खाना खाने चलते।\",\n",
        "    \"वह अपनी दीदी के साथ बाजार गयी थी ताकि वह नई साड़ी खरीद सके।\",\n",
        "    \"राज ने मुझसे कहा कि वह अगले महीने अपनी नानी के घर जा रहा है।\",\n",
        "    \"सभी बच्चे पार्टी में मज़ा कर रहे थे और खूब सारी मिठाइयाँ खा रहे थे।\",\n",
        "    \"मेरे मित्र ने मुझे उसके जन्मदिन की पार्टी में बुलाया है, और मैं उसे एक तोहफा दूंगा।\",\n",
        "]\n",
        "src_lang, tgt_lang = \"hin_Deva\", \"mar_Deva\"\n",
        "mr_translations = batch_translate(hi_sents, src_lang, tgt_lang, indic_indic_model, indic_indic_tokenizer, ip)\n",
        "\n",
        "print(f\"\\n{src_lang} - {tgt_lang}\")\n",
        "for input_sentence, translation in zip(hi_sents, mr_translations):\n",
        "    print(f\"{src_lang}: {input_sentence}\")\n",
        "    print(f\"{tgt_lang}: {translation}\")\n",
        "\n",
        "# flush the models to free the GPU memory\n",
        "del indic_indic_tokenizer, indic_indic_model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uyxXpt--Ma6n"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}