{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Introduction\n",
    "\n",
    "This tutorial demonstrates how to perform evaluation on a gpt-j-6B-int8 model."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prerequisite"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "!pip install onnx onnxruntime torch transformers datasets accelerate"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run\n",
    "\n",
    "### 1. Get lambada acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "import torch\n",
    "import numpy as np\n",
    "from datasets import load_dataset\n",
    "import onnxruntime as ort\n",
    "from torch.nn.functional import pad\n",
    "\n",
    "# load model\n",
    "model_id = \"EleutherAI/gpt-j-6B\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
    "\n",
    "def tokenize_function(examples):\n",
    "    example = tokenizer(examples['text'])\n",
    "    return example\n",
    "\n",
    "# create dataset\n",
    "dataset = load_dataset('lambada', split='validation')\n",
    "dataset = dataset.shuffle(seed=42)\n",
    "dataset = dataset.map(tokenize_function, batched=True)\n",
    "dataset.set_format(type='torch', columns=['input_ids'])\n",
    "\n",
    "# create session\n",
    "options = ort.SessionOptions()\n",
    "options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL\n",
    "session = ort.InferenceSession('/path/to/model.onnx', options, providers=ort.get_available_providers())\n",
    "total, hit = 0, 0\n",
    "index = 1\n",
    "\n",
    "# inference\n",
    "for idx, batch in enumerate(dataset):\n",
    "    input_ids = batch['input_ids'].unsqueeze(0)\n",
    "    label = input_ids[:, -1]\n",
    "    pad_len = 0  ##set to 0\n",
    "    input_ids = pad(input_ids, (0, pad_len), value=1)\n",
    "    ort_inputs = {\n",
    "        'input_ids': input_ids.detach().cpu().numpy(),\n",
    "        'attention_mask': torch.cat([torch.ones(input_ids.shape), torch.ones([1, 1])], dim=-1).detach().cpu().numpy().astype('int64')\n",
    "    }\n",
    "    for i in range(28):\n",
    "        ort_inputs[\"past_key_values.{}.key\".format(i)] = np.zeros((1,16,1,256), dtype='float32')\n",
    "        ort_inputs[\"past_key_values.{}.value\".format(i)] = np.zeros((1,16,1,256), dtype='float32')\n",
    "    predictions = session.run(None, ort_inputs)\n",
    "    outputs = torch.from_numpy(predictions[0]) \n",
    "    last_token_logits = outputs[:, -2 - pad_len, :]\n",
    "    pred = last_token_logits.argmax(dim=-1)\n",
    "    total += label.size(0)\n",
    "    hit += (pred == label).sum().item()\n",
    "\n",
    "acc = hit / total\n",
    "print('acc: ', acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "# batch inference\n",
    "\n",
    "from transformers import AutoTokenizer\n",
    "import torch\n",
    "import numpy as np\n",
    "from datasets import load_dataset\n",
    "import onnxruntime as ort\n",
    "from torch.nn.functional import pad\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "batch_size = 2\n",
    "pad_max = 196\n",
    "\n",
    "# load model\n",
    "model_id = \"EleutherAI/gpt-j-6B\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
    "\n",
    "def tokenize_function(examples):\n",
    "    example = tokenizer(examples['text'])\n",
    "    return example\n",
    "\n",
    "# create dataloader\n",
    "class Dataloader:\n",
    "    def __init__(self, pad_max=196, batch_size=1, sub_folder='validation'):\n",
    "        self.pad_max = pad_max\n",
    "        self.batch_size=batch_size\n",
    "        dataset = load_dataset('lambada', split=sub_folder)\n",
    "        dataset = dataset.map(tokenize_function, batched=True)\n",
    "        dataset.set_format(type=\"torch\", columns=[\"input_ids\", \"attention_mask\"])\n",
    "        self.dataloader = DataLoader(\n",
    "            dataset,\n",
    "            batch_size=self.batch_size,\n",
    "            shuffle=False,\n",
    "            collate_fn=self.collate_batch,\n",
    "        )\n",
    "\n",
    "    def collate_batch(self, batch):\n",
    "        input_ids_padded = []\n",
    "        attention_mask_padded = []\n",
    "        last_ind = []\n",
    "        for text in batch:\n",
    "            input_ids = text[\"input_ids\"] if text[\"input_ids\"].shape[0] <= self.pad_max else text[\"input_ids\"][0:int(self.pad_max-1)]\n",
    "            pad_len = self.pad_max - input_ids.shape[0]\n",
    "            last_ind.append(input_ids.shape[0] - 1)\n",
    "            input_ids = pad(input_ids, (0, pad_len), value=1)\n",
    "            input_ids_padded.append(input_ids)\n",
    "            attention_mask = torch.ones(input_ids.shape[0] + 1)\n",
    "            attention_mask_padded.append(attention_mask)\n",
    "        return (torch.vstack(input_ids_padded), torch.vstack(attention_mask_padded)), torch.tensor(last_ind)\n",
    "\n",
    "    def __iter__(self):\n",
    "        try:\n",
    "            for (input_ids, attention_mask), last_ind in self.dataloader:\n",
    "                data = [input_ids.detach().cpu().numpy().astype('int64')]\n",
    "                data.append(attention_mask.detach().cpu().numpy().astype('int64'))\n",
    "                yield data, last_ind.detach().cpu().numpy()\n",
    "        except StopIteration:\n",
    "            return\n",
    "\n",
    "# create session\n",
    "options = ort.SessionOptions()\n",
    "options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL\n",
    "session = ort.InferenceSession('/path/to/model.onnx', options, providers=ort.get_available_providers())\n",
    "total, hit = 0, 0\n",
    "\n",
    "dataloader = Dataloader(pad_max=pad_max, batch_size=batch_size)\n",
    "\n",
    "# inference\n",
    "for idx, (batch, last_ind) in enumerate(dataloader):\n",
    "    label = torch.from_numpy(batch[0][torch.arange(len(last_ind)), last_ind])\n",
    "    pad_len = pad_max - last_ind - 1\n",
    "    ort_inputs = {\n",
    "        'input_ids': batch[0],\n",
    "        'attention_mask': batch[1]\n",
    "    }\n",
    "    for i in range(28):\n",
    "        ort_inputs[\"past_key_values.{}.key\".format(i)] = np.zeros((batch_size,16,1,256), dtype='float32')\n",
    "        ort_inputs[\"past_key_values.{}.value\".format(i)] = np.zeros((batch_size,16,1,256), dtype='float32')\n",
    " \n",
    "    predictions = session.run(None, ort_inputs)\n",
    "    outputs = torch.from_numpy(predictions[0])\n",
    "    last_token_logits = outputs[torch.arange(len(last_ind)), -2 - pad_len, :]\n",
    "    pred = last_token_logits.argmax(dim=-1)\n",
    "    total += len(label)\n",
    "    hit += (pred == label).sum().item()\n",
    "\n",
    "acc = hit / total\n",
    "print('acc: ', acc)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. Text Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "import sys\n",
    "\n",
    "# create session\n",
    "sess_options = ort.SessionOptions()\n",
    "sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL\n",
    "session = ort.InferenceSession('/path/to/model.onnx', sess_options)\n",
    "\n",
    "# input prompt\n",
    "# 32 tokens input\n",
    "prompt = \"Once upon a time, there existed a little girl, who liked to have adventures.\" + \\\n",
    "                 \" She wanted to go to places and meet new people, and have fun.\"\n",
    "\n",
    "print(\"prompt: \", prompt)\n",
    "\n",
    "total_time = 0.0\n",
    "num_iter = 10\n",
    "num_warmup = 3\n",
    "\n",
    "# start\n",
    "for idx in range(num_iter):\n",
    "    text = []\n",
    "    tic = time.time()\n",
    "\n",
    "    input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids\n",
    "\n",
    "    attention_mask = torch.ones(input_ids.shape[1] +1)\n",
    "    attention_mask[0] = 0\n",
    "    attention_mask = attention_mask.unsqueeze(0)\n",
    "\n",
    "    inp = {'input_ids': input_ids.detach().cpu().numpy(),\n",
    "            'attention_mask': attention_mask.detach().cpu().numpy().astype('int64')}\n",
    "    for i in range(28):\n",
    "        inp[\"past_key_values.{}.key\".format(i)] = torch.zeros([1,16,1,256]).detach().cpu().numpy()\n",
    "        inp[\"past_key_values.{}.value\".format(i)] = torch.zeros([1,16,1,256]).detach().cpu().numpy()\n",
    "\n",
    "    for step in range(32):\n",
    "\n",
    "        output = session.run(None, inp)\n",
    "        logits = output[0]\n",
    "        logits = torch.from_numpy(logits)\n",
    "        next_token_logits = logits[:, -1, :]\n",
    "        probs = torch.nn.functional.softmax(next_token_logits, dim=-1)\n",
    "        next_tokens = torch.argmax(probs, dim=-1)\n",
    "        present_kv = output[1]\n",
    "        for i in range(28):\n",
    "\n",
    "            if step == 0:\n",
    "                inp[\"past_key_values.{}.key\".format(i)] = output[2*i+1][:, :, 1:, :]\n",
    "                inp[\"past_key_values.{}.value\".format(i)] = output[2*i+2][:, :, 1:, :]\n",
    "            else:\n",
    "                inp[\"past_key_values.{}.key\".format(i)] = output[2*i+1]\n",
    "                inp[\"past_key_values.{}.value\".format(i)] = output[2*i+2]\n",
    "\n",
    "        input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)\n",
    "        if step == 0:\n",
    "            attention_mask = torch.cat([attention_mask[:, 1:], torch.ones([1, 1])], dim=-1)\n",
    "        else:\n",
    "            attention_mask = torch.cat([attention_mask, torch.ones([1, 1])], dim=-1)\n",
    "\n",
    "        inp['attention_mask'] = attention_mask.detach().cpu().numpy().astype('int64')\n",
    "        inp['input_ids'] = input_ids[:, -1:].detach().cpu().numpy()\n",
    "\n",
    "    print(tokenizer.decode(input_ids[0]))\n",
    "    toc = time.time()\n",
    "    if idx >= num_warmup:\n",
    "        total_time += (toc - tic)\n",
    "print(\"Inference latency: %.3f s.\" % (total_time / (num_iter - num_warmup)))"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}