Spaces:
Build error
Build error
| import os, sys | |
| # from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, MBartForConditionalGeneration | |
| # import torch | |
| import gradio as gr | |
| import requests | |
| import json | |
| # from huggingface_hub import login | |
| class LTRC_Translation_API(): | |
| def __init__(self, url = 'https://ssmt.iiit.ac.in/onemt', src_lang = 'en', tgt_lang = 'te'): | |
| self.lang_map = {'te': 'tel', 'en': 'eng', 'ta': 'tam', 'ml': 'mal', 'mr': 'mar', 'kn': 'kan', 'hi': 'hin'} | |
| self.url = url | |
| self.headers = { | |
| 'Content-Type': 'application/json', | |
| 'Accept': 'application/json' | |
| } | |
| tgt_lang = self.lang_map.get(tgt_lang, 'te') | |
| src_lang = self.lang_map.get(src_lang, 'en') | |
| self.src_lang = src_lang | |
| self.tgt_lang = tgt_lang | |
| def translate(self, text): | |
| try: | |
| data = {'text': text, 'source_language': self.src_lang, 'target_language': self.tgt_lang} | |
| response = requests.post(self.url, headers = self.headers, json = data) | |
| translated_text = json.loads(response.text).get('data', '') | |
| return translated_text | |
| except Exception as e: | |
| print("Exception: ", e) | |
| return '' | |
| # class Headline_Generation(): | |
| # def __init__(self, model_name = "lokeshmadasu42/sample"): | |
| # self.model_name = model_name | |
| # self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| # self.tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False, use_fast=False, keep_accents=True) | |
| # self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
| # self.model.to(self.device) | |
| # self.model.eval() | |
| # self.bos_id = self.tokenizer._convert_token_to_id_with_added_voc("<s>") | |
| # self.eos_id = self.tokenizer._convert_token_to_id_with_added_voc("</s>") | |
| # self.pad_id = self.tokenizer._convert_token_to_id_with_added_voc("<pad>") | |
| # self.lang_map = {'as': '<2as>', 'bn': '<2bn>', 'en': '<2en>', 'gu': '<2gu>', 'hi': '<2hi>', 'kn': '<2kn>', 'ml': '<2ml>', 'mr': '<2mr>', 'or': '<2or>', 'pa': '<2pa>', 'ta': '<2ta>', 'te': '<2te>'} | |
| # print("Headline Generation model loaded...!") | |
| # def get_headline(self, text, lang_id): | |
| # inp = self.tokenizer(text, add_special_tokens=False, return_tensors="pt", padding=True).to(self.device) | |
| # inp = inp['input_ids'] | |
| # lang_code = self.lang_map.get(lang_id, '') | |
| # text = text + "</s> " + lang_code | |
| # # print("Text: ", text) | |
| # model_output = self.model.generate( | |
| # inp, | |
| # use_cache=True, | |
| # num_beams=5, | |
| # max_length=32, | |
| # min_length=1, | |
| # early_stopping=True, | |
| # pad_token_id = self.pad_id, | |
| # bos_token_id = self.bos_id, | |
| # eos_token_id = self.eos_id, | |
| # decoder_start_token_id = self.tokenizer._convert_token_to_id_with_added_voc(lang_code) | |
| # ) | |
| # decoded_output = self.tokenizer.decode( | |
| # model_output[0], | |
| # skip_special_tokens=True, | |
| # clean_up_tokenization_spaces=False | |
| # ) | |
| # return decoded_output | |
| # class Summarization(): | |
| # def __init__(self, model_name = "ashokurlana/mBART-TeSum"): | |
| # self.model_name = model_name | |
| # self.device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") | |
| # self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
| # self.model.to(self.device) | |
| # self.model.eval() | |
| # self.lang_map = {'te': 'te_IN', 'en': 'en_XX'} | |
| # print("Summarization model loaded...!") | |
| # def get_summary(self, text, lang_id): | |
| # inp = self.tokenizer([text], add_special_tokens=False, return_tensors="pt", max_length = 1024).to(self.device) | |
| # inp = inp['input_ids'] | |
| # lang_code = self.lang_map.get(lang_id, '') | |
| # model_output = self.model.generate( | |
| # inp, | |
| # use_cache=True, | |
| # num_beams=5, | |
| # max_length=256, | |
| # early_stopping=True | |
| # ) | |
| # decoded_output = [self.tokenizer.decode( | |
| # summ_id, | |
| # skip_special_tokens=True, | |
| # clean_up_tokenization_spaces=False | |
| # ) for summ_id in model_output] | |
| # return " ".join(decoded_output) | |
| def get_prediction(text, src_lang_id, tgt_lang_id, translate = False): | |
| # if len(sys.argv)<3: | |
| # print("Usage: python app.py <text_file_path> <lang_id>") | |
| # print("Text file should contain the article news") | |
| # exit() | |
| # txt_path = sys.argv[1] | |
| # lang_id = sys.argv[2] | |
| # if not os.path.exists(txt_path): | |
| # print("Path: {} do not exists".format(txt_path)) | |
| # exit() | |
| # text = '' | |
| # with open(txt_path, 'r', encoding='utf-8') as fp: | |
| # text = fp.read().strip() | |
| ### Login to huggingface token | |
| # access_token = "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" | |
| # login(token = access_token) | |
| # headline_generator = Headline_Generation() | |
| # summarizer = Summarization() | |
| # if translate == True: | |
| # translator = LTRC_Translation_API(tgt_lang = lang_id) | |
| # text = translator.translate(text) | |
| # headline = headline_generator.get_headline(text, lang_id) | |
| # summary = summarizer.get_summary(text, lang_id) | |
| # print("Article: ", text) | |
| # print("Summary: ", summary) | |
| # print("Headline: ", headline) | |
| # return "Headline: " + headline + "\nSummary: " + summary | |
| # return [text, summary, headline] | |
| translator = LTRC_Translation_API(src_lang = src_lang_id, tgt_lang = tgt_lang_id) | |
| text = translator.translate(text) | |
| return text | |
| interface = gr.Interface( | |
| get_prediction, | |
| inputs=[ | |
| gr.Textbox(lines = 8, label = "Source Text", info = "Provide the news article text here"), | |
| # gr.Textbox(lines = 8, label = "News Article Text", info = "Provide the news article text here. Check the `Translate` if the source language is english."), | |
| # gr.Dropdown( | |
| # ['as', 'bn', 'en', 'gu', 'hi', 'kn', 'ml', 'mr', 'or', 'pa', 'ta', 'te'], label="Source Language code", info="select the source language code" | |
| # ), | |
| # gr.Dropdown( | |
| # ['as', 'bn', 'en', 'gu', 'hi', 'kn', 'ml', 'mr', 'or', 'pa', 'ta', 'te'], label="Target Language code", info="select the target language code" | |
| # ), | |
| gr.Dropdown( | |
| ['en', 'hi', 'kn', 'ml', 'mr', 'ta', 'te'], label="Source Language code", info="select the source language code" | |
| ), | |
| gr.Dropdown( | |
| ['en', 'hi', 'kn', 'ml', 'mr', 'ta', 'te'], label="Target Language code", info="select the target language code" | |
| ), | |
| # gr.Checkbox(label="Translate", info="Is translation required?") | |
| ], | |
| outputs=[ | |
| gr.Textbox(lines = 8, label = "Translation", info = "Translated text"), | |
| # gr.Textbox(lines = 8, label = "Source Article Text", info = "Source article text (if `Translate` is enabled then the source will be translated to target language)"), | |
| # gr.Textbox(lines = 4, label = "Summary", info = "Summary of the given article (translated if `Translate` is enabled)"), | |
| # gr.Textbox(lines = 2, label = "Headline", info = "Generated headline of the given article (translated if `Translate` is enabled)") | |
| ], | |
| title = "Indic Translation Demo" | |
| ) | |
| interface.launch(share=True) |