How does this actually work?

#5
by alpindale - opened

Are there any limitations to what sort of pytorch code it can write triton kernels for? For example:

$ cat infer.py 
from kernelllm import KernelLLM

# Initialize the model
model = KernelLLM()

# Define your PyTorch module
pytorch_code = '''
import torch

def _apply_min_p(
    logits: torch.Tensor,
    min_p: torch.Tensor,
) -> torch.Tensor:
    probs = torch.softmax(logits, dim=-1)
    top_probs, _ = probs.max(dim=-1, keepdim=True)
    scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
    tokens_to_remove = probs < scaled_min_p
    logits = logits.masked_fill_(tokens_to_remove, -float("inf"))

    return logits
'''

# Generate optimized Triton code
optimized_code = model.generate_triton(pytorch_code, max_new_tokens=512)
print(optimized_code)

Output:

$ python infer.py 
Loading checkpoint shards: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 4/4 [00:06<00:00,  1.56s/it]
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
architecture named Model with custom Triton kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code!

import torch
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import grid
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
from torch._inductor.runtime import triton_helpers
from torch._inductor.runtime.triton_helpers import math as tl_math
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda




@triton
	.jit
def triton_poi_fused__softmax_0(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr
    ):
    xnumel = 64
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x2 = xindex
    x1 = xindex // 4
    tmp0 = tl.load(in_ptr0 + x2, xmask)
    tmp1 = tl.load(in_ptr0 + 4 * x1, xmask, eviction_policy='evict_last')
    tmp2 = tl.load(in_ptr0 + (1 + 4 * x1), xmask, eviction_policy='evict_last')
    tmp4 = tl.load(in_ptr0 + (2 + 4 * x1), xmask, eviction_policy='evict_last')
    tmp6 = tl.load(in_ptr0 + (3 + 4 * x1), xmask, eviction_policy='evict_last')
    tmp3 = triton_helpers.maximum(tmp1, tmp2)
    tmp5 = triton_helpers.maximum(tmp3, tmp4)
    tmp7 = triton_helpers.maximum(tmp5, tmp6)
    tmp8 = tmp0 - tmp7
    tmp9 = tl_math.exp(tmp8)
    tl.store(out_ptr0 + x2, tmp9, xmask)




@triton
	.jit
