|
|
|
|
|
|
|
|
|
|
|
#include "ds_kernel_utils.h" |
|
#include "memory_access_utils.h" |
|
#include "quantization.h" |
|
#include "quantization_utils.h" |
|
#include "reduction_utils.h" |
|
|
|
namespace cg = cooperative_groups; |
|
|
|
|
|
|
|
|
|
template <int q_bits, |
|
quantize::Type quant_type, |
|
int UNROLL, |
|
int internal_unroll, |
|
int threads_per_group, |
|
int max_threads> |
|
__global__ void cached_quantization(int8_t* __restrict__ output_data, |
|
float* __restrict__ params, |
|
const __half* __restrict__ input_data, |
|
int groups, |
|
int elems_per_group) |
|
{ |
|
cg::thread_block tb = cg::this_thread_block(); |
|
cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb); |
|
|
|
|
|
const int block_offset = |
|
(tb.group_index().x * (max_threads / threads_per_group) * elems_per_group) + |
|
(tb.thread_index().y * elems_per_group); |
|
const int elem_offset = tb.thread_index().x * quantize::h_per_load; |
|
const int base_offset = block_offset + elem_offset; |
|
const int stride = tb.size() * quantize::h_per_load; |
|
|
|
const __half* input_base = input_data + base_offset; |
|
|
|
__half2 local_buffer[UNROLL * internal_unroll * quantize::h2_per_load]; |
|
|
|
#pragma unroll |
|
for (int i = 0; i < UNROLL; i++) { |
|
|
|
__half2* iteration_buffer = local_buffer + i * internal_unroll * quantize::h2_per_load; |
|
#pragma unroll |
|
for (int j = 0; j < internal_unroll; j++) { |
|
const int iteration = i * internal_unroll + j; |
|
mem_access::load_global<quantize::granularity>( |
|
iteration_buffer + j * quantize::h2_per_load, |
|
input_base + iteration * stride, |
|
elem_offset + iteration * stride < elems_per_group); |
|
} |
|
} |
|
|
|
quantize:: |
|
local_array<quant_type, q_bits, UNROLL * internal_unroll, threads_per_group, max_threads>( |
|
local_buffer, params, output_data, elems_per_group, groups); |
|
} |
|
|
|
|
|
#define LAUNCH_CACHED_QUANT_CALL(q_bits, quant_type) \ |
|
cached_quantization<q_bits, \ |
|
quant_type, \ |
|
unroll_factor, \ |
|
internal_unroll_l, \ |
|
threads_per_group, \ |
|
max_threads> \ |
|
<<<grid, block, 0, stream>>>(output_data, params, input_data, groups, elems_per_group); |
|
|
|
#define LAUNCH_CACHED_QUANT( \ |
|
q_bits, quant_type, unroll_factor_in, internal_unroll_in, threads_per_group_in) \ |
|
const int unroll_factor = unroll_factor_in; \ |
|
const int internal_unroll_l = internal_unroll_in; \ |
|
const int threads_per_group = threads_per_group_in; \ |
|
if (q_bits == 4) { \ |
|
if (quant_type == quantize::Type::Asymmetric) { \ |
|
LAUNCH_CACHED_QUANT_CALL(4, quantize::Type::Asymmetric) \ |
|
} else { \ |
|
LAUNCH_CACHED_QUANT_CALL(4, quantize::Type::Symmetric) \ |
|
} \ |
|
} else { \ |
|
if (quant_type == quantize::Type::Asymmetric) { \ |
|
LAUNCH_CACHED_QUANT_CALL(8, quantize::Type::Asymmetric) \ |
|
} else { \ |
|
LAUNCH_CACHED_QUANT_CALL(8, quantize::Type::Symmetric) \ |
|
} \ |
|
} |
|
|
|
void launch_quant(int8_t* output_data, |
|
float* params, |
|
const __half* input_data, |
|
const int groups, |
|
const int elems_per_group, |
|
const int num_bits, |
|
const quantize::Type quant_type, |
|
cudaStream_t stream) |
|
{ |
|
constexpr int max_threads = 256; |
|
|
|
constexpr int internal_unroll = 2; |
|
|
|
const bool is_subblock_schedule = (elems_per_group <= 128) ? true : false; |
|
const int h_per_step = is_subblock_schedule ? quantize::h_per_load |
|
: quantize::h_per_load * internal_unroll; |
|
|
|
|
|
|
|
const int one_step_threads = next_pow2((elems_per_group + h_per_step - 1) / h_per_step); |
|
const int threads_per_group = (one_step_threads < max_threads) ? one_step_threads : max_threads; |
|
|
|
const int groups_per_block = |
|
is_subblock_schedule ? (max_threads + threads_per_group - 1) / threads_per_group : 1; |
|
const int groups_launch = (groups_per_block + groups - 1) / groups_per_block; |
|
|
|
dim3 block(threads_per_group, groups_per_block); |
|
dim3 grid(groups_launch); |
|
|
|
const int elems_per_step = threads_per_group * h_per_step; |
|
const int external_unroll = (elems_per_group + elems_per_step - 1) / elems_per_step; |
|
|
|
if (is_subblock_schedule) { |
|
|
|
if (threads_per_group == 1) { |
|
LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, 1, 1); |
|
} else if (threads_per_group == 2) { |
|
LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, 1, 2); |
|
} else if (threads_per_group == 4) { |
|
LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, 1, 4); |
|
} else if (threads_per_group == 8) { |
|
LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, 1, 8); |
|
} else if (threads_per_group == 16) { |
|
LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, 1, 16); |
|
} |
|
} else if (external_unroll == 1) { |
|
|
|
|
|
LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, internal_unroll, max_threads); |
|
} else if (external_unroll == 2) { |
|
|
|
LAUNCH_CACHED_QUANT(num_bits, quant_type, 2, internal_unroll, max_threads); |
|
} else if (external_unroll == 3) { |
|
|
|
LAUNCH_CACHED_QUANT(num_bits, quant_type, 3, internal_unroll, max_threads); |
|
} else if (external_unroll == 4) { |
|
|
|
LAUNCH_CACHED_QUANT(num_bits, quant_type, 4, internal_unroll, max_threads); |
|
} |
|
} |
|
|