How does this actually work?
Are there any limitations to what sort of pytorch code it can write triton kernels for? For example:
$ cat infer.py
from kernelllm import KernelLLM
# Initialize the model
model = KernelLLM()
# Define your PyTorch module
pytorch_code = '''
import torch
def _apply_min_p(
logits: torch.Tensor,
min_p: torch.Tensor,
) -> torch.Tensor:
probs = torch.softmax(logits, dim=-1)
top_probs, _ = probs.max(dim=-1, keepdim=True)
scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
tokens_to_remove = probs < scaled_min_p
logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
return logits
'''
# Generate optimized Triton code
optimized_code = model.generate_triton(pytorch_code, max_new_tokens=512)
print(optimized_code)
Output:
$ python infer.py
Loading checkpoint shards: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββ| 4/4 [00:06<00:00, 1.56s/it]
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
architecture named Model with custom Triton kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code!
import torch
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import grid
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
from torch._inductor.runtime import triton_helpers
from torch._inductor.runtime.triton_helpers import math as tl_math
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
@triton
.jit
def triton_poi_fused__softmax_0(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr
):
xnumel = 64
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x2 = xindex
x1 = xindex // 4
tmp0 = tl.load(in_ptr0 + x2, xmask)
tmp1 = tl.load(in_ptr0 + 4 * x1, xmask, eviction_policy='evict_last')
tmp2 = tl.load(in_ptr0 + (1 + 4 * x1), xmask, eviction_policy='evict_last')
tmp4 = tl.load(in_ptr0 + (2 + 4 * x1), xmask, eviction_policy='evict_last')
tmp6 = tl.load(in_ptr0 + (3 + 4 * x1), xmask, eviction_policy='evict_last')
tmp3 = triton_helpers.maximum(tmp1, tmp2)
tmp5 = triton_helpers.maximum(tmp3, tmp4)
tmp7 = triton_helpers.maximum(tmp5, tmp6)
tmp8 = tmp0 - tmp7
tmp9 = tl_math.exp(tmp8)
tl.store(out_ptr0 + x2, tmp9, xmask)
@triton
.jit
def triton_poi_fused__softmax_max_mul_1(in_ptr0, in_ptr1, out
This does not seem to be what I asked for.
Hi @alpindale ,
thank you for reaching out!
Unfortunately, the model is pretty inflexible with the format of torch code that it expects. This should be explained better, so I'll run through it by taking your code as an example below. The function that you are using is actually a really nice example of the kind of kernels that KernelLLM should work well for!
Example
TL;DR: Changing the format of your prompt to a full nn.Module solves the issue!
Let's get into it :)
Here reproduced is your original code (but the format is not ready for KernelLLM yet):
import torch
def _apply_min_p(
logits: torch.Tensor,
min_p: torch.Tensor,
) -> torch.Tensor:
probs = torch.softmax(logits, dim=-1)
top_probs, _ = probs.max(dim=-1, keepdim=True)
scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
tokens_to_remove = probs < scaled_min_p
logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
return logits
I'm taking this and rewrite it as a module Model(nn.Module)
in exactly the way that the in-context example in kernelllm.py
display to the model:
Your code formatted for KernelLLM:
(Comments are just explanations added here)
import torch
import torch.nn as nn # I always use these imports
class Model(nn.Module): # I always call my model `Model`
def __init__(self):
super(Model, self).__init__()
def forward(self, logits, min_p): # here I just renamed your function to forward
probs = torch.softmax(logits, dim=-1)
top_probs, _ = probs.max(dim=-1, keepdim=True)
scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
tokens_to_remove = probs < scaled_min_p
logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
return logits
def get_inputs():
return [torch.randn(2, 128), torch.tensor([0.1, 0.1])] # you have to decide on which shape you want to use. I am using a batch size of 2, to be somewhat realistic.
def get_init_inputs(): # just keep this even though your model doesn't need it.
return []
Note that I made an arbitrary, but somewhat sensible choice for the input shapes, which need to be provided as well.
I feed this to KernelLLM:
from kernelllm import KernelLLM, PROMPT_TEMPLATE
text = """
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, logits, min_p):
probs = torch.softmax(logits, dim=-1)
top_probs, _ = probs.max(dim=-1, keepdim=True)
scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
tokens_to_remove = probs < scaled_min_p
logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
return logits
def get_inputs():
return [torch.randn(2, 128), torch.tensor([0.1, 0.1])]
def get_init_inputs():
return []
"""
prompt = PROMPT_TEMPLATE.format(text)
model = KernelLLM()
model.stream_raw(prompt) # equivalent to model.generate_triton(text), but nicer to watch!
From this I got the following output, which was not actually correct, but could be fixed in three spots to make it run.
I left a comment whereever I changed the models output and left the original lines as comments:
import torch
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import grid
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
from torch._inductor.runtime import triton_helpers
from torch._inductor.runtime.triton_helpers import math as tl_math
import torch.nn as nn
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
@triton
.jit
def triton_poi_fused__softmax_max_0(in_ptr0, out_ptr0, out_ptr1, xnumel,
XBLOCK: tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + 2 * x0, None, eviction_policy='evict_last')
tmp1 = tl.load(in_ptr0 + (1 + 2 * x0), None, eviction_policy='evict_last')
tmp2 = triton_helpers.maximum(tmp0, tmp1)
tmp3 = tmp0 - tmp2
tmp4 = tl_math.exp(tmp3)
tmp5 = tmp1 - tmp2
tmp6 = tl_math.exp(tmp5)
tmp7 = tmp4 + tmp6
tmp8 = tmp4 / tmp7
tmp9 = tmp6 / tmp7
tmp10 = triton_helpers.maximum(tmp8, tmp9)
tl.store(out_ptr0 + x0, tmp2, None)
tl.store(out_ptr1 + x0, tmp10, None)
@triton
.jit
def triton_poi_fused__softmax_lt_masked_fill_1(in_ptr0, in_ptr1, in_ptr2,
out_ptr0, out_ptr1, xnumel, XBLOCK: tl.constexpr):
xnumel = 256
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x2 = xindex
x1 = xindex // 2
tmp0 = tl.load(in_ptr0 + x2, xmask)
tmp1 = tl.load(in_ptr0 + 2 * x1, xmask, eviction_policy='evict_last')
tmp2 = tl.load(in_ptr0 + (1 + 2 * x1), xmask, eviction_policy='evict_last')
tmp8 = tl.load(in_ptr1 + x1, xmask, eviction_policy='evict_last')
tmp10 = tl.load(in_ptr2 + x1, xmask, eviction_policy='evict_last')
tmp3 = tmp1 > tmp2
tmp4 = tmp1 == tmp2
tmp5 = tmp1 != tmp1
tmp6 = tmp2 != tmp2
tmp7 = tmp5 > tmp6
tmp9 = tmp3 | tmp7
tmp11 = tmp10 * tmp8
tmp12 = tmp0 < tmp11
# FIX: set tmp13 to -inf
# tmp13 = -1.0000000116860974e-07
tmp13 = -float('inf')
tmp14 = tl.where(tmp12, tmp13, tmp0)
tl.store(out_ptr0 + x2, tmp9, xmask)
tl.store(out_ptr1 + x2, tmp14, xmask)
def call(args):
arg0_1, arg1_1 = args
args.clear()
assert_size_stride(arg0_1, (2, 128), (128, 1))
assert_size_stride(arg1_1, (2,), (1,))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf0 = empty_strided_cuda((2, 1), (1, 2), torch.float32)
buf1 = empty_strided_cuda((2, 1), (1, 2), torch.float32)
get_raw_stream(0)
triton_poi_fused__softmax_max_0[grid(2)](arg0_1, buf0, buf1, 2,
XBLOCK=2, num_warps=1, num_stages=1)
buf2 = empty_strided_cuda((2, 128), (128, 1), torch.bool)
buf3 = empty_strided_cuda((2, 128), (128, 1), torch.float32)
triton_poi_fused__softmax_lt_masked_fill_1[grid(256)](arg0_1, buf1,
buf0, buf2, buf3, 256, XBLOCK=128, num_warps=4, num_stages=1)
del arg0_1
del buf0
del buf1
# FIX: remove reinterpret_tensor which is undefined
# return buf3, reinterpret_tensor(arg1_1, (2, 1), (1, 1), 0), buf2
return buf3,
class ModelNew(nn.Module):
def __init__(self):
super(ModelNew, self).__init__()
def forward(self, input_0, input_1):
arg0_1 = input_0
# FIX: squeeze input shape from [2, 1] to [2] which is expected by the assertion in the kernel
# arg1_1 = input_1
arg1_1 = input_1[:, 0]
output = call([arg0_1, arg1_1])
return output[0]
This code is now matching the original model's forward on the test inputs:
m_orig = Model()aa
m_new = ModelNew()
inputs = get_inputs()
inputs_cuda = [i.cuda() for i in inputs]
out_orig = m_orig(*inputs)
out_new = m_new(*inputs_cuda)
print(out_orig == out_new.cpu())
I hope this helps!
Kind regards,
Zacharias
How can we train a model to be more flexible, I am also training a model for triton and have this same problem, my model and training pipeline expect a functional format.
I believe the "problem" with supervised finetuning is that you burn in your samples pretty strongly into the LLM. On the upside, you can really teach the models new capabilities.
In general the model will learn to be as flexible as the data that is was trained on, so mixing in general coding-related & instruction following data and good old fashioned data curation.