In [1]:
##%%
import os
import pickle
import json
# import random
# import torch
# import numpy as np
# import argparse
# import cohere
# from openai import OpenAI


In [2]:
##%%
# import hashlib
from tqdm import tqdm
from itertools import product
# from collections import Counter

# from transformers import LlamaForCausalLM, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
from transformers import AutoTokenizer, AutoModelForCausalLM
from textgames import GAME_NAMES, LEVEL_IDS, game_filename, _game_class_from_name


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
os.environ.setdefault("TEXTGAMES_OUTPUT_DIR", "user_outputs")

'user_outputs'

In [4]:
##%%
gen_model_checkpoint = "google/gemma-2-9b-it"
quantize = True

In [5]:
kwargs = {
    "device_map": "auto",
} if quantize else {}

In [6]:
##%%
gen_model = AutoModelForCausalLM.from_pretrained(gen_model_checkpoint, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, **kwargs)

Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:24<00:00,  6.19s/it]


In [7]:
gen_model.device

device(type='cuda', index=0)

In [8]:
def get_gemma_response(text):
    # global gen_model, tokenizer
    messages = [
        {"role": "user", "content": text},
    ]

    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(gen_model.device)

    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]

    outputs = gen_model.generate(
        input_ids,
        max_new_tokens=100,
        eos_token_id=terminators,
        do_sample=True,
        temperature=.001,
        top_p=1,
    )

    response = outputs[0][input_ids.shape[-1]:]
    return tokenizer.decode(response, skip_special_tokens=True)

---
Example Call

In [14]:
# @title
text = \
"""
Given a set of rules to calculate point, sort the set of words in decreasing order.
When there 2 or more words with same point, sort lexicographically.

Rules:
- every pair of consecutive consonant gets 5 points
- every pair of consecutive vowel gets 3 points
- add 1 point if there exists exactly 1 'g' in the word
- word less than 5 characters gets 10 points
- word starts with gen gets 100 points
- word ends with ta gets -1000 point

Words:
- genta
- winata
- hudi
- alham
- aji
- ruochen

Print only the answer.
"""

print(text)


Given a set of rules to calculate point, sort the set of words in decreasing order.
When there 2 or more words with same point, sort lexicographically.

Rules:
- every pair of consecutive consonant gets 5 points
- every pair of consecutive vowel gets 3 points
- add 1 point if there exists exactly 1 'g' in the word
- word less than 5 characters gets 10 points
- word starts with gen gets 100 points
- word ends with ta gets -1000 point

Words:
- genta
- winata
- hudi
- alham
- aji
- ruochen

Print only the answer.



In [None]:
# Gold Answer:
# - aji      10
# - hudi     10
# - ruochen   5  3
# - alham     5
# - genta     5  1  100 -1000
# - winata  -1000

In [9]:
print(get_gemma_response(text))

genta
winata
ruochen
hudi
alham
aji 



---
Automate run all sessions

In [None]:
for game_name, difficulty_level in product([GAME_NAMES[4], *GAME_NAMES[:4], *GAME_NAMES[5:]], LEVEL_IDS[:3]):
    game_cls = _game_class_from_name(game_name)
    with open(f"problemsets/{game_filename(game_name)}_{difficulty_level}.json", "r", encoding="utf8") as f:
        sid_prompt_dict = json.load(f)

    correct_cnt = 0
    for sid, prompt in tqdm(list(sid_prompt_dict.items()), desc=f"{game_filename(game_name)}_-_{difficulty_level}"):
        cur_game = game_cls()
        cur_game.load_game(prompt)
        response = get_gemma_response(cur_game.get_prompt()).strip()
        solved, val_msg = cur_game.validate(response)
        with open(f"model_outputs/results_gemma_2_9B_it.pkl", "ab") as o:
            pickle.dump((f"{game_filename(game_name)}_{difficulty_level}", sid, response, (solved, val_msg)), o)
        if solved:
            correct_cnt += 1

    print(f"{game_name}_-_{difficulty_level}")
    print(f"  Acc.: {correct_cnt / len(sid_prompt_dict):.2%}")