|
#pragma once |
|
|
|
#include "cutlass/cutlass.h" |
|
#include <climits> |
|
#include "cuda_runtime.h" |
|
#include <iostream> |
|
|
|
|
|
|
|
|
|
#define CUTLASS_CHECK(status) \ |
|
{ \ |
|
cutlass::Status error = status; \ |
|
TORCH_CHECK(error == cutlass::Status::kSuccess, \ |
|
cutlassGetStatusString(error)); \ |
|
} |
|
|
|
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { |
|
int max_shared_mem_per_block_opt_in = 0; |
|
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, |
|
cudaDevAttrMaxSharedMemoryPerBlockOptin, device); |
|
return max_shared_mem_per_block_opt_in; |
|
} |
|
|
|
int32_t get_sm_version_num(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename Kernel> |
|
struct enable_sm90_or_later : Kernel { |
|
template <typename... Args> |
|
CUTLASS_DEVICE void operator()(Args&&... args) { |
|
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 |
|
Kernel::operator()(std::forward<Args>(args)...); |
|
#endif |
|
} |
|
}; |
|
|
|
template <typename Kernel> |
|
struct enable_sm90_only : Kernel { |
|
template <typename... Args> |
|
CUTLASS_DEVICE void operator()(Args&&... args) { |
|
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 900 |
|
Kernel::operator()(std::forward<Args>(args)...); |
|
#endif |
|
} |
|
}; |
|
|
|
template <typename Kernel> |
|
struct enable_sm100_only : Kernel { |
|
template <typename... Args> |
|
CUTLASS_DEVICE void operator()(Args&&... args) { |
|
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1000 |
|
Kernel::operator()(std::forward<Args>(args)...); |
|
#endif |
|
} |
|
}; |
|
|