From e81d7f11dede2b9b3f82de00a433eccc3d47c25e Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Thu, 30 Jan 2025 23:49:14 +0800 Subject: [PATCH] add tensorrt_llm moe_gemm as 3rdparty (#3217) --- .../tensorrt_llm/common/cudaBf16Wrapper.h | 21 + .../tensorrt_llm/common/cudaDriverWrapper.cpp | 187 ---- .../tensorrt_llm/common/cudaDriverWrapper.h | 138 --- .../launchers/fused_moe_gemm_launcher_sm80.h | 25 + .../fused_moe_gemm_launcher_sm80.inl | 96 ++ .../launchers/moe_gemm_launcher_sm90.h | 37 + .../launchers/moe_gemm_launcher_sm90.inl | 348 ++++++++ .../moe_gemm/moe_gemm_hopper_input.cu | 131 +++ .../moe_gemm/moe_gemm_kernels.h | 230 +++++ .../moe_gemm/moe_gemm_kernels_bf16_bf16.cu | 24 + .../moe_gemm/moe_gemm_kernels_bf16_uint4.cu | 24 + .../moe_gemm/moe_gemm_kernels_bf16_uint8.cu | 24 + .../moe_gemm/moe_gemm_kernels_fp16_fp16.cu | 22 + .../moe_gemm/moe_gemm_kernels_fp16_uint4.cu | 22 + .../moe_gemm/moe_gemm_kernels_fp16_uint8.cu | 22 + .../moe_gemm/moe_gemm_kernels_fp32_fp32.cu | 22 + .../moe_gemm/moe_gemm_kernels_fp8_fp8.cu | 28 + .../moe_gemm/moe_gemm_kernels_template.h | 823 ++++++++++++++++++ .../moe_gemm/moe_gemm_kernels_template_sm90.h | 222 +++++ .../moe_gemm/moe_sm90_traits.h | 44 + 20 files changed, 2165 insertions(+), 325 deletions(-) create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Wrapper.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.inl create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_hopper_input.cu create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template_sm90.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Wrapper.h b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Wrapper.h new file mode 100644 index 000000000..fb2a89af5 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Wrapper.h @@ -0,0 +1,21 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#ifdef ENABLE_BF16 +#include +#endif diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp deleted file mode 100644 index 7eca46a1c..000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#define CUDA_LIB_NAME "cuda" - -#if defined(_WIN32) -#include -#define dllOpen(name) LoadLibrary("nv" name ".dll") -#define dllClose(handle) FreeLibrary(static_cast(handle)) -#define dllGetSym(handle, name) static_cast(GetProcAddress(static_cast(handle), name)) -#else // For non-Windows platforms -#include -#define dllOpen(name) dlopen("lib" name ".so.1", RTLD_LAZY) -#define dllClose(handle) dlclose(handle) -#define dllGetSym(handle, name) dlsym(handle, name) -#endif // defined(_WIN32) - -#include "cudaDriverWrapper.h" -#include "tensorrt_llm/common/assert.h" -#include -#include - -namespace tensorrt_llm::common -{ - -std::shared_ptr CUDADriverWrapper::getInstance() -{ - static std::mutex mutex; - static std::weak_ptr instance; - std::shared_ptr result = instance.lock(); - if (result) - { - return result; - } - - std::lock_guard lock(mutex); - result = instance.lock(); - if (!result) - { - result = std::shared_ptr(new CUDADriverWrapper()); - instance = result; - } - return result; -} - -CUDADriverWrapper::CUDADriverWrapper() - : handle(dllOpen(CUDA_LIB_NAME)) -{ - - TLLM_CHECK_WITH_INFO(handle != nullptr, "CUDA driver library is not open correctly."); - - auto load_sym = [](void* handle, char const* name) - { - void* ret = dllGetSym(handle, name); - return ret; - }; - - *reinterpret_cast(&_cuGetErrorName) = load_sym(handle, "cuGetErrorName"); - *reinterpret_cast(&_cuGetErrorMessage) = load_sym(handle, "cuGetErrorMessage"); - *reinterpret_cast(&_cuFuncSetAttribute) = load_sym(handle, "cuFuncSetAttribute"); - *reinterpret_cast(&_cuLinkComplete) = load_sym(handle, "cuLinkComplete"); - *reinterpret_cast(&_cuModuleUnload) = load_sym(handle, "cuModuleUnload"); - *reinterpret_cast(&_cuLinkDestroy) = load_sym(handle, "cuLinkDestroy"); - *reinterpret_cast(&_cuModuleLoadData) = load_sym(handle, "cuModuleLoadData"); - *reinterpret_cast(&_cuLinkCreate) = load_sym(handle, "cuLinkCreate_v2"); - *reinterpret_cast(&_cuModuleGetFunction) = load_sym(handle, "cuModuleGetFunction"); - *reinterpret_cast(&_cuModuleGetGlobal) = load_sym(handle, "cuModuleGetGlobal_v2"); - *reinterpret_cast(&_cuLinkAddFile) = load_sym(handle, "cuLinkAddFile_v2"); - *reinterpret_cast(&_cuLinkAddData) = load_sym(handle, "cuLinkAddData_v2"); - *reinterpret_cast(&_cuLaunchCooperativeKernel) = load_sym(handle, "cuLaunchCooperativeKernel"); - *reinterpret_cast(&_cuLaunchKernel) = load_sym(handle, "cuLaunchKernel"); - *reinterpret_cast(&_cuTensorMapEncodeTiled) = load_sym(handle, "cuTensorMapEncodeTiled"); - *reinterpret_cast(&_cuMemcpyDtoH) = load_sym(handle, "cuMemcpyDtoH_v2"); -} - -CUDADriverWrapper::~CUDADriverWrapper() -{ - dllClose(handle); -} - -CUresult CUDADriverWrapper::cuGetErrorName(CUresult error, char const** pStr) const -{ - return (*_cuGetErrorName)(error, pStr); -} - -CUresult CUDADriverWrapper::cuGetErrorMessage(CUresult error, char const** pStr) const -{ - return (*_cuGetErrorMessage)(error, pStr); -} - -CUresult CUDADriverWrapper::cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const -{ - return (*_cuFuncSetAttribute)(hfunc, attrib, value); -} - -CUresult CUDADriverWrapper::cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const -{ - return (*_cuLinkComplete)(state, cubinOut, sizeOut); -} - -CUresult CUDADriverWrapper::cuModuleUnload(CUmodule hmod) const -{ - return (*_cuModuleUnload)(hmod); -} - -CUresult CUDADriverWrapper::cuLinkDestroy(CUlinkState state) const -{ - return (*_cuLinkDestroy)(state); -} - -CUresult CUDADriverWrapper::cuModuleLoadData(CUmodule* module, void const* image) const -{ - return (*_cuModuleLoadData)(module, image); -} - -CUresult CUDADriverWrapper::cuLinkCreate( - unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const -{ - return (*_cuLinkCreate)(numOptions, options, optionValues, stateOut); -} - -CUresult CUDADriverWrapper::cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const -{ - return (*_cuModuleGetFunction)(hfunc, hmod, name); -} - -CUresult CUDADriverWrapper::cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const -{ - return (*_cuModuleGetGlobal)(dptr, bytes, hmod, name); -} - -CUresult CUDADriverWrapper::cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, - unsigned int numOptions, CUjit_option* options, void** optionValues) const -{ - return (*_cuLinkAddFile)(state, type, path, numOptions, options, optionValues); -} - -CUresult CUDADriverWrapper::cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, - char const* name, unsigned int numOptions, CUjit_option* options, void** optionValues) const -{ - return (*_cuLinkAddData)(state, type, data, size, name, numOptions, options, optionValues); -} - -CUresult CUDADriverWrapper::cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, - unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, - unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const -{ - return (*_cuLaunchCooperativeKernel)( - f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams); -} - -CUresult CUDADriverWrapper::cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, - unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, - unsigned int sharedMemBytes, CUstream hStream, void** kernelParams, void** extra) const -{ - return (*_cuLaunchKernel)( - f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra); -} - -CUresult CUDADriverWrapper::cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, - cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, - cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, - CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const -{ - return (*_cuTensorMapEncodeTiled)(tensorMap, tensorDataType, tensorRank, globalAddress, globalDim, globalStrides, - boxDim, elementStrides, interleave, swizzle, l2Promotion, oobFill); -} - -CUresult CUDADriverWrapper::cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const -{ - return (*_cuMemcpyDtoH)(dstHost, srcDevice, ByteCount); -} - -} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h deleted file mode 100644 index c4d470a85..000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef CUDA_DRIVER_WRAPPER_H -#define CUDA_DRIVER_WRAPPER_H - -#include "tensorrt_llm/common/assert.h" -#include -#include -#include -#include - -namespace tensorrt_llm::common -{ - -class CUDADriverWrapper -{ -public: - static std::shared_ptr getInstance(); - - ~CUDADriverWrapper(); - CUDADriverWrapper(CUDADriverWrapper const&) = delete; - CUDADriverWrapper operator=(CUDADriverWrapper const&) = delete; - CUDADriverWrapper(CUDADriverWrapper&&) = delete; - CUDADriverWrapper operator=(CUDADriverWrapper&&) = delete; - - CUresult cuGetErrorName(CUresult error, char const** pStr) const; - - CUresult cuGetErrorMessage(CUresult error, char const** pStr) const; - - CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const; - - CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const; - - CUresult cuModuleUnload(CUmodule hmod) const; - - CUresult cuLinkDestroy(CUlinkState state) const; - - CUresult cuModuleLoadData(CUmodule* module, void const* image) const; - - CUresult cuLinkCreate( - unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const; - - CUresult cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const; - - CUresult cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const; - - CUresult cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, unsigned int numOptions, - CUjit_option* options, void** optionValues) const; - - CUresult cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, char const* name, - unsigned int numOptions, CUjit_option* options, void** optionValues) const; - - CUresult cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, - unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, - unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const; - - CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, - unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, - CUstream hStream, void** kernelParams, void** extra) const; - - CUresult cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, cuuint32_t tensorRank, - void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, cuuint32_t const* boxDim, - cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle, - CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const; - - CUresult cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const; - -private: - void* handle; - CUDADriverWrapper(); - - CUresult (*_cuGetErrorName)(CUresult, char const**); - CUresult (*_cuGetErrorMessage)(CUresult, char const**); - CUresult (*_cuFuncSetAttribute)(CUfunction, CUfunction_attribute, int); - CUresult (*_cuLinkComplete)(CUlinkState, void**, size_t*); - CUresult (*_cuModuleUnload)(CUmodule); - CUresult (*_cuLinkDestroy)(CUlinkState); - CUresult (*_cuLinkCreate)(unsigned int, CUjit_option*, void**, CUlinkState*); - CUresult (*_cuModuleLoadData)(CUmodule*, void const*); - CUresult (*_cuModuleGetFunction)(CUfunction*, CUmodule, char const*); - CUresult (*_cuModuleGetGlobal)(CUdeviceptr*, size_t*, CUmodule, char const*); - CUresult (*_cuLinkAddFile)(CUlinkState, CUjitInputType, char const*, unsigned int, CUjit_option*, void**); - CUresult (*_cuLinkAddData)( - CUlinkState, CUjitInputType, void*, size_t, char const*, unsigned int, CUjit_option*, void**); - CUresult (*_cuLaunchCooperativeKernel)(CUfunction, unsigned int, unsigned int, unsigned int, unsigned int, - unsigned int, unsigned int, unsigned int, CUstream, void**); - CUresult (*_cuLaunchKernel)(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, - unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, - CUstream hStream, void** kernelParams, void** extra); - CUresult (*_cuTensorMapEncodeTiled)(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, - cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, - cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, - CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill); - CUresult (*_cuMemcpyDtoH)(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount); -}; - -template -void checkDriver( - T result, CUDADriverWrapper const& wrap, char const* const func, char const* const file, int const line) -{ - if (result) - { - char const* errorName = nullptr; - char const* errorMsg = nullptr; - wrap.cuGetErrorName(result, &errorName); - wrap.cuGetErrorMessage(result, &errorMsg); - throw TllmException( - file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA driver error in %s: %s: %s", func, errorName, errorMsg)); - } -} - -} // namespace tensorrt_llm::common - -/* - * Macros compliant with TensorRT coding conventions - */ -#define TLLM_CU_CHECK(stat) \ - do \ - { \ - tensorrt_llm::common::checkDriver( \ - (stat), *tensorrt_llm::common::CUDADriverWrapper::getInstance(), #stat, __FILE__, __LINE__); \ - } while (0) - -#endif // CUDA_DRIVER_WRAPPER_H diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h new file mode 100644 index 000000000..f4eed277c --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace tensorrt_llm::kernels::cutlass_kernels +{ +template +void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWeightType_ const* B, + ElementType_ const* biases, bool bias_is_broadcast, ElementType_* C, int64_t const* total_tokens_including_expert, + int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream, + int* kernel_occupancy); +} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl new file mode 100644 index 000000000..126e761ec --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cutlass/array.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" + +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" + +#include +#include +#include + +namespace tensorrt_llm::kernels::cutlass_kernels +{ +template +void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWeightType_ const* B, + ElementType_ const* biases, bool bias_is_broadcast, ElementType_* C, int64_t const* total_tokens_including_expert, + int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream, + int* kernel_occupancy) +{ + constexpr auto activation_type = fused_moe::EpilogueRouting(true); + using GemmType = fused_moe::Fused_Moe_Kernel_sm80; + + // make sure GPU has enough resources.. + if (kernel_occupancy != nullptr) + { + constexpr int smem_size = GemmType::kSmemSize; + + if (smem_size > (48 << 10)) + { + cudaFuncAttributes attr{}; + int device = 0; + int max_smem_per_block = 0; + tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device)); + tensorrt_llm::common::check_cuda_error( + cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); + tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, fused_moe::run_global)); + if (smem_size + attr.sharedSizeBytes >= static_cast(max_smem_per_block)) + { + // This should mean that + // cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + // smem_size) wouldn't work. In that case, we return an occupancy of 0. This will cause the + // heuristic to ignore this configuration. + *kernel_occupancy = 0; + return; + } + } + + int max_active_blocks = -1; + tensorrt_llm::common::check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, fused_moe::run_global, GemmType::kThreadCount, smem_size)); + *kernel_occupancy = max_active_blocks; + return; + } + int occupancy = std::min(2, fused_moe::fused_gemm_maximum_active_blocks()); + int const threadblock_count = multi_processor_count * occupancy; + TLLM_CHECK_WITH_INFO(occupancy > 0, "GPU lacks the shared memory resources to run fused_moe kernel"); + using Arguments = typename GemmType::Arguments; + Arguments args{{const_cast(A), const_cast(B), const_cast(biases), + reinterpret_cast(C), total_tokens_including_expert, static_cast(gemm_n), + static_cast(gemm_k), num_experts, bias_is_broadcast}, + num_experts, threadblock_count}; + auto params = GemmType::to_underlying_arguments(args); + if (GemmType::kSmemSize >= (48 << 10)) + { + cudaError_t result = cudaFuncSetAttribute( + fused_moe::run_global, cudaFuncAttributeMaxDynamicSharedMemorySize, GemmType::kSmemSize); + TLLM_CHECK_WITH_INFO(result == cudaSuccess, + "Fail to set the max smem size to " + std::to_string(GemmType::kSmemSize) + " for fused moe kernel"); + } + dim3 grid(params.threadblock_count, 1, 1); + dim3 block(GemmType::kThreadCount); + fused_moe::run_global<<>>(params); + auto result = cudaGetLastError(); + TLLM_CHECK_WITH_INFO(result == cudaSuccess, "Fail to execute fused moe kernel, cuda error %d\n", (int) (result)); +} +} // namespace tensorrt_llm::kernels::cutlass_kernels diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h new file mode 100644 index 000000000..91527fadb --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h @@ -0,0 +1,37 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h" +#include + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ + +// Keep in sync with the signature generated by generate_kernels.py +template +void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, int num_experts, + int multi_processor_count, cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size); + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.inl b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.inl new file mode 100644 index 000000000..cca60a981 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.inl @@ -0,0 +1,348 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "cutlass/array.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" + +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/tensor_ref.h" + +#include "cutlass_extensions/compute_occupancy.h" +#include "cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp" +#include "cutlass_extensions/epilogue_helpers.h" +#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" +#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h" +#include "cutlass_extensions/gemm/threadblock/default_mma.h" + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h" +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h" + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h" +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h" + +#include +#include +#include +#include + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +using EpilogueFusion = HopperGroupedGemmInput::EpilogueFusion; + +// Hopper helper class for defining all the cutlass helper types +template +struct HopperGroupedGemmInfo +{ + using Arch = cutlass::arch::Sm90; + + // TODO Update once mixed input support is added + static_assert(cutlass::platform::is_same::value, + "CUTLASS does not currently have specialised SM90 support for quantized operations"); + +#ifdef ENABLE_FP8 + constexpr static bool IsFP8 + = cutlass::platform::is_same::value || cutlass::platform::is_same::value; +#else + constexpr static bool IsFP8 = false; +#endif + +#ifdef ENABLE_BF16 + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value + || cutlass::platform::is_same::value || IsFP8, + "Specialized for bfloat16, half, float, fp8"); +#else + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value || IsFP8, + "Specialized for half, float, fp8"); +#endif + + static_assert(cutlass::platform::is_same::value + || cutlass::platform::is_same::value + || cutlass::platform::is_same::value + || cutlass::platform::is_same::value + || cutlass::platform::is_same::value, + "Unexpected quantization type"); + + // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. + using ElementType = typename TllmToCutlassTypeAdapter::type; + + using CutlassWeightTypeMaybeUint4 = typename TllmToCutlassTypeAdapter::type; + // For legacy reasons we convert unsigned 8-bit to signed + using CutlassWeightTypeMaybeUint8 + = std::conditional_t, cutlass::int4b_t, + CutlassWeightTypeMaybeUint4>; + using CutlassWeightType + = std::conditional_t, int8_t, CutlassWeightTypeMaybeUint8>; + + using ElementA = ElementType; + using ElementB = CutlassWeightType; + + using ElementD = typename TllmToCutlassTypeAdapter>::type; + using ElementFinalOutput = typename TllmToCutlassTypeAdapter::type; + + // using ElementC = std::conditional_t; + // using ElementCNoVoid = std::conditional_t; + using ElementC = void; + using ElementCNoVoid = ElementD; + + using ElementAccumulator = float; + + using ElementBias = ElementFinalOutput; + using ElementRouterScales = float; + + // A matrix configuration - this is transposed and swapped with B + using LayoutA = HopperGroupedGemmInput::LayoutA; + constexpr static int AlignmentA + = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units + // of elements (up to 16 bytes) + + // B matrix configuration - this is transposed and swapped with A + using LayoutB = HopperGroupedGemmInput::LayoutB; // Layout type for B matrix operand + constexpr static int AlignmentB + = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units + // of elements (up to 16 bytes) + + // C matrix configuration + using LayoutC = HopperGroupedGemmInput::LayoutC; // Layout type for C matrix operand + using StrideC = HopperGroupedGemmInput::StrideC; + // Note we use ElementType here deliberately, so we don't break when BIAS is disabled + constexpr static int AlignmentC + = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units + // of elements (up to 16 bytes) + + // D matrix configuration + using LayoutD = HopperGroupedGemmInput::DefaultEpilogue::LayoutD; + using StrideD = HopperGroupedGemmInput::DefaultEpilogue::StrideD; + constexpr static int AlignmentD + = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix + // in units of elements (up to 16 bytes) + + static_assert(cutlass::platform::is_same::value, + "Hopper Grouped GEMM specialisation doesn't support fused activation"); + + using EpilogueOp + = cutlass::epilogue::fusion::LinearCombination; + + // TODO Add mode for fused activation once CUTLASS adds support + // using EpilogueSchedule = cutlass::platform::conditional_t< + // cutlass::platform::is_same::value, + // cutlass::epilogue::PtrArrayNoSmemWarpSpecialized, + // cutlass::epilogue::?????????????????? /// <<<<<< what supports activations + // >; + using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; + + // Epilogue For Default Finalize + using CollectiveEpilogueDefault = typename cutlass::epilogue::collective::CollectiveBuilder< // + Arch, cutlass::arch::OpClassTensorOp, // + TileShape, ClusterShape, // + cutlass::epilogue::collective::EpilogueTileAuto, // + ElementAccumulator, ElementAccumulator, // + ElementC, LayoutC*, AlignmentC, // + ElementD, LayoutD*, AlignmentD, // + EpilogueSchedule>::CollectiveOp; + + // Epilogue For Fused Finalize + using CollectiveEpilogueFinalize = typename cutlass::epilogue::collective::EpilogueMoeFusedFinalizeBuilder< // + TileShape, // + ElementCNoVoid, StrideC*, // + ElementFinalOutput, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideFinalOutput, // + ElementAccumulator, // + ElementAccumulator, // + ElementBias, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideBias, // + ElementRouterScales, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideRouterScales // + >::CollectiveOp; + + using CollectiveEpilogue + = std::conditional_t; + + using StageCountAutoCarveout = cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>; + + using KernelSchedule + = std::conditional_t; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< // + Arch, cutlass::arch::OpClassTensorOp, // + CutlassWeightType, LayoutB*, AlignmentB, // A & B swapped here + ElementType, LayoutA*, AlignmentA, // + ElementAccumulator, // + TileShape, ClusterShape, // + StageCountAutoCarveout, KernelSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal; + + using GemmGrouped = cutlass::gemm::device::GemmUniversalAdapter; +}; + +// Hopper specialised version +template +void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, int num_experts, + int const multi_processor_count, cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size) +{ +#ifdef COMPILE_HOPPER_TMA_GEMMS + using namespace cute; + if constexpr (!should_filter_sm90_gemm_problem_shape_v) + { + using GemmInfo + = HopperGroupedGemmInfo; + + using ElementAccumulator = typename GemmInfo::ElementAccumulator; + using ElementA = typename GemmInfo::ElementA; + using ElementB = typename GemmInfo::ElementB; + using ElementC = typename GemmInfo::ElementC; + using ElementCNoVoid = typename GemmInfo::ElementCNoVoid; + using ElementD = typename GemmInfo::ElementD; + + using CollectiveMainloop = typename GemmInfo::CollectiveMainloop; + using CollectiveEpilogue = typename GemmInfo::CollectiveEpilogue; + using GemmKernel = typename GemmInfo::GemmKernel; + using GemmGrouped = typename GemmInfo::GemmGrouped; + + if (kernel_occupancy != nullptr) + { + *kernel_occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel(); + return; + } + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = multi_processor_count; + + GemmGrouped gemm; + + if (workspace_size != nullptr) + { + // Make a mock problem shape with just the minimal information actually required to get the workspace size + // This makes some assumptions about CUTLASS's implementation which is suboptimal. We have a check later to + // catch future cutlass updates causing silent breakages, but that is not fool proof. + // The alternative is to wait until we have data and then dynamically allocate the workspace + typename HopperGroupedGemmInput::ProblemShape shape_info{num_experts, nullptr, nullptr}; + + typename GemmGrouped::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGrouped, shape_info, {}, {}, hw_info}; + *workspace_size = gemm.get_workspace_size(args); + return; + } + + using MainloopArguments = typename CollectiveMainloop::Arguments; + TLLM_CHECK(hopper_input.stride_a); + TLLM_CHECK(hopper_input.stride_b); + TLLM_CHECK(hopper_input.ptr_a); + TLLM_CHECK(hopper_input.ptr_b); + + MainloopArguments const mainloop_params = {reinterpret_cast(hopper_input.ptr_b), + hopper_input.stride_b, reinterpret_cast(hopper_input.ptr_a), hopper_input.stride_a}; + + typename GemmGrouped::EpilogueOutputOp::Params epilogue_scalars{ + ElementAccumulator(1.f), hopper_input.ptr_c ? ElementAccumulator(1.f) : ElementAccumulator(0.f)}; + epilogue_scalars.alpha_ptr_array = hopper_input.alpha_scale_ptr_array; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + // TODO(dastokes) ptr_c casts to ElementCNoVoid** because there is a workaround in CUTLASS + auto make_epi_args = [&]() + { + if constexpr (FUSION == EpilogueFusion::NONE) + { + auto epi_params = hopper_input.default_epilogue; + return EpilogueArguments{epilogue_scalars, reinterpret_cast(hopper_input.ptr_c), + hopper_input.stride_c, reinterpret_cast(epi_params.ptr_d), epi_params.stride_d}; + } + else if constexpr (FUSION == EpilogueFusion::FINALIZE) + { + // Parameters for fused finalize + auto epi_params = hopper_input.fused_finalize_epilogue; + return EpilogueArguments{ + epilogue_scalars, // Parameters to underlying epilogue + reinterpret_cast(hopper_input.ptr_c), hopper_input.stride_c, // C params + reinterpret_cast(epi_params.ptr_final_output), + epi_params.stride_final_output, // D (output) params + reinterpret_cast(epi_params.ptr_bias), + epi_params.stride_bias, // Bias params + epi_params.ptr_router_scales, epi_params.stride_router_scales, // Router scales + epi_params.ptr_expert_first_token_offset, // Offset of this expert's token in the router scales + epi_params.ptr_source_token_index, // Index of the source token to sum into + epi_params.num_rows_in_final_output // Number of tokens in the output buffer + }; + } + else + { + static_assert( + sizeof(EpilogueArguments) == 0, "Unimplemented fusion provided to SM90+ MoE gemm launcher"); + } + }; + EpilogueArguments const epilogue_params = make_epi_args(); + + typename GemmKernel::TileScheduler::Arguments scheduler_args{ + 1, GemmKernel::TileScheduler::RasterOrderOptions::AlongN}; + + typename GemmGrouped::Arguments args{cutlass::gemm::GemmUniversalMode::kGrouped, hopper_input.shape_info, + mainloop_params, epilogue_params, hw_info, scheduler_args}; + + size_t calculated_ws_size = gemm.get_workspace_size(args); + TLLM_CHECK_WITH_INFO(calculated_ws_size <= hopper_input.gemm_workspace_size, + "Workspace is size %zu but only %zu were allocated", calculated_ws_size, hopper_input.gemm_workspace_size); + + auto can_implement = gemm.can_implement(args); + TLLM_CHECK_WITH_INFO(can_implement == cutlass::Status::kSuccess, + "Grouped GEMM kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement))); + + auto init_status = gemm.initialize(args, hopper_input.gemm_workspace); + TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess, + "Failed to initialize cutlass SM90 grouped gemm. Error: " + + std::string(cutlassGetStatusString(init_status))); + + auto run_status = gemm.run(stream); + TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess, + "Failed to run cutlass SM90 grouped gemm. Error: " + std::string(cutlassGetStatusString(run_status))); + sync_check_cuda_error(); + } + else + { + TLLM_THROW("Configuration was disabled by FAST_BUILD"); + } + +#else // COMPILE_HOPPER_TMA_GEMMS + TLLM_THROW("Please recompile with support for hopper by passing 90-real as an arch to build_wheel.py."); +#endif // COMPILE_HOPPER_TMA_GEMMS +} + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_hopper_input.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_hopper_input.cu new file mode 100644 index 000000000..9862460dd --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_hopper_input.cu @@ -0,0 +1,131 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h" + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/conv/convolution.h" +// Order matters here, packed_stride.hpp is missing cute and convolution includes +#include "cutlass/util/packed_stride.hpp" + +#include "tensorrt_llm/common/logger.h" + +namespace tensorrt_llm +{ +std::array HopperGroupedGemmInput::workspaceBuffers(int num_experts) +{ + size_t problem_shape_size = sizeof(ProblemShape::UnderlyingProblemShape) * num_experts; + size_t stride_a_size = sizeof(StrideA) * num_experts; + size_t stride_b_size = sizeof(StrideB) * num_experts; + size_t stride_c_size = sizeof(StrideC) * num_experts; + size_t stride_d_size = sizeof(DefaultEpilogue::StrideD) * num_experts; + + size_t ptr_buf_size = sizeof(void*) * num_experts; + size_t scale_buf_size = sizeof(float*) * num_experts; + + return std::array{problem_shape_size, stride_a_size, stride_b_size, stride_c_size, stride_d_size, ptr_buf_size, + ptr_buf_size, ptr_buf_size, ptr_buf_size, scale_buf_size}; +} + +size_t HopperGroupedGemmInput::workspaceSize(int num_experts) +{ + auto buffers = workspaceBuffers(num_experts); + return tensorrt_llm::common::calculateTotalWorkspaceSize(buffers.data(), buffers.size()); +} + +void HopperGroupedGemmInput::configureWorkspace( + int8_t* start_ptr, int num_experts, void* gemm_workspace, size_t gemm_workspace_size) +{ + auto buffers = workspaceBuffers(num_experts); + std::array pointers{}; + TLLM_CHECK_WITH_INFO(pointers.size() == buffers.size(), "Mismatching workspace size and number of buffers"); + for (int i = 0; i < buffers.size(); i++) + { + pointers[i] = start_ptr; + start_ptr = tensorrt_llm::common::nextWorkspacePtr(start_ptr, buffers[i]); + } + + shape_info.num_groups = num_experts; + shape_info.problem_shapes = reinterpret_cast(pointers[0]); + shape_info.host_problem_shapes = nullptr; + stride_a = reinterpret_cast(pointers[1]); + stride_b = reinterpret_cast(pointers[2]); + stride_c = reinterpret_cast(pointers[3]); + default_epilogue.stride_d = reinterpret_cast(pointers[4]); + + ptr_a = reinterpret_cast(pointers[5]); + ptr_b = reinterpret_cast(pointers[6]); + ptr_c = reinterpret_cast(pointers[7]); + default_epilogue.ptr_d = reinterpret_cast(pointers[8]); + + alpha_scale_ptr_array = reinterpret_cast(pointers[9]); + + this->gemm_workspace = reinterpret_cast(gemm_workspace); + this->gemm_workspace_size = gemm_workspace_size; +} + +void HopperGroupedGemmInput::setFinalizeFusionParams(void* final_output, float const* router_scales, + int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size, + int num_output_tokens) +{ + fused_finalize_epilogue.ptr_final_output = final_output; + fused_finalize_epilogue.ptr_router_scales = router_scales; + fused_finalize_epilogue.ptr_bias = bias; + fused_finalize_epilogue.ptr_expert_first_token_offset = expert_first_token_offset; + fused_finalize_epilogue.ptr_source_token_index = source_token_index; + + fused_finalize_epilogue.stride_final_output + = cutlass::make_cute_packed_stride(FusedFinalizeEpilogue::StrideFinalOutput{}, + transpose_stride(cute::make_shape(num_output_tokens, hidden_size, 1))); + fused_finalize_epilogue.stride_bias + = transpose_stride(cute::make_stride(cute::Int<0>{}, cute::Int<1>{}, hidden_size)); + fused_finalize_epilogue.stride_router_scales = {}; + + fused_finalize_epilogue.num_rows_in_final_output = num_output_tokens; +} + +std::string HopperGroupedGemmInput::toString() const +{ + std::stringstream ss; + ss << "Hopper Input Information: " << (isValid() ? "valid" : "null") << "\n"; + if (isValid()) + { + ss << "Ptr A: " << ptr_a << ", Ptr B: " << ptr_b << ", Ptr C: " << ptr_c << "\n"; + ss << "Epilogue Fusion: " << (int) fusion; + if (fusion == HopperGroupedGemmInput::EpilogueFusion::FINALIZE) + { + ss << ",\nFinal Output: " << fused_finalize_epilogue.ptr_final_output; + ss << " with Stride: " << fused_finalize_epilogue.stride_router_scales; + ss << ",\nBias: " << fused_finalize_epilogue.ptr_bias; + ss << " with Stride: " << fused_finalize_epilogue.stride_bias; + ss << ",\nRouter Scales: " << fused_finalize_epilogue.ptr_router_scales; + ss << " with Stride: " << fused_finalize_epilogue.stride_router_scales; + ss << ",\nExpert Offset: " << fused_finalize_epilogue.ptr_expert_first_token_offset; + ss << ", Source Map: " << fused_finalize_epilogue.ptr_source_token_index; + } + else + { + ss << ", Ptr D: " << default_epilogue.ptr_d; + } + ss << '\n'; + ss << "Alpha scale ptr: " << alpha_scale_ptr_array << "\n"; + } + return ss.str(); +} +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h new file mode 100644 index 000000000..0616c0636 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h @@ -0,0 +1,230 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "tensorrt_llm/common/cudaFp8Utils.h" +#include "tensorrt_llm/common/workspace.h" +#include "tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h" +#include +#include +#include +#include + +#include "cute/tensor.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/layout/layout.h" + +namespace tensorrt_llm +{ +template +constexpr auto transpose_stride(T const& t) +{ + return cute::prepend(cute::prepend(cute::take<2, cute::rank_v>(t), cute::get<0>(t)), cute::get<1>(t)); +} + +struct HopperGroupedGemmInput +{ + template + using TransposeStride = decltype(transpose_stride(T{})); + template + using TransposeLayoutTag = std::conditional_t, + cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>; + + static_assert(std::is_same_v>); + static_assert(std::is_same_v>); + + // Layout for A and B is transposed and then swapped in the implementation + // This uses B^T * A^T = (A * B)^T to get a better layout for the GEMM + using LayoutA = TransposeLayoutTag; // Layout type for A matrix operand + using LayoutB = TransposeLayoutTag; // Layout type for B matrix operand + using LayoutC = TransposeLayoutTag; // Layout type for C matrix operand + + using StrideA + = std::remove_pointer_t>; // Use B because they will be swapped + using StrideB + = std::remove_pointer_t>; // Use A because they will be swapped + using StrideC = std::remove_pointer_t>; + + template + constexpr static bool IsFP8_v = std::is_same_v || std::is_same_v; + + // Currently this should always just be T + template + using OutputTypeAdaptor_t = std::conditional_t, nv_bfloat16, T>; + + using ProblemShape = cutlass::gemm::GroupProblemShape>; + + ProblemShape shape_info{}; + StrideA* stride_a = nullptr; + StrideB* stride_b = nullptr; + + void const** ptr_a = nullptr; + void const** ptr_b = nullptr; + + // C is currently the same in both epilogues + StrideC* stride_c = nullptr; + void const** ptr_c = nullptr; + + struct DefaultEpilogue + { + using LayoutD = TransposeLayoutTag; // Layout type for D matrix operand + using StrideD = std::remove_pointer_t>; + + StrideD* stride_d = nullptr; + void** ptr_d = nullptr; + }; + + struct FusedFinalizeEpilogue + { + using StrideFinalOutput = DefaultEpilogue::StrideD; + using StrideBias = TransposeStride>; + using StrideRouterScales = TransposeStride>; + + void* ptr_final_output = nullptr; + StrideFinalOutput stride_final_output{}; + + void const* ptr_bias = nullptr; + StrideBias stride_bias{}; + + float const* ptr_router_scales = nullptr; + StrideRouterScales stride_router_scales{}; + + int64_t const* ptr_expert_first_token_offset = nullptr; + int const* ptr_source_token_index = nullptr; + + size_t num_rows_in_final_output = 0; + }; + + DefaultEpilogue default_epilogue; + FusedFinalizeEpilogue fused_finalize_epilogue; + + enum class EpilogueFusion + { + NONE, + ACTIVATION, + GATED_ACTIVATION, + FINALIZE + }; + EpilogueFusion fusion = EpilogueFusion::NONE; + + float const** alpha_scale_ptr_array = nullptr; + + uint8_t* gemm_workspace = nullptr; + size_t gemm_workspace_size = 0; + + static std::array workspaceBuffers(int num_experts); + + static size_t workspaceSize(int num_experts); + + void configureWorkspace(int8_t* start_ptr, int num_experts, void* gemm_workspace, size_t gemm_workspace_size); + + bool isValid() const + { + return stride_a != nullptr && ptr_a != nullptr; + } + + void setFinalizeFusionParams(void* final_output, float const* router_scales, + int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size, + int num_output_tokens); + + std::string toString() const; +}; + +// Note update moe.py to match +enum class ActivationType +{ + Gelu = 0, + Relu, + Silu, + Swiglu, + Geglu, + Identity, + InvalidType +}; + +constexpr bool isGatedActivation(ActivationType activation_type) +{ + return activation_type == ActivationType::Swiglu || activation_type == ActivationType::Geglu; +} + +template +class MoeGemmRunner +{ +public: + MoeGemmRunner(); + +#if defined(ENABLE_FP8) + static constexpr bool use_fp8 = std::is_same_v || std::is_same_v; +#else + static constexpr bool use_fp8 = false; +#endif + + void moeGemmBiasAct(T const* A, WeightType const* B, ScaleBiasType const* weight_scales, + ScaleBiasType const* biases, bool bias_is_broadcast, void* C, int64_t const* total_tokens_including_expert, + HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + ActivationType activation_type, bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, + cutlass_extensions::CutlassGemmConfig chosen_conf); + + void moeGemm(T const* A, WeightType const* B, ScaleBiasType const* weight_scales, void* C, + int64_t const* total_tokens_including_expert, HopperGroupedGemmInput layout_info, int64_t total_rows, + int64_t gemm_n, int64_t gemm_k, int num_experts, bool use_fused_moe, float const** alpha_scale_ptr_array, + cudaStream_t stream, cutlass_extensions::CutlassGemmConfig chosen_conf); + + std::vector getConfigs() const; + static std::vector getConfigs(int sm); + static std::vector getHopperConfigs(int sm); + static std::vector getAmpereConfigs(int sm); + + [[nodiscard]] bool isHopperSpecialised(cutlass_extensions::CutlassGemmConfig gemm_config) const; + [[nodiscard]] bool supportsHopperSpecialisation() const; + [[nodiscard]] bool isFusedGatedActivation( + cutlass_extensions::CutlassGemmConfig gemm_config, bool is_gated_activation, int gemm_n, int gemm_k) const; + [[nodiscard]] bool supportsFusedGatedActivation(bool is_gated_activation, int gemm_n, int gemm_k) const; + + size_t getMaxWorkspaceSize(int num_experts) const; + + [[nodiscard]] int getSM() const; + +private: + template + void dispatchToArch(T const* A, WeightType const* B, ScaleBiasType const* weight_scales, + ScaleBiasType const* biases, bool bias_is_broadcast, void* C, int64_t const* total_tokens_including_expert, + HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, bool use_fused_moe, float const** alpha_scale_ptr_array, + cudaStream_t stream, int* occupancy = nullptr); + + template + void runGemm(T const* A, WeightType const* B, ScaleBiasType const* weight_scales, ScaleBiasType const* biases, + bool bias_is_broadcast, void* C, int64_t const* total_tokens_including_expert, + HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, + cutlass_extensions::CutlassGemmConfig chosen_conf); + +private: + int sm_{}; + int multi_processor_count_{}; + mutable int num_experts_ = 0; + mutable size_t gemm_workspace_size_ = 0; + size_t calcMaxWorkspaceSize(int num_experts) const; +}; + +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu new file mode 100644 index 000000000..3aa96502d --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +#ifdef ENABLE_BF16 +template class MoeGemmRunner<__nv_bfloat16, __nv_bfloat16, __nv_bfloat16>; +#endif +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu new file mode 100644 index 000000000..fbb527045 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +#ifdef ENABLE_BF16 +template class MoeGemmRunner<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16>; +#endif +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu new file mode 100644 index 000000000..78f1a93a6 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +#ifdef ENABLE_BF16 +template class MoeGemmRunner<__nv_bfloat16, uint8_t, __nv_bfloat16>; +#endif +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu new file mode 100644 index 000000000..69c4b6a15 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +template class MoeGemmRunner; +} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu new file mode 100644 index 000000000..4ffa5485f --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +template class MoeGemmRunner; +} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu new file mode 100644 index 000000000..424b817b8 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +template class MoeGemmRunner; +} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu new file mode 100644 index 000000000..f31702356 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +template class MoeGemmRunner; +} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu new file mode 100644 index 000000000..c6b8fe787 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +#ifdef ENABLE_FP8 +template class MoeGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, half>; +#ifdef ENABLE_BF16 +template class MoeGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>; +#endif +// template class MoeGemmRunner<__nv_fp8_e5m2, __nv_fp8_e5m2>; +#endif +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h new file mode 100644 index 000000000..2a337e6ca --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h @@ -0,0 +1,823 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Ignore CUTLASS warnings about type punning +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif + +#include "cutlass/array.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" + +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/tensor_ref.h" + +#include "cutlass_extensions/compute_occupancy.h" +#include "cutlass_extensions/epilogue_helpers.h" +#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" +#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h" +#include "cutlass_extensions/gemm/threadblock/default_mma.h" + +#ifdef __GNUC__ // Restore GCC-specific diagnostics +#pragma GCC diagnostic pop +#endif + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/logger.h" + +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h" + +#include "moe_gemm_kernels_template_sm90.h" +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h" +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h" +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h" +#include + +#include +#include +#include +#include + +namespace tensorrt_llm +{ +namespace kernels::cutlass_kernels +{ + +// ============================= Variable batched Gemm things =========================== +template +void genericMoeGemmKernelLauncher(T const* A, WeightType const* B, GemmOutputType const* weight_scales, + GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C, + int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int const multi_processor_count, bool use_fused_moe, + float const** alpha_scale_ptr_array, cudaStream_t stream, int* kernel_occupancy = nullptr) +{ +#if defined(ENABLE_FP8) + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value + || cutlass::platform::is_same::value + || cutlass::platform::is_same::value || cutlass::platform::is_same::value, + "Specialized for fp8, bfloat16, half, float"); +#elif defined(ENABLE_BF16) + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value + || cutlass::platform::is_same::value, + "Specialized for bfloat16, half, float"); +#else + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value, + "Specialized for half, float"); +#endif + + static_assert(cutlass::platform::is_same::value + || cutlass::platform::is_same::value + || cutlass::platform::is_same::value, + ""); + + static_assert(!cutlass::platform::is_same::value, + "Sm90 architecture should use specialised kernels"); + + // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. + using ElementType = typename TllmToCutlassTypeAdapter::type; + using CutlassGemmOutputType = typename TllmToCutlassTypeAdapter::type; + using CutlassWeightType = typename TllmToCutlassTypeAdapter::type; + if (!use_fused_moe) + { + // We need separate config for each architecture since we will target different tensorcore instructions. For + // float, we do not target TCs. + using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits; + using ElementAccumulator = typename MixedGemmArchTraits::AccType; + + using EpilogueOp = typename tensorrt_llm::cutlass_extensions::Epilogue::Op; + + typename EpilogueOp::Params epilogue_op( + ElementAccumulator(1.f), biases ? ElementAccumulator(1.f) : ElementAccumulator(0.f)); + +#if defined(ENABLE_FP8) + if constexpr ((std::is_same_v + || std::is_same_v) &&std::is_same_v) + { + TLLM_CHECK_WITH_INFO(weight_scales == nullptr && biases == nullptr && alpha_scale_ptr_array, + "weight_scales and biases should be nullptr and alpha_scale_ptr_array shouldn't be nullptr for FP8 " + "Ada"); + epilogue_op.alpha_ptr_array = alpha_scale_ptr_array; + } +#endif + + // Finally, set up the kernel. + using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemmGrouped::GemmKernel; + + using GemmKernel = cutlass::gemm::kernel::MoeFCGemm; + + using GemmGrouped = cutlass::gemm::device::GemmGrouped; + + if (kernel_occupancy != nullptr) + { + *kernel_occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel(); + return; + } + int occupancy = std::min(2, GemmGrouped::maximum_active_blocks()); + TLLM_CHECK_WITH_INFO(occupancy > 0, "GPU lacks the shared memory resources to run GroupedGEMM kernel"); + int const threadblock_count = multi_processor_count * occupancy; + + int const group_size = gemm_k; + typename GemmGrouped::Arguments args(num_experts, threadblock_count, group_size, epilogue_op, + reinterpret_cast(A), reinterpret_cast(B), + reinterpret_cast(weight_scales), + reinterpret_cast(biases), bias_is_broadcast, + reinterpret_cast(C), total_tokens_including_expert, gemm_n, gemm_k); + + GemmGrouped gemm; + + auto can_implement = gemm.can_implement(args); + TLLM_CHECK_WITH_INFO(can_implement == cutlass::Status::kSuccess, + "MoE FC kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement))); + + auto init_status = gemm.initialize(args); + TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess, + "Failed to initialize cutlass grouped gemm. Error: " + std::string(cutlassGetStatusString(init_status))); + + auto run_status = gemm.run(stream); + TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess, + "Failed to run cutlass grouped gemm. Error: " + std::string(cutlassGetStatusString(run_status))); + } + else if constexpr (sizeof(ElementType) == 2 && sizeof(CutlassWeightType) == 2 + && (std::is_same_v + || std::is_same_v) ) // use fused moe gemm + // kernel.. (only support + // fp16 or bf16) + { + sm80_generic_fused_moe_gemm_kernelLauncher(reinterpret_cast(A), + reinterpret_cast(B), reinterpret_cast(biases), + bias_is_broadcast, reinterpret_cast(C), total_tokens_including_expert, num_rows, gemm_n, + gemm_k, num_experts, multi_processor_count, stream, kernel_occupancy); + } +} + +} // namespace kernels::cutlass_kernels + +template +static void dispatch(T const* A, WeightType const* B, GemmOutputType const* weight_scales, GemmOutputType const* biases, + bool bias_is_broadcast, GemmOutputType* C, int64_t const* total_tokens_including_expert, int64_t num_rows, + int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, + int multi_processor_count, bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, + int* occupancy = nullptr) +{ + + static_assert(!std::is_same_v, "Use TMA specialised functions for arch SM90"); +#if defined(ENABLE_FP8) + constexpr bool isFp8 = std::is_same_v || std::is_same_v; +#else + constexpr bool isFp8 = false; +#endif + + if constexpr ((Stages == 2 || Arch::kMinComputeCapability >= 80) + && (!isFp8 || std::is_same_v) ) + { + kernels::cutlass_kernels::genericMoeGemmKernelLauncher(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, num_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + } + else + { + TLLM_THROW( + "Cutlass gemm. Not instantiated for arch %d with stages set to %d", Arch::kMinComputeCapability, Stages); + } +} + +template +void dispatchGemmConfig(T const* A, WeightType const* B, GemmOutputType const* weight_scales, + GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C, + int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, + float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr) +{ + switch (gemm_config.stages) + { + case 2: + dispatch(A, B, weight_scales, + biases, bias_is_broadcast, C, total_tokens_including_expert, num_rows, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case 3: + dispatch(A, B, weight_scales, + biases, bias_is_broadcast, C, total_tokens_including_expert, num_rows, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case 4: + dispatch(A, B, weight_scales, + biases, bias_is_broadcast, C, total_tokens_including_expert, num_rows, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + default: TLLM_THROW("dispatchGemmConfig does not support stages %d", gemm_config.stages); break; + } +} + +// This overload will handle tensorop gemms. It is disabled via SFINAE for fp32. +// This overload is only enabled when T == WeightType. +template ::value +#if defined(ENABLE_FP8) + && !std::is_same::value && !std::is_same::value +#endif + && std::is_same::value>::type* = nullptr> +void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales, + GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C, + int64_t const* total_tokens_including_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, + float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr) +{ + switch (gemm_config.tile_config) + { + case cutlass_extensions::CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: + TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); + if constexpr (arch::kMinComputeCapability >= 75) + { + dispatchGemmConfig, + cutlass::gemm::GemmShape<16, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + } + break; + case cutlass_extensions::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: + TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); + if constexpr (arch::kMinComputeCapability >= 75) + { + dispatchGemmConfig, + cutlass::gemm::GemmShape<16, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + } + break; + case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<32, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; + case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: + TLLM_THROW("GEMM config should have already been set by heuristic."); + break; + default: TLLM_THROW("Config is invalid for same type tensorop GEMM."); break; + } +} + +// Tensorop GEMM overload +// Overload for quantize MoE GEMMs. We disable some warp configs here since they will not be used and we can improve +// compile time +template ::value && !std::is_same::value>::type* = nullptr> +void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales, + GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C, + int64_t const* total_tokens_including_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, + float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr) +{ + switch (gemm_config.tile_config) + { + case cutlass_extensions::CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: + TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); + if constexpr (arch::kMinComputeCapability >= 75) + { + dispatchGemmConfig, + cutlass::gemm::GemmShape<16, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + } + break; + case cutlass_extensions::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: + TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); + if constexpr (arch::kMinComputeCapability >= 75) + { + dispatchGemmConfig, + cutlass::gemm::GemmShape<16, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + } + break; + case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<128, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; + case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: + TLLM_THROW("GEMM config should have already been set by heuristic."); + break; + default: TLLM_THROW("Config is invalid for mixed type tensorop GEMM."); break; + } +} + +// This overload will handle tensorop gemms. +// This overload is only enabled when T == WeightType and T == __nv_fp8_e4m3 or __nv_fp8_e5m2 +#if defined(ENABLE_FP8) +template ::value || std::is_same::value) + && std::is_same::value>::type* = nullptr> +void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales, + GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C, + int64_t const* total_tokens_including_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, + float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr) +{ + switch (gemm_config.tile_config) + { + case cutlass_extensions::CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128: + dispatchGemmConfig, + cutlass::gemm::GemmShape<16, 64, 128>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<32, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; + case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: + TLLM_THROW("GEMM config should have already been set by heuristic."); + break; + default: TLLM_THROW("Config is invalid for same type tensorop GEMM."); break; + } +} +#endif + +// This overload will handle simt gemms. It is disabled via SFINAE for tensorop. +template ::value>::type* = nullptr> +void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales, + GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C, + int64_t const* total_tokens_including_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, + float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr) +{ + switch (gemm_config.tile_config) + { + case cutlass_extensions::CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 64, 8>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; + case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: + TLLM_THROW("GEMM config should have already been set by heuristic."); + break; + default: TLLM_THROW("Unsupported config for float MoE gemm."); break; + } +} + +template +std::vector +MoeGemmRunner::getConfigs() const +{ + return getConfigs(sm_); +} + +template +std::vector MoeGemmRunner::getConfigs( + int sm) +{ + std::vector candidate_configs = getHopperConfigs(sm); + std::vector ampere_configs = getAmpereConfigs(sm); + std::copy(ampere_configs.begin(), ampere_configs.end(), std::back_inserter(candidate_configs)); + + return candidate_configs; +} + +template +std::vector +MoeGemmRunner::getAmpereConfigs(int sm) +{ + using tensorrt_llm::cutlass_extensions::CutlassGemmConfig; + static constexpr auto weight_only_flag + = std::is_same::value ? CutlassGemmConfig::NONE : CutlassGemmConfig::WEIGHT_ONLY; + static constexpr auto simt_only_flag + = std::is_same::value ? CutlassGemmConfig::SIMT_ONLY : CutlassGemmConfig::NONE; + static constexpr auto fp8_only_flag = use_fp8 ? CutlassGemmConfig::FP8_ONLY : CutlassGemmConfig::NONE; + int const max_split_k = 1; + int const grouped_gemm_flag = CutlassGemmConfig::GROUPED_GEMM; + int const enable_hopper = CutlassGemmConfig::NONE; + + auto config_type_param = static_cast( + weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_hopper | fp8_only_flag); + + if (!kernels::cutlass_kernels::isValidAmpereMOESpecialisation()) + { + return {}; + } + + std::vector ampere_configs + = kernels::cutlass_kernels::get_candidate_configs(sm, max_split_k, config_type_param); + return ampere_configs; +} + +template +std::vector +MoeGemmRunner::getHopperConfigs(int sm) +{ + using tensorrt_llm::cutlass_extensions::CutlassGemmConfig; + static constexpr auto weight_only_flag + = std::is_same::value ? CutlassGemmConfig::NONE : CutlassGemmConfig::WEIGHT_ONLY; + static constexpr auto simt_only_flag + = std::is_same::value ? CutlassGemmConfig::SIMT_ONLY : CutlassGemmConfig::NONE; + int const max_split_k = 1; + int const grouped_gemm_flag = CutlassGemmConfig::GROUPED_GEMM; + int const enable_hopper = CutlassGemmConfig::HOPPER; + static constexpr auto fp8_only_flag = use_fp8 ? CutlassGemmConfig::FP8_ONLY : CutlassGemmConfig::NONE; + auto config_type_param = static_cast( + weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_hopper | fp8_only_flag); + + if (!kernels::cutlass_kernels::isValidHopperMOESpecialisation()) + { + return {}; + } + + std::vector hopper_configs + = kernels::cutlass_kernels::get_candidate_configs(sm, max_split_k, config_type_param); + return hopper_configs; +} + +template +bool MoeGemmRunner::isHopperSpecialised( + cutlass_extensions::CutlassGemmConfig gemm_config) const +{ + bool config_is_sm90 = gemm_config.is_sm90; + return supportsHopperSpecialisation() && config_is_sm90; +} + +template +bool MoeGemmRunner::supportsHopperSpecialisation() const +{ + return sm_ == 90 && kernels::cutlass_kernels::isValidHopperMOESpecialisation(); +} + +template +int MoeGemmRunner::getSM() const +{ + return this->sm_; +} + +// currently support sm80 bf16/fp16 gate activation, only set predication tensor for m direction +template +bool MoeGemmRunner::supportsFusedGatedActivation( + bool is_gated_activation, int gemm_n, int gemm_k) const +{ + constexpr bool ENABLE_FUSED_GATED_ACTIVATION = true; + return is_gated_activation && std::is_same_v && !std::is_same_v && !use_fp8 + && (this->getSM() >= 80) && (gemm_k % 64 == 0) && (gemm_n % 64 == 0) && ENABLE_FUSED_GATED_ACTIVATION; +} + +template +bool MoeGemmRunner::isFusedGatedActivation( + cutlass_extensions::CutlassGemmConfig gemm_config, bool is_gated_activation, int gemm_n, int gemm_k) const +{ + return supportsFusedGatedActivation(is_gated_activation, gemm_n, gemm_k) && !gemm_config.is_sm90; +} + +template +MoeGemmRunner::MoeGemmRunner() +{ + int device{-1}; + tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device)); + sm_ = tensorrt_llm::common::getSMVersion(); + tensorrt_llm::common::check_cuda_error( + cudaDeviceGetAttribute(&multi_processor_count_, cudaDevAttrMultiProcessorCount, device)); +} + +template +template +void MoeGemmRunner::dispatchToArch(T const* A, + WeightType const* B, ScaleBiasType const* weight_scales, ScaleBiasType const* biases, bool bias_is_broadcast, + void* C_void, int64_t const* total_tokens_including_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows, + int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, + bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy) +{ + static_assert(std::is_same_v, + "Separate Scale/Bias type is not supported. This is assumed to be the gemm output type"); + + // For now we always cast this to output type. + // In the future this will vary based on what fusions are applied for FP8 + auto* C = reinterpret_cast(C_void); + + TLLM_CHECK_WITH_INFO( + sm_ >= 89 || !hopper_input.isValid(), "Hopper input information is set for non specialised implementation"); + TLLM_CHECK_WITH_INFO( + sm_ == 90 || !gemm_config.is_sm90, "Hopper configuration provided for non-Hopper architecture"); + + if (sm_ >= 75 && sm_ < 80) + { + dispatchMoeGemmToCutlass(A, B, weight_scales, + biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + } + else if (sm_ >= 80 && sm_ < 90) + { + if constexpr (use_fp8) + { +#if defined(ENABLE_FP8) + static_assert(!std::is_same_v && !std::is_same_v, + "FP8 GEMM Output not supported"); +#endif + + TLLM_CHECK_WITH_INFO(sm_ == 89, "For sm >= 80 and < 90, fp8 is only supported with sm == 89"); + dispatchMoeGemmToCutlass(A, B, + weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n, gemm_k, + num_experts, gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream, + occupancy); + } + else + { + dispatchMoeGemmToCutlass(A, B, + weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n, gemm_k, + num_experts, gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream, + occupancy); + } + } + else if (sm_ >= 90) + { + if constexpr (kernels::cutlass_kernels::isValidHopperMOESpecialisation()) + { + + // We allow both SM90 and SM80 configurations to coexist because for some cases with small numbers of tokens + // SM80 is faster. We check here to see which is selected + if (gemm_config.is_sm90) + { + TLLM_CHECK_WITH_INFO(biases != nullptr || hopper_input.ptr_c == nullptr, + "Input biases and hopper input disagree if bias is enabled"); + TLLM_CHECK_WITH_INFO(hopper_input.isValid(), "Calling SM90 configuration with invalid hopper config"); + + // Select the appropriate fusion function + auto select_function = [&]() + { + switch (hopper_input.fusion) + { + case HopperGroupedGemmInput::EpilogueFusion::FINALIZE: + return &dispatchMoeGemmSelectTileShapeSM90; + case HopperGroupedGemmInput::EpilogueFusion::NONE: + return &dispatchMoeGemmSelectTileShapeSM90; + case HopperGroupedGemmInput::EpilogueFusion::ACTIVATION: + case HopperGroupedGemmInput::EpilogueFusion::GATED_ACTIVATION: + default: TLLM_THROW("Unimplemented fusion %d requested", (int) hopper_input.fusion); + }; + }; + auto selected_func = select_function(); + selected_func( + hopper_input, num_experts, gemm_config, multi_processor_count_, stream, occupancy, nullptr); + return; + } + + // Fallthrough to SM80 impl below + } + + // Do Ampere case instead + if constexpr (kernels::cutlass_kernels::isValidAmpereMOESpecialisation()) + { + TLLM_CHECK_WITH_INFO(!hopper_input.isValid(), + "Non-specialised Hopper implementation is being rerouted to fallback implementation so input " + "information is not required"); + TLLM_CHECK_WITH_INFO(!gemm_config.is_sm90, + "GEMM config is for SM90 configuration, but this configuration is not valid for Hppper"); + dispatchMoeGemmToCutlass(A, B, + weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n, gemm_k, + num_experts, gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream, + occupancy); + } + else + { + TLLM_THROW("Configuration expects SM80 but configuration is not supported by SM80 kernels"); + } + } + else + { + TLLM_THROW("Arch unsupported for MoE GEMM"); + } +} + +template +size_t MoeGemmRunner::getMaxWorkspaceSize(int num_experts) const +{ + if (num_experts != num_experts_) + { + TLLM_LOG_TRACE("Calling getMaxWorkspaceSize() with a new expert count %d vs %d", num_experts, num_experts_); + num_experts_ = num_experts; + gemm_workspace_size_ = calcMaxWorkspaceSize(num_experts); + } + return gemm_workspace_size_; +} + +template +size_t MoeGemmRunner::calcMaxWorkspaceSize(int num_experts) const +{ + if (!supportsHopperSpecialisation()) + { + return 0; + } + if constexpr (kernels::cutlass_kernels::isValidHopperMOESpecialisation()) + { + auto configs = getHopperConfigs(sm_); + size_t max_size = 0; + bool has_config = false; + for (auto conf : configs) + { +#define CALC_SIZE_FUSION(FUSION) \ + do \ + { \ + try \ + { \ + size_t size = calcMaxWorkspaceSizeSM90( \ + num_experts, conf, multi_processor_count_); \ + max_size = std::max(max_size, size); \ + has_config = true; \ + } \ + catch (tensorrt_llm::common::TllmException const& e) \ + { \ + TLLM_LOG_TRACE("Unsupported config skipped when calculating MOE workspace size"); \ + } \ + } while (0) + + CALC_SIZE_FUSION(HopperGroupedGemmInput::EpilogueFusion::NONE); + CALC_SIZE_FUSION(HopperGroupedGemmInput::EpilogueFusion::FINALIZE); + } + TLLM_CHECK_WITH_INFO(has_config, "Could not find valid config when calculating workspace size"); + return max_size; + } + else + { + TLLM_THROW("Attempting to calculate Hopper GEMM workspace size with unsupported weight combination"); + return 0; + } +} + +template +template +void MoeGemmRunner::runGemm(T const* A, WeightType const* B, + ScaleBiasType const* weight_scales, ScaleBiasType const* biases, bool bias_is_broadcast, void* C, + int64_t const* total_tokens_including_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows, + int64_t gemm_n, int64_t gemm_k, int num_experts, bool use_fused_moe, float const** alpha_scale_ptr_array, + cudaStream_t stream, cutlass_extensions::CutlassGemmConfig chosen_conf) +{ + dispatchToArch(A, B, weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert, + hopper_input, total_rows, gemm_n, gemm_k, num_experts, chosen_conf, use_fused_moe, alpha_scale_ptr_array, + stream, nullptr); +} + +template +void MoeGemmRunner::moeGemmBiasAct(T const* A, WeightType const* B, + ScaleBiasType const* weight_scales, ScaleBiasType const* biases, bool bias_is_broadcast, void* C, + int64_t const* total_tokens_including_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows, + int64_t gemm_n, int64_t gemm_k, int num_experts, ActivationType activation_type, bool use_fused_moe, + float const** alpha_scale_ptr_array, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig chosen_conf) +{ + switch (activation_type) + { + case ActivationType::Relu: + runGemm(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, + alpha_scale_ptr_array, stream, chosen_conf); + break; + case ActivationType::Gelu: + runGemm(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, + alpha_scale_ptr_array, stream, chosen_conf); + break; + case ActivationType::Silu: + runGemm(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, + alpha_scale_ptr_array, stream, chosen_conf); + break; + case ActivationType::Identity: + runGemm(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, + alpha_scale_ptr_array, stream, chosen_conf); + break; + case ActivationType::Swiglu: + runGemm(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, + alpha_scale_ptr_array, stream, chosen_conf); + break; + case ActivationType::Geglu: + runGemm(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, + alpha_scale_ptr_array, stream, chosen_conf); + break; + case ActivationType::InvalidType: TLLM_THROW("Activation type for fpA_intB must be valid."); break; + default: TLLM_THROW("Invalid activation type."); break; + } +} + +template +void MoeGemmRunner::moeGemm(T const* A, WeightType const* B, + ScaleBiasType const* weight_scales, void* C, int64_t const* total_tokens_including_expert, + HopperGroupedGemmInput hopper_input, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, + cutlass_extensions::CutlassGemmConfig chosen_conf) +{ + runGemm(A, B, weight_scales, nullptr, true, C, total_tokens_including_expert, + hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, alpha_scale_ptr_array, stream, + chosen_conf); +} + +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template_sm90.h b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template_sm90.h new file mode 100644 index 000000000..3efb42f41 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template_sm90.h @@ -0,0 +1,222 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Ignore CUTLASS warnings about type punning +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ + +#include "cutlass/array.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" + +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/tensor_ref.h" + +#include "cutlass_extensions/compute_occupancy.h" +#include "cutlass_extensions/epilogue_helpers.h" +#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" +#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h" +#include "cutlass_extensions/gemm/threadblock/default_mma.h" + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic pop +#endif // __GNUC__ + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" + +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h" + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h" +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h" +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h" + +#include +#include +#include +#include + +namespace tensorrt_llm +{ +using EpilogueFusion = HopperGroupedGemmInput::EpilogueFusion; + +template +void dispatchMoeGemmSelectBiasSM90(HopperGroupedGemmInput hopper_input, int num_experts, int multi_processor_count, + cudaStream_t stream, int* occupancy, size_t* workspace_size) +{ + static_assert(kernels::cutlass_kernels::isValidHopperMOESpecialisation(), + "Invalid hopper configuration invoked, fallback to Sm80"); + + TLLM_CHECK_WITH_INFO( + workspace_size || hopper_input.isValid(), "Hopper specialisation is missing additional input information"); + + // auto func = hopper_input.ptr_c ? + // kernels::cutlass_kernels::genericMoeGemmKernelLauncherHopper + // : + // kernels::cutlass_kernels::genericMoeGemmKernelLauncherHopper; + // TODO(dastokes) Re-enable bias when CUTLASS supports it + auto func = kernels::cutlass_kernels::sm90_generic_moe_gemm_kernelLauncher; + func(hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size); +} + +/* + 1x1x1 cluster shape is are supported for any tile shape. + + 2x1x1 cluster shape is only supported for when the M tile is at least 128. + + 1x2x1 cluster shape is only supported when the N tile is at least 128. + + 2x2x1 cluster shape is only supported when both the M and N tiles are at least 128. + + We make the above restrictions are to improve compilation speed in TRT-LLM by pruning kernels + that may not be very useful in practice. + */ +template +constexpr bool are_tile_shapes_supported() +{ + using namespace cute; + [[maybe_unused]] constexpr int cta_m = get<0>(CTAShape{}); + [[maybe_unused]] constexpr int cta_n = get<1>(CTAShape{}); + constexpr int cga_m = get<0>(ClusterShape{}); + constexpr int cga_n = get<1>(ClusterShape{}); + + if constexpr (cga_m == _1{} && cga_n == _1{}) + { + return true; + } + else if constexpr (cga_m == _2{} && cga_n == _1{} && cta_m >= _128{}) + { + return true; + } + else if constexpr (cga_m == _1{} && cga_n == _2{} && cta_n >= _128{}) + { + return true; + } + else if constexpr (cga_m == _2{} && cga_n == _2{} && cta_m >= _128{} && cta_n >= _128{}) + { + return true; + } + else + { + return false; + } +} + +template +void dispatchMoeGemmSelectClusterShapeSM90(HopperGroupedGemmInput hopper_input, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, int* occupancy, + size_t* workspace_size) +{ + using namespace cute; + switch (gemm_config.cluster_shape) + { +#define SHAPE_CASE(M, N, K) \ + case cutlass_extensions::ClusterShape::ClusterShape_##M##x##N##x##K: \ + { \ + using ClusterShape = Shape<_##M, _##N, _##K>; \ + if constexpr (are_tile_shapes_supported()) \ + { \ + dispatchMoeGemmSelectBiasSM90( \ + hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size); \ + break; \ + } \ + else \ + { \ + TLLM_THROW("Unsupported tile and cluster shape combination"); \ + } \ + } + + SHAPE_CASE(1, 1, 1) + SHAPE_CASE(1, 2, 1) + + SHAPE_CASE(2, 1, 1) + SHAPE_CASE(2, 2, 1) + +#undef SHAPE_CASE + default: TLLM_THROW("Unsupported config for MoE gemm."); + } +} // namespace tensorrt_llm + +template +void dispatchMoeGemmSelectTileShapeSM90(HopperGroupedGemmInput hopper_input, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, int* occupancy, + size_t* workspace_size) +{ + using namespace cute; + + switch (gemm_config.tile_config_sm90) + { +#define SHAPE_CASE(M, N, K) \ + case cutlass_extensions::CutlassTileConfigSM90::CtaShape##M##x##N##x##K##B: \ + { \ + constexpr int KtileBytes = K / sizeof(T); \ + using KTileDim = Int; \ + using TileShape = Shape<_##M, _##N, KTileDim>; \ + dispatchMoeGemmSelectClusterShapeSM90( \ + hopper_input, num_experts, gemm_config, multi_processor_count, stream, occupancy, workspace_size); \ + break; \ + } + + SHAPE_CASE(128, 16, 128) + SHAPE_CASE(128, 32, 128) + SHAPE_CASE(128, 64, 128) + SHAPE_CASE(128, 128, 128) + SHAPE_CASE(128, 256, 128) + SHAPE_CASE(256, 128, 128) + +#undef SHAPE_CASE + case cutlass_extensions::CutlassTileConfigSM90::Undefined: TLLM_THROW("GEMM config undefined."); break; + case cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic: + TLLM_THROW("GEMM config should have already been set by heuristic."); + break; + default: TLLM_THROW("Unsupported config for MoE gemm."); break; + } +} + +template +size_t calcMaxWorkspaceSizeSM90( + int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count) +{ + size_t count; + // Most of the values are ignored for WS size calculation. We reuse the function to reduce the template bloat + dispatchMoeGemmSelectTileShapeSM90( + HopperGroupedGemmInput{}, num_experts, gemm_config, multi_processor_count, cudaStream_t{0}, nullptr, &count); + return count; +} + +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h new file mode 100644 index 000000000..959d0ea08 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h @@ -0,0 +1,44 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/arch/mma_sm90.h" +#include "cutlass_extensions/epilogue_helpers.h" + +namespace tensorrt_llm::kernels::cutlass_kernels +{ + +// Hopper arch +template +constexpr bool isValidHopperMOESpecialisation() +{ +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + return cutlass::platform::is_same::value + && cutlass::platform::is_same::value; +#else + return false; // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED is set when Hopper kernels are enabled +#endif +} + +// Hopper arch +template +constexpr bool isValidAmpereMOESpecialisation() +{ + return true; // Default to true +} + +} // namespace tensorrt_llm::kernels::cutlass_kernels