|
#ifndef BLOCKPARTY_CSRC_CUDA_UTIL_H_ |
|
#define BLOCKPARTY_CSRC_CUDA_UTIL_H_ |
|
|
|
#include <cuda_fp16.h> |
|
#include <cuda_runtime.h> |
|
|
|
|
|
namespace megablocks { |
|
|
|
typedef __half2 half2; |
|
|
|
struct __align__(8) half4 { |
|
half2 x, y; |
|
}; |
|
|
|
struct __align__(16) half8 { |
|
half2 x, y, z, w; |
|
}; |
|
|
|
template <class To, class From> |
|
__device__ __forceinline__ To BitCast(const From& src) noexcept { |
|
To dst; |
|
std::memcpy(&dst, &src, sizeof(To)); |
|
return dst; |
|
} |
|
|
|
template <typename T> |
|
__device__ __forceinline__ void Store(const T& value, T* ptr) { |
|
*ptr = value; |
|
} |
|
|
|
template <typename T> |
|
__device__ __forceinline__ T Load(const T* address) { |
|
return __ldg(address); |
|
} |
|
|
|
__device__ __forceinline__ half4 Load(const half4* address) { |
|
float2 x = __ldg(reinterpret_cast<const float2*>(address)); |
|
return BitCast<half4>(x); |
|
} |
|
|
|
__device__ __forceinline__ half8 Load(const half8* address) { |
|
float4 x = __ldg(reinterpret_cast<const float4*>(address)); |
|
return BitCast<half8>(x); |
|
} |
|
|
|
template <typename T> |
|
__device__ __forceinline__ T Zero() { return 0; }; |
|
|
|
template <> |
|
__device__ __forceinline__ half2 Zero<half2>() { |
|
return {(c10::Half)0., (c10::Half)0.}; |
|
}; |
|
|
|
template <> |
|
__device__ __forceinline__ half4 Zero<half4>() { |
|
return {Zero<half2>(), Zero<half2>()}; |
|
}; |
|
|
|
} |
|
|
|
#endif |
|
|