Spaces:
Runtime error
Runtime error
| import nltk | |
| import spacy | |
| from word2number import w2n | |
| import inflect | |
| from num2words import num2words | |
| p = inflect.engine() | |
| import numpy as np | |
| import random | |
| nltk.download('punkt') | |
| nltk.download('averaged_perceptron_tagger') | |
| nlp = spacy.load('en_core_web_sm') | |
| # object names with two words | |
| SPECIAL_WORDS = ['baseball bat', | |
| 'baseball glove', | |
| 'cell phone', | |
| 'dining table', | |
| 'fire hydrant', | |
| 'french fries', | |
| 'hair drier', | |
| 'hot dog', | |
| 'parking meter', | |
| 'potted plant', | |
| 'soccer ball', | |
| 'soccer player', | |
| 'sports ball', | |
| 'stop sign', | |
| 'teddy bear', | |
| 'tennis racket', | |
| 'toy figure', | |
| 'traffic light', | |
| 'wine glass'] | |
| def _get_nouns(lines): | |
| # function to test if something is a noun | |
| present_words = [] | |
| for s in SPECIAL_WORDS: | |
| if s in lines: | |
| present_words.append(s) | |
| for w in present_words: | |
| lines = lines.replace(w, "") | |
| is_noun = lambda pos: pos[:2] == 'NN' or pos[:2] == 'NNP' | |
| # do the nlp stuff | |
| tokenized = nltk.word_tokenize(lines) | |
| nouns = [word for (word, pos) in nltk.pos_tag(tokenized) if is_noun(pos)] | |
| noun_dict = {} | |
| if "objects" in nouns: | |
| nouns.remove("objects") | |
| if "image" in nouns: | |
| nouns.remove("image") | |
| for n in nouns: | |
| if n not in noun_dict.keys(): | |
| noun_dict[n] = 1 | |
| else: | |
| noun_dict[n] += 1 | |
| nouns = {} | |
| for k, v in noun_dict.items(): | |
| if not (k == "bus" or k == "skis"): | |
| if v == 1: | |
| if p.singular_noun(k): | |
| k = p.singular_noun(k) | |
| else: | |
| if not p.singular_noun(k): | |
| k = p.plural(k) | |
| try: | |
| w2n.word_to_num(k) | |
| except: | |
| if len(k) >= 3: | |
| if k == "ski": | |
| k = "skis" | |
| elif k == "gras": | |
| k = "grass" | |
| nouns[k] = v | |
| for w in present_words: | |
| nouns[w] = 1 | |
| return nouns | |
| def _get_num_nouns(lines): | |
| lines = lines.replace(":", "").replace(".", "") | |
| doc = nlp(lines) | |
| num_nouns = [chunk.text for chunk in doc.noun_chunks if any(token.pos_ == 'NUM' for token in chunk)] | |
| num_noun_dict = {} | |
| for n in num_nouns: | |
| nums = n.split(", ") | |
| for n in nums: | |
| try: | |
| w = " ".join(n.split(' ')[1:]) | |
| if w == "ski": | |
| w = "skis" | |
| num_noun_dict[w] = w2n.word_to_num(n.split(' ')[0]) | |
| except: | |
| pass | |
| return num_noun_dict | |
| def _obtain_nouns(gt): | |
| gt = gt.replace("hair dryer", "hair drier").lower() | |
| nouns_gt = _get_nouns(gt) | |
| num_nouns_gt = _get_num_nouns(gt) | |
| com_keys = [] | |
| for k in nouns_gt.keys(): | |
| if p.plural(k) in num_nouns_gt.keys(): | |
| com_keys.append(k) | |
| for k in com_keys: | |
| del nouns_gt[k] | |
| num_nouns_gt = {**num_nouns_gt, **nouns_gt} | |
| return num_nouns_gt | |
| def generate_qa_pairs(text): | |
| num_nouns = _obtain_nouns(text) | |
| qa_pairs = [] | |
| for obj, count in num_nouns.items(): | |
| # Count question | |
| if count == 1: | |
| plural_obj = p.plural(obj) | |
| else: | |
| plural_obj = obj | |
| count_question = f"How many {plural_obj} are there in the image?" | |
| count_answer = f"There {'is' if count == 1 else 'are'} {num2words(count)} {obj} in the image." | |
| qa_pairs.append((count_question, count_answer)) | |
| prob_positive = np.random.uniform(0,1.) | |
| if prob_positive > 0.7 or count == 1: | |
| numeric_presence_question = f"{'Is' if count == 1 else 'Are'} there {num2words(count)} {obj} in the image?" | |
| numeric_presence_answer = "Yes." | |
| elif count > 1: | |
| numbers = [i for i in range(2, count + 6) if i != count] | |
| # Select a random number from the range | |
| cnt = random.choice(numbers) | |
| numeric_presence_question = f"{'Is' if cnt == 1 else 'Are'} there {num2words(cnt)} {obj} in the image?" | |
| numeric_presence_answer = "No." | |
| qa_pairs.append((numeric_presence_question, numeric_presence_answer)) | |
| random.shuffle(qa_pairs) | |
| return random.sample(qa_pairs, min(len(qa_pairs), random.choice([1, 2, 3, 4, 5, 6]))) | |
| if __name__ == "__main__": | |
| text = "The objects present in the image are: wall, ceiling, shelf, cabinet, counter, dining table, two people, eighteen bottles, two wine glasses, refrigerator, tv, bowl" | |
| qa = generate_qa_pairs(text) | |
| from icecream import ic | |
| ic(qa) | |