def triton_poi_fused__softmax_max_mul_1(in_ptr0, in_ptr1, out

This does not seem to be what I asked for.

AI at Meta org

Hi @alpindale ,

thank you for reaching out!
Unfortunately, the model is pretty inflexible with the format of torch code that it expects. This should be explained better, so I'll run through it by taking your code as an example below. The function that you are using is actually a really nice example of the kind of kernels that KernelLLM should work well for!

Example

TL;DR: Changing the format of your prompt to a full nn.Module solves the issue!

Let's get into it :)
Here reproduced is your original code (but the format is not ready for KernelLLM yet):

import torch

def _apply_min_p(
    logits: torch.Tensor,
    min_p: torch.Tensor,
) -> torch.Tensor:
    probs = torch.softmax(logits, dim=-1)
    top_probs, _ = probs.max(dim=-1, keepdim=True)
    scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
    tokens_to_remove = probs < scaled_min_p
    logits = logits.masked_fill_(tokens_to_remove, -float("inf"))

    return logits

I'm taking this and rewrite it as a module Model(nn.Module) in exactly the way that the in-context example in kernelllm.py display to the model:
Your code formatted for KernelLLM:
(Comments are just explanations added here)

import torch
import torch.nn as nn  # I always use these imports

class Model(nn.Module):  # I always call my model `Model`
    def __init__(self):
        super(Model, self).__init__()

    def forward(self, logits, min_p):   # here I just renamed your function to forward
        probs = torch.softmax(logits, dim=-1)
        top_probs, _ = probs.max(dim=-1, keepdim=True)
        scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
        tokens_to_remove = probs < scaled_min_p
        logits = logits.masked_fill_(tokens_to_remove, -float("inf"))

        return logits

def get_inputs():
    return [torch.randn(2, 128), torch.tensor([0.1, 0.1])]   # you have to decide on which shape you want to use. I am using a batch size of 2, to be somewhat realistic.

def get_init_inputs():  # just keep this even though your model doesn't need it.
    return []

Note that I made an arbitrary, but somewhat sensible choice for the input shapes, which need to be provided as well.

I feed this to KernelLLM:

from kernelllm import KernelLLM, PROMPT_TEMPLATE
text = """
import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

    def forward(self, logits, min_p):
        probs = torch.softmax(logits, dim=-1)
        top_probs, _ = probs.max(dim=-1, keepdim=True)
        scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
        tokens_to_remove = probs < scaled_min_p
        logits = logits.masked_fill_(tokens_to_remove, -float("inf"))

        return logits

def get_inputs():
    return [torch.randn(2, 128), torch.tensor([0.1, 0.1])]

def get_init_inputs():
    return []
"""

prompt = PROMPT_TEMPLATE.format(text)

model = KernelLLM()
model.stream_raw(prompt)  # equivalent to model.generate_triton(text), but nicer to watch!

From this I got the following output, which was not actually correct, but could be fixed in three spots to make it run.
I left a comment whereever I changed the models output and left the original lines as comments:

import torch
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import grid
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
from torch._inductor.runtime import triton_helpers
from torch._inductor.runtime.triton_helpers import math as tl_math
import torch.nn as nn
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda




@triton
	.jit
def triton_poi_fused__softmax_max_0(in_ptr0, out_ptr0, out_ptr1, xnumel,
    XBLOCK: tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    tl.full([XBLOCK], True, tl.int1)
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + 2 * x0, None, eviction_policy='evict_last')
    tmp1 = tl.load(in_ptr0 + (1 + 2 * x0), None, eviction_policy='evict_last')
    tmp2 = triton_helpers.maximum(tmp0, tmp1)
    tmp3 = tmp0 - tmp2
    tmp4 = tl_math.exp(tmp3)
    tmp5 = tmp1 - tmp2
    tmp6 = tl_math.exp(tmp5)
    tmp7 = tmp4 + tmp6
    tmp8 = tmp4 / tmp7
    tmp9 = tmp6 / tmp7
    tmp10 = triton_helpers.maximum(tmp8, tmp9)
    tl.store(out_ptr0 + x0, tmp2, None)
    tl.store(out_ptr1 + x0, tmp10, None)




@triton
	.jit
def triton_poi_fused__softmax_lt_masked_fill_1(in_ptr0, in_ptr1, in_ptr2,
    out_ptr0, out_ptr1, xnumel, XBLOCK: tl.constexpr):
    xnumel = 256
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x2 = xindex
    x1 = xindex // 2
    tmp0 = tl.load(in_ptr0 + x2, xmask)
    tmp1 = tl.load(in_ptr0 + 2 * x1, xmask, eviction_policy='evict_last')
    tmp2 = tl.load(in_ptr0 + (1 + 2 * x1), xmask, eviction_policy='evict_last')
    tmp8 = tl.load(in_ptr1 + x1, xmask, eviction_policy='evict_last')
    tmp10 = tl.load(in_ptr2 + x1, xmask, eviction_policy='evict_last')
    tmp3 = tmp1 > tmp2
    tmp4 = tmp1 == tmp2
    tmp5 = tmp1 != tmp1
    tmp6 = tmp2 != tmp2
    tmp7 = tmp5 > tmp6
    tmp9 = tmp3 | tmp7
    tmp11 = tmp10 * tmp8
    tmp12 = tmp0 < tmp11
    # FIX: set tmp13 to -inf 
    # tmp13 = -1.0000000116860974e-07
    tmp13 = -float('inf')
    tmp14 = tl.where(tmp12, tmp13, tmp0)
    tl.store(out_ptr0 + x2, tmp9, xmask)
    tl.store(out_ptr1 + x2, tmp14, xmask)


def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (2, 128), (128, 1))
    assert_size_stride(arg1_1, (2,), (1,))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf0 = empty_strided_cuda((2, 1), (1, 2), torch.float32)
        buf1 = empty_strided_cuda((2, 1), (1, 2), torch.float32)
        get_raw_stream(0)
        triton_poi_fused__softmax_max_0[grid(2)](arg0_1, buf0, buf1, 2,
            XBLOCK=2, num_warps=1, num_stages=1)
        buf2 = empty_strided_cuda((2, 128), (128, 1), torch.bool)
        buf3 = empty_strided_cuda((2, 128), (128, 1), torch.float32)
        triton_poi_fused__softmax_lt_masked_fill_1[grid(256)](arg0_1, buf1,
            buf0, buf2, buf3, 256, XBLOCK=128, num_warps=4, num_stages=1)
        del arg0_1
        del buf0
        del buf1
    # FIX: remove reinterpret_tensor which is undefined
    # return buf3, reinterpret_tensor(arg1_1, (2, 1), (1, 1), 0), buf2
    return buf3,


class ModelNew(nn.Module):
    def __init__(self):
        super(ModelNew, self).__init__()

    def forward(self, input_0, input_1):
        arg0_1 = input_0
        # FIX: squeeze input shape from [2, 1] to [2] which is expected by the assertion in the kernel
        # arg1_1 = input_1
        arg1_1 = input_1[:, 0]
        output = call([arg0_1, arg1_1])
        return output[0]

This code is now matching the original model's forward on the test inputs:

m_orig = Model()aa
m_new = ModelNew()
inputs = get_inputs()
inputs_cuda = [i.cuda() for i in inputs]

out_orig = m_orig(*inputs)
out_new = m_new(*inputs_cuda)

print(out_orig == out_new.cpu())

I hope this helps!

Kind regards,
Zacharias

How can we train a model to be more flexible, I am also training a model for triton and have this same problem, my model and training pipeline expect a functional format.

I believe the "problem" with supervised finetuning is that you burn in your samples pretty strongly into the LLM. On the upside, you can really teach the models new capabilities.

In general the model will learn to be as flexible as the data that is was trained on, so mixing in general coding-related & instruction following data and good old fashioned data curation.

Sign up or log in to comment