#pragma once #include #include #include namespace at { enum class ROCmFABackend : int8_t { Default, AOTriton, Ck }; inline std::string ROCmFABackendToString(at::ROCmFABackend backend) { switch (backend) { case ROCmFABackend::Default: return "at::ROCmFABackend::Default"; case ROCmFABackend::AOTriton: return "at::ROCmFABackend::AOTriton"; case ROCmFABackend::Ck: return "at::ROCmFABackend::Ck"; default: TORCH_CHECK(false, "Unknown ROCm flash attention backend") } } inline std::ostream& operator<<( std::ostream& stream, at::ROCmFABackend backend) { return stream << ROCmFABackendToString(backend); } } // namespace at