add tensorrt_llm moe_gemm as 3rdparty (#3217)
This commit is contained in:
21
sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Wrapper.h
vendored
Normal file
21
sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Wrapper.h
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
@@ -1,187 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#define CUDA_LIB_NAME "cuda"
|
||||
|
||||
#if defined(_WIN32)
|
||||
#include <windows.h>
|
||||
#define dllOpen(name) LoadLibrary("nv" name ".dll")
|
||||
#define dllClose(handle) FreeLibrary(static_cast<HMODULE>(handle))
|
||||
#define dllGetSym(handle, name) static_cast<void*>(GetProcAddress(static_cast<HMODULE>(handle), name))
|
||||
#else // For non-Windows platforms
|
||||
#include <dlfcn.h>
|
||||
#define dllOpen(name) dlopen("lib" name ".so.1", RTLD_LAZY)
|
||||
#define dllClose(handle) dlclose(handle)
|
||||
#define dllGetSym(handle, name) dlsym(handle, name)
|
||||
#endif // defined(_WIN32)
|
||||
|
||||
#include "cudaDriverWrapper.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include <cstdio>
|
||||
#include <cuda.h>
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
std::shared_ptr<CUDADriverWrapper> CUDADriverWrapper::getInstance()
|
||||
{
|
||||
static std::mutex mutex;
|
||||
static std::weak_ptr<CUDADriverWrapper> instance;
|
||||
std::shared_ptr<CUDADriverWrapper> result = instance.lock();
|
||||
if (result)
|
||||
{
|
||||
return result;
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
result = instance.lock();
|
||||
if (!result)
|
||||
{
|
||||
result = std::shared_ptr<CUDADriverWrapper>(new CUDADriverWrapper());
|
||||
instance = result;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
CUDADriverWrapper::CUDADriverWrapper()
|
||||
: handle(dllOpen(CUDA_LIB_NAME))
|
||||
{
|
||||
|
||||
TLLM_CHECK_WITH_INFO(handle != nullptr, "CUDA driver library is not open correctly.");
|
||||
|
||||
auto load_sym = [](void* handle, char const* name)
|
||||
{
|
||||
void* ret = dllGetSym(handle, name);
|
||||
return ret;
|
||||
};
|
||||
|
||||
*reinterpret_cast<void**>(&_cuGetErrorName) = load_sym(handle, "cuGetErrorName");
|
||||
*reinterpret_cast<void**>(&_cuGetErrorMessage) = load_sym(handle, "cuGetErrorMessage");
|
||||
*reinterpret_cast<void**>(&_cuFuncSetAttribute) = load_sym(handle, "cuFuncSetAttribute");
|
||||
*reinterpret_cast<void**>(&_cuLinkComplete) = load_sym(handle, "cuLinkComplete");
|
||||
*reinterpret_cast<void**>(&_cuModuleUnload) = load_sym(handle, "cuModuleUnload");
|
||||
*reinterpret_cast<void**>(&_cuLinkDestroy) = load_sym(handle, "cuLinkDestroy");
|
||||
*reinterpret_cast<void**>(&_cuModuleLoadData) = load_sym(handle, "cuModuleLoadData");
|
||||
*reinterpret_cast<void**>(&_cuLinkCreate) = load_sym(handle, "cuLinkCreate_v2");
|
||||
*reinterpret_cast<void**>(&_cuModuleGetFunction) = load_sym(handle, "cuModuleGetFunction");
|
||||
*reinterpret_cast<void**>(&_cuModuleGetGlobal) = load_sym(handle, "cuModuleGetGlobal_v2");
|
||||
*reinterpret_cast<void**>(&_cuLinkAddFile) = load_sym(handle, "cuLinkAddFile_v2");
|
||||
*reinterpret_cast<void**>(&_cuLinkAddData) = load_sym(handle, "cuLinkAddData_v2");
|
||||
*reinterpret_cast<void**>(&_cuLaunchCooperativeKernel) = load_sym(handle, "cuLaunchCooperativeKernel");
|
||||
*reinterpret_cast<void**>(&_cuLaunchKernel) = load_sym(handle, "cuLaunchKernel");
|
||||
*reinterpret_cast<void**>(&_cuTensorMapEncodeTiled) = load_sym(handle, "cuTensorMapEncodeTiled");
|
||||
*reinterpret_cast<void**>(&_cuMemcpyDtoH) = load_sym(handle, "cuMemcpyDtoH_v2");
|
||||
}
|
||||
|
||||
CUDADriverWrapper::~CUDADriverWrapper()
|
||||
{
|
||||
dllClose(handle);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuGetErrorName(CUresult error, char const** pStr) const
|
||||
{
|
||||
return (*_cuGetErrorName)(error, pStr);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuGetErrorMessage(CUresult error, char const** pStr) const
|
||||
{
|
||||
return (*_cuGetErrorMessage)(error, pStr);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const
|
||||
{
|
||||
return (*_cuFuncSetAttribute)(hfunc, attrib, value);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const
|
||||
{
|
||||
return (*_cuLinkComplete)(state, cubinOut, sizeOut);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuModuleUnload(CUmodule hmod) const
|
||||
{
|
||||
return (*_cuModuleUnload)(hmod);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuLinkDestroy(CUlinkState state) const
|
||||
{
|
||||
return (*_cuLinkDestroy)(state);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuModuleLoadData(CUmodule* module, void const* image) const
|
||||
{
|
||||
return (*_cuModuleLoadData)(module, image);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuLinkCreate(
|
||||
unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const
|
||||
{
|
||||
return (*_cuLinkCreate)(numOptions, options, optionValues, stateOut);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const
|
||||
{
|
||||
return (*_cuModuleGetFunction)(hfunc, hmod, name);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const
|
||||
{
|
||||
return (*_cuModuleGetGlobal)(dptr, bytes, hmod, name);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path,
|
||||
unsigned int numOptions, CUjit_option* options, void** optionValues) const
|
||||
{
|
||||
return (*_cuLinkAddFile)(state, type, path, numOptions, options, optionValues);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size,
|
||||
char const* name, unsigned int numOptions, CUjit_option* options, void** optionValues) const
|
||||
{
|
||||
return (*_cuLinkAddData)(state, type, data, size, name, numOptions, options, optionValues);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
|
||||
unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ,
|
||||
unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const
|
||||
{
|
||||
return (*_cuLaunchCooperativeKernel)(
|
||||
f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
|
||||
unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ,
|
||||
unsigned int sharedMemBytes, CUstream hStream, void** kernelParams, void** extra) const
|
||||
{
|
||||
return (*_cuLaunchKernel)(
|
||||
f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType,
|
||||
cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides,
|
||||
cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave,
|
||||
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const
|
||||
{
|
||||
return (*_cuTensorMapEncodeTiled)(tensorMap, tensorDataType, tensorRank, globalAddress, globalDim, globalStrides,
|
||||
boxDim, elementStrides, interleave, swizzle, l2Promotion, oobFill);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const
|
||||
{
|
||||
return (*_cuMemcpyDtoH)(dstHost, srcDevice, ByteCount);
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
@@ -1,138 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef CUDA_DRIVER_WRAPPER_H
|
||||
#define CUDA_DRIVER_WRAPPER_H
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include <cstdio>
|
||||
#include <cuda.h>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
class CUDADriverWrapper
|
||||
{
|
||||
public:
|
||||
static std::shared_ptr<CUDADriverWrapper> getInstance();
|
||||
|
||||
~CUDADriverWrapper();
|
||||
CUDADriverWrapper(CUDADriverWrapper const&) = delete;
|
||||
CUDADriverWrapper operator=(CUDADriverWrapper const&) = delete;
|
||||
CUDADriverWrapper(CUDADriverWrapper&&) = delete;
|
||||
CUDADriverWrapper operator=(CUDADriverWrapper&&) = delete;
|
||||
|
||||
CUresult cuGetErrorName(CUresult error, char const** pStr) const;
|
||||
|
||||
CUresult cuGetErrorMessage(CUresult error, char const** pStr) const;
|
||||
|
||||
CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const;
|
||||
|
||||
CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const;
|
||||
|
||||
CUresult cuModuleUnload(CUmodule hmod) const;
|
||||
|
||||
CUresult cuLinkDestroy(CUlinkState state) const;
|
||||
|
||||
CUresult cuModuleLoadData(CUmodule* module, void const* image) const;
|
||||
|
||||
CUresult cuLinkCreate(
|
||||
unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const;
|
||||
|
||||
CUresult cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const;
|
||||
|
||||
CUresult cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const;
|
||||
|
||||
CUresult cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, unsigned int numOptions,
|
||||
CUjit_option* options, void** optionValues) const;
|
||||
|
||||
CUresult cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, char const* name,
|
||||
unsigned int numOptions, CUjit_option* options, void** optionValues) const;
|
||||
|
||||
CUresult cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
|
||||
unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ,
|
||||
unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const;
|
||||
|
||||
CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ,
|
||||
unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes,
|
||||
CUstream hStream, void** kernelParams, void** extra) const;
|
||||
|
||||
CUresult cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, cuuint32_t tensorRank,
|
||||
void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, cuuint32_t const* boxDim,
|
||||
cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle,
|
||||
CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const;
|
||||
|
||||
CUresult cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const;
|
||||
|
||||
private:
|
||||
void* handle;
|
||||
CUDADriverWrapper();
|
||||
|
||||
CUresult (*_cuGetErrorName)(CUresult, char const**);
|
||||
CUresult (*_cuGetErrorMessage)(CUresult, char const**);
|
||||
CUresult (*_cuFuncSetAttribute)(CUfunction, CUfunction_attribute, int);
|
||||
CUresult (*_cuLinkComplete)(CUlinkState, void**, size_t*);
|
||||
CUresult (*_cuModuleUnload)(CUmodule);
|
||||
CUresult (*_cuLinkDestroy)(CUlinkState);
|
||||
CUresult (*_cuLinkCreate)(unsigned int, CUjit_option*, void**, CUlinkState*);
|
||||
CUresult (*_cuModuleLoadData)(CUmodule*, void const*);
|
||||
CUresult (*_cuModuleGetFunction)(CUfunction*, CUmodule, char const*);
|
||||
CUresult (*_cuModuleGetGlobal)(CUdeviceptr*, size_t*, CUmodule, char const*);
|
||||
CUresult (*_cuLinkAddFile)(CUlinkState, CUjitInputType, char const*, unsigned int, CUjit_option*, void**);
|
||||
CUresult (*_cuLinkAddData)(
|
||||
CUlinkState, CUjitInputType, void*, size_t, char const*, unsigned int, CUjit_option*, void**);
|
||||
CUresult (*_cuLaunchCooperativeKernel)(CUfunction, unsigned int, unsigned int, unsigned int, unsigned int,
|
||||
unsigned int, unsigned int, unsigned int, CUstream, void**);
|
||||
CUresult (*_cuLaunchKernel)(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ,
|
||||
unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes,
|
||||
CUstream hStream, void** kernelParams, void** extra);
|
||||
CUresult (*_cuTensorMapEncodeTiled)(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType,
|
||||
cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides,
|
||||
cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave,
|
||||
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill);
|
||||
CUresult (*_cuMemcpyDtoH)(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void checkDriver(
|
||||
T result, CUDADriverWrapper const& wrap, char const* const func, char const* const file, int const line)
|
||||
{
|
||||
if (result)
|
||||
{
|
||||
char const* errorName = nullptr;
|
||||
char const* errorMsg = nullptr;
|
||||
wrap.cuGetErrorName(result, &errorName);
|
||||
wrap.cuGetErrorMessage(result, &errorMsg);
|
||||
throw TllmException(
|
||||
file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA driver error in %s: %s: %s", func, errorName, errorMsg));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
|
||||
/*
|
||||
* Macros compliant with TensorRT coding conventions
|
||||
*/
|
||||
#define TLLM_CU_CHECK(stat) \
|
||||
do \
|
||||
{ \
|
||||
tensorrt_llm::common::checkDriver( \
|
||||
(stat), *tensorrt_llm::common::CUDADriverWrapper::getInstance(), #stat, __FILE__, __LINE__); \
|
||||
} while (0)
|
||||
|
||||
#endif // CUDA_DRIVER_WRAPPER_H
|
||||
@@ -0,0 +1,25 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
namespace tensorrt_llm::kernels::cutlass_kernels
|
||||
{
|
||||
template <typename ElementType_, typename CutlassWeightType_, int MaxTileM_, int TileN_, int TileK_, int Stages_,
|
||||
typename EpilogueTag>
|
||||
void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWeightType_ const* B,
|
||||
ElementType_ const* biases, bool bias_is_broadcast, ElementType_* C, int64_t const* total_tokens_including_expert,
|
||||
int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream,
|
||||
int* kernel_occupancy);
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_grouped.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include <cutlass_extensions/epilogue_helpers.h>
|
||||
#include <cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh>
|
||||
#include <tensorrt_llm/common/cudaUtils.h>
|
||||
|
||||
namespace tensorrt_llm::kernels::cutlass_kernels
|
||||
{
|
||||
template <typename ElementType_, typename CutlassWeightType_, int MaxTileM_, int TileN_, int TileK_, int Stages_,
|
||||
typename EpilogueTag>
|
||||
void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWeightType_ const* B,
|
||||
ElementType_ const* biases, bool bias_is_broadcast, ElementType_* C, int64_t const* total_tokens_including_expert,
|
||||
int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream,
|
||||
int* kernel_occupancy)
|
||||
{
|
||||
constexpr auto activation_type = fused_moe::EpilogueRouting<EpilogueTag>(true);
|
||||
using GemmType = fused_moe::Fused_Moe_Kernel_sm80<ElementType_, CutlassWeightType_, ElementType_, MaxTileM_, TileN_,
|
||||
TileK_, Stages_, activation_type>;
|
||||
|
||||
// make sure GPU has enough resources..
|
||||
if (kernel_occupancy != nullptr)
|
||||
{
|
||||
constexpr int smem_size = GemmType::kSmemSize;
|
||||
|
||||
if (smem_size > (48 << 10))
|
||||
{
|
||||
cudaFuncAttributes attr{};
|
||||
int device = 0;
|
||||
int max_smem_per_block = 0;
|
||||
tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device));
|
||||
tensorrt_llm::common::check_cuda_error(
|
||||
cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device));
|
||||
tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, fused_moe::run_global<GemmType>));
|
||||
if (smem_size + attr.sharedSizeBytes >= static_cast<size_t>(max_smem_per_block))
|
||||
{
|
||||
// This should mean that
|
||||
// cudaFuncSetAttribute(cutlass::Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
// smem_size) wouldn't work. In that case, we return an occupancy of 0. This will cause the
|
||||
// heuristic to ignore this configuration.
|
||||
*kernel_occupancy = 0;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
int max_active_blocks = -1;
|
||||
tensorrt_llm::common::check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks, fused_moe::run_global<GemmType>, GemmType::kThreadCount, smem_size));
|
||||
*kernel_occupancy = max_active_blocks;
|
||||
return;
|
||||
}
|
||||
int occupancy = std::min(2, fused_moe::fused_gemm_maximum_active_blocks<GemmType>());
|
||||
int const threadblock_count = multi_processor_count * occupancy;
|
||||
TLLM_CHECK_WITH_INFO(occupancy > 0, "GPU lacks the shared memory resources to run fused_moe kernel");
|
||||
using Arguments = typename GemmType::Arguments;
|
||||
Arguments args{{const_cast<ElementType_*>(A), const_cast<CutlassWeightType_*>(B), const_cast<ElementType_*>(biases),
|
||||
reinterpret_cast<ElementType_*>(C), total_tokens_including_expert, static_cast<int>(gemm_n),
|
||||
static_cast<int>(gemm_k), num_experts, bias_is_broadcast},
|
||||
num_experts, threadblock_count};
|
||||
auto params = GemmType::to_underlying_arguments(args);
|
||||
if (GemmType::kSmemSize >= (48 << 10))
|
||||
{
|
||||
cudaError_t result = cudaFuncSetAttribute(
|
||||
fused_moe::run_global<GemmType>, cudaFuncAttributeMaxDynamicSharedMemorySize, GemmType::kSmemSize);
|
||||
TLLM_CHECK_WITH_INFO(result == cudaSuccess,
|
||||
"Fail to set the max smem size to " + std::to_string(GemmType::kSmemSize) + " for fused moe kernel");
|
||||
}
|
||||
dim3 grid(params.threadblock_count, 1, 1);
|
||||
dim3 block(GemmType::kThreadCount);
|
||||
fused_moe::run_global<GemmType><<<grid, block, GemmType::kSmemSize, stream>>>(params);
|
||||
auto result = cudaGetLastError();
|
||||
TLLM_CHECK_WITH_INFO(result == cudaSuccess, "Fail to execute fused moe kernel, cuda error %d\n", (int) (result));
|
||||
}
|
||||
} // namespace tensorrt_llm::kernels::cutlass_kernels
|
||||
@@ -0,0 +1,37 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace cutlass_kernels
|
||||
{
|
||||
|
||||
// Keep in sync with the signature generated by generate_kernels.py
|
||||
template <typename T, typename WeightType, typename OutputType, typename EpilogueTag,
|
||||
HopperGroupedGemmInput::EpilogueFusion FUSION, typename TileShape, typename ClusterShape, bool BIAS>
|
||||
void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, int num_experts,
|
||||
int multi_processor_count, cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size);
|
||||
|
||||
} // namespace cutlass_kernels
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@@ -0,0 +1,348 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_grouped.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
#include "cutlass_extensions/compute_occupancy.h"
|
||||
#include "cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp"
|
||||
#include "cutlass_extensions/epilogue_helpers.h"
|
||||
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
|
||||
#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h"
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <math.h>
|
||||
#include <sstream>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace cutlass_kernels
|
||||
{
|
||||
using EpilogueFusion = HopperGroupedGemmInput::EpilogueFusion;
|
||||
|
||||
// Hopper helper class for defining all the cutlass helper types
|
||||
template <typename T, typename WeightType, typename OutputType, typename EpilogueTag, typename TileShape,
|
||||
typename ClusterShape, bool BIAS, EpilogueFusion FUSION>
|
||||
struct HopperGroupedGemmInfo
|
||||
{
|
||||
using Arch = cutlass::arch::Sm90;
|
||||
|
||||
// TODO Update once mixed input support is added
|
||||
static_assert(cutlass::platform::is_same<T, WeightType>::value,
|
||||
"CUTLASS does not currently have specialised SM90 support for quantized operations");
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
constexpr static bool IsFP8
|
||||
= cutlass::platform::is_same<T, __nv_fp8_e4m3>::value || cutlass::platform::is_same<T, __nv_fp8_e5m2>::value;
|
||||
#else
|
||||
constexpr static bool IsFP8 = false;
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
static_assert(cutlass::platform::is_same<T, __nv_bfloat16>::value || cutlass::platform::is_same<T, half>::value
|
||||
|| cutlass::platform::is_same<T, float>::value || IsFP8,
|
||||
"Specialized for bfloat16, half, float, fp8");
|
||||
#else
|
||||
static_assert(cutlass::platform::is_same<T, half>::value || cutlass::platform::is_same<T, float>::value || IsFP8,
|
||||
"Specialized for half, float, fp8");
|
||||
#endif
|
||||
|
||||
static_assert(cutlass::platform::is_same<T, WeightType>::value
|
||||
|| cutlass::platform::is_same<WeightType, uint8_t>::value
|
||||
|| cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value
|
||||
|| cutlass::platform::is_same<WeightType, cutlass::float_e4m3_t>::value
|
||||
|| cutlass::platform::is_same<WeightType, cutlass::float_e5m2_t>::value,
|
||||
"Unexpected quantization type");
|
||||
|
||||
// The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary.
|
||||
using ElementType = typename TllmToCutlassTypeAdapter<T>::type;
|
||||
|
||||
using CutlassWeightTypeMaybeUint4 = typename TllmToCutlassTypeAdapter<WeightType>::type;
|
||||
// For legacy reasons we convert unsigned 8-bit to signed
|
||||
using CutlassWeightTypeMaybeUint8
|
||||
= std::conditional_t<std::is_same_v<CutlassWeightTypeMaybeUint4, cutlass::uint4b_t>, cutlass::int4b_t,
|
||||
CutlassWeightTypeMaybeUint4>;
|
||||
using CutlassWeightType
|
||||
= std::conditional_t<std::is_same_v<CutlassWeightTypeMaybeUint8, uint8_t>, int8_t, CutlassWeightTypeMaybeUint8>;
|
||||
|
||||
using ElementA = ElementType;
|
||||
using ElementB = CutlassWeightType;
|
||||
|
||||
using ElementD = typename TllmToCutlassTypeAdapter<HopperGroupedGemmInput::OutputTypeAdaptor_t<OutputType>>::type;
|
||||
using ElementFinalOutput = typename TllmToCutlassTypeAdapter<OutputType>::type;
|
||||
|
||||
// using ElementC = std::conditional_t<BIAS, ElementType, void>;
|
||||
// using ElementCNoVoid = std::conditional_t<BIAS, ElementType, ElementD>;
|
||||
using ElementC = void;
|
||||
using ElementCNoVoid = ElementD;
|
||||
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using ElementBias = ElementFinalOutput;
|
||||
using ElementRouterScales = float;
|
||||
|
||||
// A matrix configuration - this is transposed and swapped with B
|
||||
using LayoutA = HopperGroupedGemmInput::LayoutA;
|
||||
constexpr static int AlignmentA
|
||||
= 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units
|
||||
// of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration - this is transposed and swapped with A
|
||||
using LayoutB = HopperGroupedGemmInput::LayoutB; // Layout type for B matrix operand
|
||||
constexpr static int AlignmentB
|
||||
= 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units
|
||||
// of elements (up to 16 bytes)
|
||||
|
||||
// C matrix configuration
|
||||
using LayoutC = HopperGroupedGemmInput::LayoutC; // Layout type for C matrix operand
|
||||
using StrideC = HopperGroupedGemmInput::StrideC;
|
||||
// Note we use ElementType here deliberately, so we don't break when BIAS is disabled
|
||||
constexpr static int AlignmentC
|
||||
= 128 / cutlass::sizeof_bits<ElementType>::value; // Memory access granularity/alignment of C matrix in units
|
||||
// of elements (up to 16 bytes)
|
||||
|
||||
// D matrix configuration
|
||||
using LayoutD = HopperGroupedGemmInput::DefaultEpilogue::LayoutD;
|
||||
using StrideD = HopperGroupedGemmInput::DefaultEpilogue::StrideD;
|
||||
constexpr static int AlignmentD
|
||||
= 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of D matrix
|
||||
// in units of elements (up to 16 bytes)
|
||||
|
||||
static_assert(cutlass::platform::is_same<EpilogueTag, tensorrt_llm::cutlass_extensions::EpilogueOpDefault>::value,
|
||||
"Hopper Grouped GEMM specialisation doesn't support fused activation");
|
||||
|
||||
using EpilogueOp
|
||||
= cutlass::epilogue::fusion::LinearCombination<ElementD, ElementAccumulator, ElementC, ElementAccumulator>;
|
||||
|
||||
// TODO Add mode for fused activation once CUTLASS adds support
|
||||
// using EpilogueSchedule = cutlass::platform::conditional_t<
|
||||
// cutlass::platform::is_same<EpilogueOp, EpilogueOpDefault>::value,
|
||||
// cutlass::epilogue::PtrArrayNoSmemWarpSpecialized,
|
||||
// cutlass::epilogue::?????????????????? /// <<<<<< what supports activations
|
||||
// >;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized;
|
||||
|
||||
// Epilogue For Default Finalize
|
||||
using CollectiveEpilogueDefault = typename cutlass::epilogue::collective::CollectiveBuilder< //
|
||||
Arch, cutlass::arch::OpClassTensorOp, //
|
||||
TileShape, ClusterShape, //
|
||||
cutlass::epilogue::collective::EpilogueTileAuto, //
|
||||
ElementAccumulator, ElementAccumulator, //
|
||||
ElementC, LayoutC*, AlignmentC, //
|
||||
ElementD, LayoutD*, AlignmentD, //
|
||||
EpilogueSchedule>::CollectiveOp;
|
||||
|
||||
// Epilogue For Fused Finalize
|
||||
using CollectiveEpilogueFinalize = typename cutlass::epilogue::collective::EpilogueMoeFusedFinalizeBuilder< //
|
||||
TileShape, //
|
||||
ElementCNoVoid, StrideC*, //
|
||||
ElementFinalOutput, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideFinalOutput, //
|
||||
ElementAccumulator, //
|
||||
ElementAccumulator, //
|
||||
ElementBias, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideBias, //
|
||||
ElementRouterScales, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideRouterScales //
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue
|
||||
= std::conditional_t<FUSION == EpilogueFusion::FINALIZE, CollectiveEpilogueFinalize, CollectiveEpilogueDefault>;
|
||||
|
||||
using StageCountAutoCarveout = cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>;
|
||||
|
||||
using KernelSchedule
|
||||
= std::conditional_t<IsFP8, cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum,
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative>;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< //
|
||||
Arch, cutlass::arch::OpClassTensorOp, //
|
||||
CutlassWeightType, LayoutB*, AlignmentB, // A & B swapped here
|
||||
ElementType, LayoutA*, AlignmentA, //
|
||||
ElementAccumulator, //
|
||||
TileShape, ClusterShape, //
|
||||
StageCountAutoCarveout, KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<HopperGroupedGemmInput::ProblemShape, CollectiveMainloop,
|
||||
CollectiveEpilogue>;
|
||||
|
||||
using GemmGrouped = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
};
|
||||
|
||||
// Hopper specialised version
|
||||
template <typename T, typename WeightType, typename OutputType, typename EpilogueTag, EpilogueFusion FUSION,
|
||||
typename TileShape, typename ClusterShape, bool BIAS>
|
||||
void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, int num_experts,
|
||||
int const multi_processor_count, cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size)
|
||||
{
|
||||
#ifdef COMPILE_HOPPER_TMA_GEMMS
|
||||
using namespace cute;
|
||||
if constexpr (!should_filter_sm90_gemm_problem_shape_v<TileShape, ClusterShape, T>)
|
||||
{
|
||||
using GemmInfo
|
||||
= HopperGroupedGemmInfo<T, WeightType, OutputType, EpilogueTag, TileShape, ClusterShape, BIAS, FUSION>;
|
||||
|
||||
using ElementAccumulator = typename GemmInfo::ElementAccumulator;
|
||||
using ElementA = typename GemmInfo::ElementA;
|
||||
using ElementB = typename GemmInfo::ElementB;
|
||||
using ElementC = typename GemmInfo::ElementC;
|
||||
using ElementCNoVoid = typename GemmInfo::ElementCNoVoid;
|
||||
using ElementD = typename GemmInfo::ElementD;
|
||||
|
||||
using CollectiveMainloop = typename GemmInfo::CollectiveMainloop;
|
||||
using CollectiveEpilogue = typename GemmInfo::CollectiveEpilogue;
|
||||
using GemmKernel = typename GemmInfo::GemmKernel;
|
||||
using GemmGrouped = typename GemmInfo::GemmGrouped;
|
||||
|
||||
if (kernel_occupancy != nullptr)
|
||||
{
|
||||
*kernel_occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel<GemmKernel, true>();
|
||||
return;
|
||||
}
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = multi_processor_count;
|
||||
|
||||
GemmGrouped gemm;
|
||||
|
||||
if (workspace_size != nullptr)
|
||||
{
|
||||
// Make a mock problem shape with just the minimal information actually required to get the workspace size
|
||||
// This makes some assumptions about CUTLASS's implementation which is suboptimal. We have a check later to
|
||||
// catch future cutlass updates causing silent breakages, but that is not fool proof.
|
||||
// The alternative is to wait until we have data and then dynamically allocate the workspace
|
||||
typename HopperGroupedGemmInput::ProblemShape shape_info{num_experts, nullptr, nullptr};
|
||||
|
||||
typename GemmGrouped::Arguments args{
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped, shape_info, {}, {}, hw_info};
|
||||
*workspace_size = gemm.get_workspace_size(args);
|
||||
return;
|
||||
}
|
||||
|
||||
using MainloopArguments = typename CollectiveMainloop::Arguments;
|
||||
TLLM_CHECK(hopper_input.stride_a);
|
||||
TLLM_CHECK(hopper_input.stride_b);
|
||||
TLLM_CHECK(hopper_input.ptr_a);
|
||||
TLLM_CHECK(hopper_input.ptr_b);
|
||||
|
||||
MainloopArguments const mainloop_params = {reinterpret_cast<ElementB const**>(hopper_input.ptr_b),
|
||||
hopper_input.stride_b, reinterpret_cast<ElementA const**>(hopper_input.ptr_a), hopper_input.stride_a};
|
||||
|
||||
typename GemmGrouped::EpilogueOutputOp::Params epilogue_scalars{
|
||||
ElementAccumulator(1.f), hopper_input.ptr_c ? ElementAccumulator(1.f) : ElementAccumulator(0.f)};
|
||||
epilogue_scalars.alpha_ptr_array = hopper_input.alpha_scale_ptr_array;
|
||||
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
|
||||
// TODO(dastokes) ptr_c casts to ElementCNoVoid** because there is a workaround in CUTLASS
|
||||
auto make_epi_args = [&]()
|
||||
{
|
||||
if constexpr (FUSION == EpilogueFusion::NONE)
|
||||
{
|
||||
auto epi_params = hopper_input.default_epilogue;
|
||||
return EpilogueArguments{epilogue_scalars, reinterpret_cast<ElementCNoVoid const**>(hopper_input.ptr_c),
|
||||
hopper_input.stride_c, reinterpret_cast<ElementD**>(epi_params.ptr_d), epi_params.stride_d};
|
||||
}
|
||||
else if constexpr (FUSION == EpilogueFusion::FINALIZE)
|
||||
{
|
||||
// Parameters for fused finalize
|
||||
auto epi_params = hopper_input.fused_finalize_epilogue;
|
||||
return EpilogueArguments{
|
||||
epilogue_scalars, // Parameters to underlying epilogue
|
||||
reinterpret_cast<ElementCNoVoid const**>(hopper_input.ptr_c), hopper_input.stride_c, // C params
|
||||
reinterpret_cast<typename GemmInfo::ElementFinalOutput*>(epi_params.ptr_final_output),
|
||||
epi_params.stride_final_output, // D (output) params
|
||||
reinterpret_cast<typename GemmInfo::ElementBias const*>(epi_params.ptr_bias),
|
||||
epi_params.stride_bias, // Bias params
|
||||
epi_params.ptr_router_scales, epi_params.stride_router_scales, // Router scales
|
||||
epi_params.ptr_expert_first_token_offset, // Offset of this expert's token in the router scales
|
||||
epi_params.ptr_source_token_index, // Index of the source token to sum into
|
||||
epi_params.num_rows_in_final_output // Number of tokens in the output buffer
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(
|
||||
sizeof(EpilogueArguments) == 0, "Unimplemented fusion provided to SM90+ MoE gemm launcher");
|
||||
}
|
||||
};
|
||||
EpilogueArguments const epilogue_params = make_epi_args();
|
||||
|
||||
typename GemmKernel::TileScheduler::Arguments scheduler_args{
|
||||
1, GemmKernel::TileScheduler::RasterOrderOptions::AlongN};
|
||||
|
||||
typename GemmGrouped::Arguments args{cutlass::gemm::GemmUniversalMode::kGrouped, hopper_input.shape_info,
|
||||
mainloop_params, epilogue_params, hw_info, scheduler_args};
|
||||
|
||||
size_t calculated_ws_size = gemm.get_workspace_size(args);
|
||||
TLLM_CHECK_WITH_INFO(calculated_ws_size <= hopper_input.gemm_workspace_size,
|
||||
"Workspace is size %zu but only %zu were allocated", calculated_ws_size, hopper_input.gemm_workspace_size);
|
||||
|
||||
auto can_implement = gemm.can_implement(args);
|
||||
TLLM_CHECK_WITH_INFO(can_implement == cutlass::Status::kSuccess,
|
||||
"Grouped GEMM kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)));
|
||||
|
||||
auto init_status = gemm.initialize(args, hopper_input.gemm_workspace);
|
||||
TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess,
|
||||
"Failed to initialize cutlass SM90 grouped gemm. Error: "
|
||||
+ std::string(cutlassGetStatusString(init_status)));
|
||||
|
||||
auto run_status = gemm.run(stream);
|
||||
TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess,
|
||||
"Failed to run cutlass SM90 grouped gemm. Error: " + std::string(cutlassGetStatusString(run_status)));
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_THROW("Configuration was disabled by FAST_BUILD");
|
||||
}
|
||||
|
||||
#else // COMPILE_HOPPER_TMA_GEMMS
|
||||
TLLM_THROW("Please recompile with support for hopper by passing 90-real as an arch to build_wheel.py.");
|
||||
#endif // COMPILE_HOPPER_TMA_GEMMS
|
||||
}
|
||||
|
||||
} // namespace cutlass_kernels
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
131
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_hopper_input.cu
vendored
Normal file
131
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_hopper_input.cu
vendored
Normal file
@@ -0,0 +1,131 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
// Order matters here, packed_stride.hpp is missing cute and convolution includes
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
std::array<size_t, 10> HopperGroupedGemmInput::workspaceBuffers(int num_experts)
|
||||
{
|
||||
size_t problem_shape_size = sizeof(ProblemShape::UnderlyingProblemShape) * num_experts;
|
||||
size_t stride_a_size = sizeof(StrideA) * num_experts;
|
||||
size_t stride_b_size = sizeof(StrideB) * num_experts;
|
||||
size_t stride_c_size = sizeof(StrideC) * num_experts;
|
||||
size_t stride_d_size = sizeof(DefaultEpilogue::StrideD) * num_experts;
|
||||
|
||||
size_t ptr_buf_size = sizeof(void*) * num_experts;
|
||||
size_t scale_buf_size = sizeof(float*) * num_experts;
|
||||
|
||||
return std::array{problem_shape_size, stride_a_size, stride_b_size, stride_c_size, stride_d_size, ptr_buf_size,
|
||||
ptr_buf_size, ptr_buf_size, ptr_buf_size, scale_buf_size};
|
||||
}
|
||||
|
||||
size_t HopperGroupedGemmInput::workspaceSize(int num_experts)
|
||||
{
|
||||
auto buffers = workspaceBuffers(num_experts);
|
||||
return tensorrt_llm::common::calculateTotalWorkspaceSize(buffers.data(), buffers.size());
|
||||
}
|
||||
|
||||
void HopperGroupedGemmInput::configureWorkspace(
|
||||
int8_t* start_ptr, int num_experts, void* gemm_workspace, size_t gemm_workspace_size)
|
||||
{
|
||||
auto buffers = workspaceBuffers(num_experts);
|
||||
std::array<int8_t*, 10> pointers{};
|
||||
TLLM_CHECK_WITH_INFO(pointers.size() == buffers.size(), "Mismatching workspace size and number of buffers");
|
||||
for (int i = 0; i < buffers.size(); i++)
|
||||
{
|
||||
pointers[i] = start_ptr;
|
||||
start_ptr = tensorrt_llm::common::nextWorkspacePtr(start_ptr, buffers[i]);
|
||||
}
|
||||
|
||||
shape_info.num_groups = num_experts;
|
||||
shape_info.problem_shapes = reinterpret_cast<ProblemShape::UnderlyingProblemShape*>(pointers[0]);
|
||||
shape_info.host_problem_shapes = nullptr;
|
||||
stride_a = reinterpret_cast<StrideA*>(pointers[1]);
|
||||
stride_b = reinterpret_cast<StrideB*>(pointers[2]);
|
||||
stride_c = reinterpret_cast<StrideC*>(pointers[3]);
|
||||
default_epilogue.stride_d = reinterpret_cast<DefaultEpilogue::StrideD*>(pointers[4]);
|
||||
|
||||
ptr_a = reinterpret_cast<void const**>(pointers[5]);
|
||||
ptr_b = reinterpret_cast<void const**>(pointers[6]);
|
||||
ptr_c = reinterpret_cast<void const**>(pointers[7]);
|
||||
default_epilogue.ptr_d = reinterpret_cast<void**>(pointers[8]);
|
||||
|
||||
alpha_scale_ptr_array = reinterpret_cast<float const**>(pointers[9]);
|
||||
|
||||
this->gemm_workspace = reinterpret_cast<uint8_t*>(gemm_workspace);
|
||||
this->gemm_workspace_size = gemm_workspace_size;
|
||||
}
|
||||
|
||||
void HopperGroupedGemmInput::setFinalizeFusionParams(void* final_output, float const* router_scales,
|
||||
int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size,
|
||||
int num_output_tokens)
|
||||
{
|
||||
fused_finalize_epilogue.ptr_final_output = final_output;
|
||||
fused_finalize_epilogue.ptr_router_scales = router_scales;
|
||||
fused_finalize_epilogue.ptr_bias = bias;
|
||||
fused_finalize_epilogue.ptr_expert_first_token_offset = expert_first_token_offset;
|
||||
fused_finalize_epilogue.ptr_source_token_index = source_token_index;
|
||||
|
||||
fused_finalize_epilogue.stride_final_output
|
||||
= cutlass::make_cute_packed_stride(FusedFinalizeEpilogue::StrideFinalOutput{},
|
||||
transpose_stride(cute::make_shape(num_output_tokens, hidden_size, 1)));
|
||||
fused_finalize_epilogue.stride_bias
|
||||
= transpose_stride(cute::make_stride(cute::Int<0>{}, cute::Int<1>{}, hidden_size));
|
||||
fused_finalize_epilogue.stride_router_scales = {};
|
||||
|
||||
fused_finalize_epilogue.num_rows_in_final_output = num_output_tokens;
|
||||
}
|
||||
|
||||
std::string HopperGroupedGemmInput::toString() const
|
||||
{
|
||||
std::stringstream ss;
|
||||
ss << "Hopper Input Information: " << (isValid() ? "valid" : "null") << "\n";
|
||||
if (isValid())
|
||||
{
|
||||
ss << "Ptr A: " << ptr_a << ", Ptr B: " << ptr_b << ", Ptr C: " << ptr_c << "\n";
|
||||
ss << "Epilogue Fusion: " << (int) fusion;
|
||||
if (fusion == HopperGroupedGemmInput::EpilogueFusion::FINALIZE)
|
||||
{
|
||||
ss << ",\nFinal Output: " << fused_finalize_epilogue.ptr_final_output;
|
||||
ss << " with Stride: " << fused_finalize_epilogue.stride_router_scales;
|
||||
ss << ",\nBias: " << fused_finalize_epilogue.ptr_bias;
|
||||
ss << " with Stride: " << fused_finalize_epilogue.stride_bias;
|
||||
ss << ",\nRouter Scales: " << fused_finalize_epilogue.ptr_router_scales;
|
||||
ss << " with Stride: " << fused_finalize_epilogue.stride_router_scales;
|
||||
ss << ",\nExpert Offset: " << fused_finalize_epilogue.ptr_expert_first_token_offset;
|
||||
ss << ", Source Map: " << fused_finalize_epilogue.ptr_source_token_index;
|
||||
}
|
||||
else
|
||||
{
|
||||
ss << ", Ptr D: " << default_epilogue.ptr_d;
|
||||
}
|
||||
ss << '\n';
|
||||
ss << "Alpha scale ptr: " << alpha_scale_ptr_array << "\n";
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
} // namespace tensorrt_llm
|
||||
230
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h
vendored
Normal file
230
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h
vendored
Normal file
@@ -0,0 +1,230 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "tensorrt_llm/common/cudaFp8Utils.h"
|
||||
#include "tensorrt_llm/common/workspace.h"
|
||||
#include "tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h"
|
||||
#include <array>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/layout/layout.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
template <class T>
|
||||
constexpr auto transpose_stride(T const& t)
|
||||
{
|
||||
return cute::prepend(cute::prepend(cute::take<2, cute::rank_v<T>>(t), cute::get<0>(t)), cute::get<1>(t));
|
||||
}
|
||||
|
||||
struct HopperGroupedGemmInput
|
||||
{
|
||||
template <class T>
|
||||
using TransposeStride = decltype(transpose_stride<T>(T{}));
|
||||
template <class Tag>
|
||||
using TransposeLayoutTag = std::conditional_t<std::is_same_v<Tag, cutlass::layout::RowMajor>,
|
||||
cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>;
|
||||
|
||||
static_assert(std::is_same_v<cutlass::layout::RowMajor, TransposeLayoutTag<cutlass::layout::ColumnMajor>>);
|
||||
static_assert(std::is_same_v<cutlass::layout::ColumnMajor, TransposeLayoutTag<cutlass::layout::RowMajor>>);
|
||||
|
||||
// Layout for A and B is transposed and then swapped in the implementation
|
||||
// This uses B^T * A^T = (A * B)^T to get a better layout for the GEMM
|
||||
using LayoutA = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for A matrix operand
|
||||
using LayoutB = TransposeLayoutTag<cutlass::layout::ColumnMajor>; // Layout type for B matrix operand
|
||||
using LayoutC = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for C matrix operand
|
||||
|
||||
using StrideA
|
||||
= std::remove_pointer_t<cutlass::detail::TagToStrideB_t<LayoutA*>>; // Use B because they will be swapped
|
||||
using StrideB
|
||||
= std::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutB*>>; // Use A because they will be swapped
|
||||
using StrideC = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutC*>>;
|
||||
|
||||
template <class T>
|
||||
constexpr static bool IsFP8_v = std::is_same_v<T, __nv_fp8_e4m3> || std::is_same_v<T, __nv_fp8_e5m2>;
|
||||
|
||||
// Currently this should always just be T
|
||||
template <class T>
|
||||
using OutputTypeAdaptor_t = std::conditional_t<IsFP8_v<T>, nv_bfloat16, T>;
|
||||
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<cute::Shape<int64_t, int64_t, int64_t>>;
|
||||
|
||||
ProblemShape shape_info{};
|
||||
StrideA* stride_a = nullptr;
|
||||
StrideB* stride_b = nullptr;
|
||||
|
||||
void const** ptr_a = nullptr;
|
||||
void const** ptr_b = nullptr;
|
||||
|
||||
// C is currently the same in both epilogues
|
||||
StrideC* stride_c = nullptr;
|
||||
void const** ptr_c = nullptr;
|
||||
|
||||
struct DefaultEpilogue
|
||||
{
|
||||
using LayoutD = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for D matrix operand
|
||||
using StrideD = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutD*>>;
|
||||
|
||||
StrideD* stride_d = nullptr;
|
||||
void** ptr_d = nullptr;
|
||||
};
|
||||
|
||||
struct FusedFinalizeEpilogue
|
||||
{
|
||||
using StrideFinalOutput = DefaultEpilogue::StrideD;
|
||||
using StrideBias = TransposeStride<cute::Stride<cute::_0, cute::_1, int>>;
|
||||
using StrideRouterScales = TransposeStride<cute::Stride<cute::_1, cute::_0>>;
|
||||
|
||||
void* ptr_final_output = nullptr;
|
||||
StrideFinalOutput stride_final_output{};
|
||||
|
||||
void const* ptr_bias = nullptr;
|
||||
StrideBias stride_bias{};
|
||||
|
||||
float const* ptr_router_scales = nullptr;
|
||||
StrideRouterScales stride_router_scales{};
|
||||
|
||||
int64_t const* ptr_expert_first_token_offset = nullptr;
|
||||
int const* ptr_source_token_index = nullptr;
|
||||
|
||||
size_t num_rows_in_final_output = 0;
|
||||
};
|
||||
|
||||
DefaultEpilogue default_epilogue;
|
||||
FusedFinalizeEpilogue fused_finalize_epilogue;
|
||||
|
||||
enum class EpilogueFusion
|
||||
{
|
||||
NONE,
|
||||
ACTIVATION,
|
||||
GATED_ACTIVATION,
|
||||
FINALIZE
|
||||
};
|
||||
EpilogueFusion fusion = EpilogueFusion::NONE;
|
||||
|
||||
float const** alpha_scale_ptr_array = nullptr;
|
||||
|
||||
uint8_t* gemm_workspace = nullptr;
|
||||
size_t gemm_workspace_size = 0;
|
||||
|
||||
static std::array<size_t, 10> workspaceBuffers(int num_experts);
|
||||
|
||||
static size_t workspaceSize(int num_experts);
|
||||
|
||||
void configureWorkspace(int8_t* start_ptr, int num_experts, void* gemm_workspace, size_t gemm_workspace_size);
|
||||
|
||||
bool isValid() const
|
||||
{
|
||||
return stride_a != nullptr && ptr_a != nullptr;
|
||||
}
|
||||
|
||||
void setFinalizeFusionParams(void* final_output, float const* router_scales,
|
||||
int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size,
|
||||
int num_output_tokens);
|
||||
|
||||
std::string toString() const;
|
||||
};
|
||||
|
||||
// Note update moe.py to match
|
||||
enum class ActivationType
|
||||
{
|
||||
Gelu = 0,
|
||||
Relu,
|
||||
Silu,
|
||||
Swiglu,
|
||||
Geglu,
|
||||
Identity,
|
||||
InvalidType
|
||||
};
|
||||
|
||||
constexpr bool isGatedActivation(ActivationType activation_type)
|
||||
{
|
||||
return activation_type == ActivationType::Swiglu || activation_type == ActivationType::Geglu;
|
||||
}
|
||||
|
||||
template <typename T, /*The type used for activations/scales/compute*/
|
||||
typename WeightType, /* The type for the MoE weights */
|
||||
typename OutputType, /* The output type for the GEMM */
|
||||
typename ScaleBiasType = OutputType /* The type for the scales/bias */
|
||||
>
|
||||
class MoeGemmRunner
|
||||
{
|
||||
public:
|
||||
MoeGemmRunner();
|
||||
|
||||
#if defined(ENABLE_FP8)
|
||||
static constexpr bool use_fp8 = std::is_same_v<T, __nv_fp8_e4m3> || std::is_same_v<T, __nv_fp8_e5m2>;
|
||||
#else
|
||||
static constexpr bool use_fp8 = false;
|
||||
#endif
|
||||
|
||||
void moeGemmBiasAct(T const* A, WeightType const* B, ScaleBiasType const* weight_scales,
|
||||
ScaleBiasType const* biases, bool bias_is_broadcast, void* C, int64_t const* total_tokens_including_expert,
|
||||
HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
ActivationType activation_type, bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream,
|
||||
cutlass_extensions::CutlassGemmConfig chosen_conf);
|
||||
|
||||
void moeGemm(T const* A, WeightType const* B, ScaleBiasType const* weight_scales, void* C,
|
||||
int64_t const* total_tokens_including_expert, HopperGroupedGemmInput layout_info, int64_t total_rows,
|
||||
int64_t gemm_n, int64_t gemm_k, int num_experts, bool use_fused_moe, float const** alpha_scale_ptr_array,
|
||||
cudaStream_t stream, cutlass_extensions::CutlassGemmConfig chosen_conf);
|
||||
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> getConfigs() const;
|
||||
static std::vector<cutlass_extensions::CutlassGemmConfig> getConfigs(int sm);
|
||||
static std::vector<cutlass_extensions::CutlassGemmConfig> getHopperConfigs(int sm);
|
||||
static std::vector<cutlass_extensions::CutlassGemmConfig> getAmpereConfigs(int sm);
|
||||
|
||||
[[nodiscard]] bool isHopperSpecialised(cutlass_extensions::CutlassGemmConfig gemm_config) const;
|
||||
[[nodiscard]] bool supportsHopperSpecialisation() const;
|
||||
[[nodiscard]] bool isFusedGatedActivation(
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, bool is_gated_activation, int gemm_n, int gemm_k) const;
|
||||
[[nodiscard]] bool supportsFusedGatedActivation(bool is_gated_activation, int gemm_n, int gemm_k) const;
|
||||
|
||||
size_t getMaxWorkspaceSize(int num_experts) const;
|
||||
|
||||
[[nodiscard]] int getSM() const;
|
||||
|
||||
private:
|
||||
template <typename EpilogueTag>
|
||||
void dispatchToArch(T const* A, WeightType const* B, ScaleBiasType const* weight_scales,
|
||||
ScaleBiasType const* biases, bool bias_is_broadcast, void* C, int64_t const* total_tokens_including_expert,
|
||||
HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, bool use_fused_moe, float const** alpha_scale_ptr_array,
|
||||
cudaStream_t stream, int* occupancy = nullptr);
|
||||
|
||||
template <typename EpilogueTag>
|
||||
void runGemm(T const* A, WeightType const* B, ScaleBiasType const* weight_scales, ScaleBiasType const* biases,
|
||||
bool bias_is_broadcast, void* C, int64_t const* total_tokens_including_expert,
|
||||
HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream,
|
||||
cutlass_extensions::CutlassGemmConfig chosen_conf);
|
||||
|
||||
private:
|
||||
int sm_{};
|
||||
int multi_processor_count_{};
|
||||
mutable int num_experts_ = 0;
|
||||
mutable size_t gemm_workspace_size_ = 0;
|
||||
size_t calcMaxWorkspaceSize(int num_experts) const;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm
|
||||
24
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu
vendored
Normal file
24
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
#ifdef ENABLE_BF16
|
||||
template class MoeGemmRunner<__nv_bfloat16, __nv_bfloat16, __nv_bfloat16>;
|
||||
#endif
|
||||
} // namespace tensorrt_llm
|
||||
24
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu
vendored
Normal file
24
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
#ifdef ENABLE_BF16
|
||||
template class MoeGemmRunner<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16>;
|
||||
#endif
|
||||
} // namespace tensorrt_llm
|
||||
24
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu
vendored
Normal file
24
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
#ifdef ENABLE_BF16
|
||||
template class MoeGemmRunner<__nv_bfloat16, uint8_t, __nv_bfloat16>;
|
||||
#endif
|
||||
} // namespace tensorrt_llm
|
||||
22
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu
vendored
Normal file
22
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
template class MoeGemmRunner<half, half, half>;
|
||||
}
|
||||
22
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu
vendored
Normal file
22
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
template class MoeGemmRunner<half, cutlass::uint4b_t, half>;
|
||||
}
|
||||
22
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu
vendored
Normal file
22
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
template class MoeGemmRunner<half, uint8_t, half>;
|
||||
}
|
||||
22
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu
vendored
Normal file
22
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
template class MoeGemmRunner<float, float, float>;
|
||||
}
|
||||
28
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu
vendored
Normal file
28
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
#ifdef ENABLE_FP8
|
||||
template class MoeGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, half>;
|
||||
#ifdef ENABLE_BF16
|
||||
template class MoeGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>;
|
||||
#endif
|
||||
// template class MoeGemmRunner<__nv_fp8_e5m2, __nv_fp8_e5m2>;
|
||||
#endif
|
||||
} // namespace tensorrt_llm
|
||||
823
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h
vendored
Normal file
823
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h
vendored
Normal file
@@ -0,0 +1,823 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// Ignore CUTLASS warnings about type punning
|
||||
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
#endif
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_grouped.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
#include "cutlass_extensions/compute_occupancy.h"
|
||||
#include "cutlass_extensions/epilogue_helpers.h"
|
||||
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
|
||||
#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
|
||||
|
||||
#ifdef __GNUC__ // Restore GCC-specific diagnostics
|
||||
#pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
|
||||
|
||||
#include "moe_gemm_kernels_template_sm90.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h"
|
||||
#include <tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <math.h>
|
||||
#include <sstream>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels::cutlass_kernels
|
||||
{
|
||||
|
||||
// ============================= Variable batched Gemm things ===========================
|
||||
template <typename T, typename WeightType, typename GemmOutputType, typename arch, typename EpilogueTag,
|
||||
typename ThreadblockShape, typename WarpShape, int Stages>
|
||||
void genericMoeGemmKernelLauncher(T const* A, WeightType const* B, GemmOutputType const* weight_scales,
|
||||
GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C,
|
||||
int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, int const multi_processor_count, bool use_fused_moe,
|
||||
float const** alpha_scale_ptr_array, cudaStream_t stream, int* kernel_occupancy = nullptr)
|
||||
{
|
||||
#if defined(ENABLE_FP8)
|
||||
static_assert(cutlass::platform::is_same<T, __nv_bfloat16>::value || cutlass::platform::is_same<T, half>::value
|
||||
|| cutlass::platform::is_same<T, __nv_fp8_e4m3>::value
|
||||
|| cutlass::platform::is_same<T, __nv_fp8_e5m2>::value || cutlass::platform::is_same<T, float>::value,
|
||||
"Specialized for fp8, bfloat16, half, float");
|
||||
#elif defined(ENABLE_BF16)
|
||||
static_assert(cutlass::platform::is_same<T, __nv_bfloat16>::value || cutlass::platform::is_same<T, half>::value
|
||||
|| cutlass::platform::is_same<T, float>::value,
|
||||
"Specialized for bfloat16, half, float");
|
||||
#else
|
||||
static_assert(cutlass::platform::is_same<T, half>::value || cutlass::platform::is_same<T, float>::value,
|
||||
"Specialized for half, float");
|
||||
#endif
|
||||
|
||||
static_assert(cutlass::platform::is_same<T, WeightType>::value
|
||||
|| cutlass::platform::is_same<WeightType, uint8_t>::value
|
||||
|| cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value,
|
||||
"");
|
||||
|
||||
static_assert(!cutlass::platform::is_same<arch, cutlass::arch::Sm90>::value,
|
||||
"Sm90 architecture should use specialised kernels");
|
||||
|
||||
// The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary.
|
||||
using ElementType = typename TllmToCutlassTypeAdapter<T>::type;
|
||||
using CutlassGemmOutputType = typename TllmToCutlassTypeAdapter<GemmOutputType>::type;
|
||||
using CutlassWeightType = typename TllmToCutlassTypeAdapter<WeightType>::type;
|
||||
if (!use_fused_moe)
|
||||
{
|
||||
// We need separate config for each architecture since we will target different tensorcore instructions. For
|
||||
// float, we do not target TCs.
|
||||
using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits<ElementType, CutlassWeightType, arch>;
|
||||
using ElementAccumulator = typename MixedGemmArchTraits::AccType;
|
||||
|
||||
using EpilogueOp = typename tensorrt_llm::cutlass_extensions::Epilogue<CutlassGemmOutputType,
|
||||
MixedGemmArchTraits::ElementsPerAccessC, ElementAccumulator, EpilogueTag>::Op;
|
||||
|
||||
typename EpilogueOp::Params epilogue_op(
|
||||
ElementAccumulator(1.f), biases ? ElementAccumulator(1.f) : ElementAccumulator(0.f));
|
||||
|
||||
#if defined(ENABLE_FP8)
|
||||
if constexpr ((std::is_same_v<T, __nv_fp8_e4m3>
|
||||
|| std::is_same_v<T, __nv_fp8_e5m2>) &&std::is_same_v<EpilogueTag,
|
||||
cutlass_extensions::EpilogueOpDefault>)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(weight_scales == nullptr && biases == nullptr && alpha_scale_ptr_array,
|
||||
"weight_scales and biases should be nullptr and alpha_scale_ptr_array shouldn't be nullptr for FP8 "
|
||||
"Ada");
|
||||
epilogue_op.alpha_ptr_array = alpha_scale_ptr_array;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Finally, set up the kernel.
|
||||
using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemmGrouped<ElementType, cutlass::layout::RowMajor,
|
||||
cutlass::ComplexTransform::kNone, MixedGemmArchTraits::ElementsPerAccessA, CutlassWeightType,
|
||||
typename MixedGemmArchTraits::LayoutB, cutlass::ComplexTransform::kNone,
|
||||
MixedGemmArchTraits::ElementsPerAccessB, CutlassGemmOutputType, cutlass::layout::RowMajor,
|
||||
ElementAccumulator, typename MixedGemmArchTraits::OperatorClass, arch, ThreadblockShape, WarpShape,
|
||||
typename MixedGemmArchTraits::InstructionShape, EpilogueOp,
|
||||
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, Stages,
|
||||
cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, typename MixedGemmArchTraits::Operator>::GemmKernel;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::MoeFCGemm<typename GemmKernel_::Mma, typename GemmKernel_::Epilogue,
|
||||
typename GemmKernel_::ThreadblockSwizzle,
|
||||
arch, // Ensure top level arch is used for dispatch
|
||||
GemmKernel_::kGroupScheduleMode>;
|
||||
|
||||
using GemmGrouped = cutlass::gemm::device::GemmGrouped<GemmKernel>;
|
||||
|
||||
if (kernel_occupancy != nullptr)
|
||||
{
|
||||
*kernel_occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel<GemmKernel>();
|
||||
return;
|
||||
}
|
||||
int occupancy = std::min(2, GemmGrouped::maximum_active_blocks());
|
||||
TLLM_CHECK_WITH_INFO(occupancy > 0, "GPU lacks the shared memory resources to run GroupedGEMM kernel");
|
||||
int const threadblock_count = multi_processor_count * occupancy;
|
||||
|
||||
int const group_size = gemm_k;
|
||||
typename GemmGrouped::Arguments args(num_experts, threadblock_count, group_size, epilogue_op,
|
||||
reinterpret_cast<ElementType const*>(A), reinterpret_cast<CutlassWeightType const*>(B),
|
||||
reinterpret_cast<CutlassGemmOutputType const*>(weight_scales),
|
||||
reinterpret_cast<CutlassGemmOutputType const*>(biases), bias_is_broadcast,
|
||||
reinterpret_cast<CutlassGemmOutputType*>(C), total_tokens_including_expert, gemm_n, gemm_k);
|
||||
|
||||
GemmGrouped gemm;
|
||||
|
||||
auto can_implement = gemm.can_implement(args);
|
||||
TLLM_CHECK_WITH_INFO(can_implement == cutlass::Status::kSuccess,
|
||||
"MoE FC kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)));
|
||||
|
||||
auto init_status = gemm.initialize(args);
|
||||
TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess,
|
||||
"Failed to initialize cutlass grouped gemm. Error: " + std::string(cutlassGetStatusString(init_status)));
|
||||
|
||||
auto run_status = gemm.run(stream);
|
||||
TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess,
|
||||
"Failed to run cutlass grouped gemm. Error: " + std::string(cutlassGetStatusString(run_status)));
|
||||
}
|
||||
else if constexpr (sizeof(ElementType) == 2 && sizeof(CutlassWeightType) == 2
|
||||
&& (std::is_same_v<EpilogueTag, cutlass_extensions::EpilogueOpDefaultSilu>
|
||||
|| std::is_same_v<EpilogueTag, cutlass_extensions::EpilogueOpDefaultFtGelu>) ) // use fused moe gemm
|
||||
// kernel.. (only support
|
||||
// fp16 or bf16)
|
||||
{
|
||||
sm80_generic_fused_moe_gemm_kernelLauncher<ElementType, CutlassWeightType, ThreadblockShape::kM,
|
||||
ThreadblockShape::kN, ThreadblockShape::kK, Stages, EpilogueTag>(reinterpret_cast<ElementType const*>(A),
|
||||
reinterpret_cast<CutlassWeightType const*>(B), reinterpret_cast<ElementType const*>(biases),
|
||||
bias_is_broadcast, reinterpret_cast<ElementType*>(C), total_tokens_including_expert, num_rows, gemm_n,
|
||||
gemm_k, num_experts, multi_processor_count, stream, kernel_occupancy);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace kernels::cutlass_kernels
|
||||
|
||||
template <typename T, typename WeightType, typename GemmOutputType, typename Arch, typename EpilogueTag,
|
||||
typename ThreadblockShape, typename WarpShape, int Stages>
|
||||
static void dispatch(T const* A, WeightType const* B, GemmOutputType const* weight_scales, GemmOutputType const* biases,
|
||||
bool bias_is_broadcast, GemmOutputType* C, int64_t const* total_tokens_including_expert, int64_t num_rows,
|
||||
int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config,
|
||||
int multi_processor_count, bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream,
|
||||
int* occupancy = nullptr)
|
||||
{
|
||||
|
||||
static_assert(!std::is_same_v<Arch, cutlass::arch::Sm90>, "Use TMA specialised functions for arch SM90");
|
||||
#if defined(ENABLE_FP8)
|
||||
constexpr bool isFp8 = std::is_same_v<T, __nv_fp8_e4m3> || std::is_same_v<T, __nv_fp8_e5m2>;
|
||||
#else
|
||||
constexpr bool isFp8 = false;
|
||||
#endif
|
||||
|
||||
if constexpr ((Stages == 2 || Arch::kMinComputeCapability >= 80)
|
||||
&& (!isFp8 || std::is_same_v<Arch, cutlass::arch::Sm89>) )
|
||||
{
|
||||
kernels::cutlass_kernels::genericMoeGemmKernelLauncher<T, WeightType, GemmOutputType, Arch, EpilogueTag,
|
||||
ThreadblockShape, WarpShape, Stages>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, num_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
|
||||
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_THROW(
|
||||
"Cutlass gemm. Not instantiated for arch %d with stages set to %d", Arch::kMinComputeCapability, Stages);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename GemmOutputType, typename arch, typename EpilogueTag,
|
||||
typename ThreadblockShape, typename WarpShape>
|
||||
void dispatchGemmConfig(T const* A, WeightType const* B, GemmOutputType const* weight_scales,
|
||||
GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C,
|
||||
int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe,
|
||||
float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr)
|
||||
{
|
||||
switch (gemm_config.stages)
|
||||
{
|
||||
case 2:
|
||||
dispatch<T, WeightType, GemmOutputType, arch, EpilogueTag, ThreadblockShape, WarpShape, 2>(A, B, weight_scales,
|
||||
biases, bias_is_broadcast, C, total_tokens_including_expert, num_rows, gemm_n, gemm_k, num_experts,
|
||||
gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
case 3:
|
||||
dispatch<T, WeightType, GemmOutputType, arch, EpilogueTag, ThreadblockShape, WarpShape, 3>(A, B, weight_scales,
|
||||
biases, bias_is_broadcast, C, total_tokens_including_expert, num_rows, gemm_n, gemm_k, num_experts,
|
||||
gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
case 4:
|
||||
dispatch<T, WeightType, GemmOutputType, arch, EpilogueTag, ThreadblockShape, WarpShape, 4>(A, B, weight_scales,
|
||||
biases, bias_is_broadcast, C, total_tokens_including_expert, num_rows, gemm_n, gemm_k, num_experts,
|
||||
gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
default: TLLM_THROW("dispatchGemmConfig does not support stages %d", gemm_config.stages); break;
|
||||
}
|
||||
}
|
||||
|
||||
// This overload will handle tensorop gemms. It is disabled via SFINAE for fp32.
|
||||
// This overload is only enabled when T == WeightType.
|
||||
template <typename T, typename WeightType, typename GemmOutputType, typename arch, typename EpilogueTag,
|
||||
typename std::enable_if<!std::is_same<T, float>::value
|
||||
#if defined(ENABLE_FP8)
|
||||
&& !std::is_same<T, __nv_fp8_e4m3>::value && !std::is_same<T, __nv_fp8_e5m2>::value
|
||||
#endif
|
||||
&& std::is_same<T, WeightType>::value>::type* = nullptr>
|
||||
void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales,
|
||||
GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C,
|
||||
int64_t const* total_tokens_including_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe,
|
||||
float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr)
|
||||
{
|
||||
switch (gemm_config.tile_config)
|
||||
{
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64:
|
||||
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
|
||||
if constexpr (arch::kMinComputeCapability >= 75)
|
||||
{
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<16, 128, 64>,
|
||||
cutlass::gemm::GemmShape<16, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config,
|
||||
multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
}
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64:
|
||||
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
|
||||
if constexpr (arch::kMinComputeCapability >= 75)
|
||||
{
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<16, 256, 64>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config,
|
||||
multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
}
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<32, 128, 64>,
|
||||
cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
|
||||
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64:
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<64, 128, 64>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
|
||||
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64:
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
|
||||
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break;
|
||||
case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic:
|
||||
TLLM_THROW("GEMM config should have already been set by heuristic.");
|
||||
break;
|
||||
default: TLLM_THROW("Config is invalid for same type tensorop GEMM."); break;
|
||||
}
|
||||
}
|
||||
|
||||
// Tensorop GEMM overload
|
||||
// Overload for quantize MoE GEMMs. We disable some warp configs here since they will not be used and we can improve
|
||||
// compile time
|
||||
template <typename T, typename WeightType, typename GemmOutputType, typename arch, typename EpilogueTag,
|
||||
typename std::enable_if<!std::is_same<T, float>::value && !std::is_same<T, WeightType>::value>::type* = nullptr>
|
||||
void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales,
|
||||
GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C,
|
||||
int64_t const* total_tokens_including_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe,
|
||||
float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr)
|
||||
{
|
||||
switch (gemm_config.tile_config)
|
||||
{
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64:
|
||||
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
|
||||
if constexpr (arch::kMinComputeCapability >= 75)
|
||||
{
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<16, 128, 64>,
|
||||
cutlass::gemm::GemmShape<16, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config,
|
||||
multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
}
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64:
|
||||
TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta");
|
||||
if constexpr (arch::kMinComputeCapability >= 75)
|
||||
{
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<16, 256, 64>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config,
|
||||
multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
}
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<32, 128, 64>,
|
||||
cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
|
||||
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64:
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<64, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
|
||||
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<128, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
|
||||
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break;
|
||||
case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic:
|
||||
TLLM_THROW("GEMM config should have already been set by heuristic.");
|
||||
break;
|
||||
default: TLLM_THROW("Config is invalid for mixed type tensorop GEMM."); break;
|
||||
}
|
||||
}
|
||||
|
||||
// This overload will handle tensorop gemms.
|
||||
// This overload is only enabled when T == WeightType and T == __nv_fp8_e4m3 or __nv_fp8_e5m2
|
||||
#if defined(ENABLE_FP8)
|
||||
template <typename T, typename WeightType, typename GemmOutputType, typename arch, typename EpilogueTag,
|
||||
typename std::enable_if<(std::is_same<T, __nv_fp8_e4m3>::value || std::is_same<T, __nv_fp8_e5m2>::value)
|
||||
&& std::is_same<T, WeightType>::value>::type* = nullptr>
|
||||
void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales,
|
||||
GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C,
|
||||
int64_t const* total_tokens_including_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe,
|
||||
float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr)
|
||||
{
|
||||
switch (gemm_config.tile_config)
|
||||
{
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128:
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<16, 256, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 128>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
|
||||
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<32, 128, 64>,
|
||||
cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
|
||||
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64:
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<64, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
|
||||
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64:
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<64, 64, 128>,
|
||||
cutlass::gemm::GemmShape<32, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
|
||||
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64:
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<128, 64, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
|
||||
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64:
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<128, 256, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
|
||||
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64:
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<256, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
|
||||
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break;
|
||||
case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic:
|
||||
TLLM_THROW("GEMM config should have already been set by heuristic.");
|
||||
break;
|
||||
default: TLLM_THROW("Config is invalid for same type tensorop GEMM."); break;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// This overload will handle simt gemms. It is disabled via SFINAE for tensorop.
|
||||
template <typename T, typename WeightType, typename GemmOutputType, typename arch, typename EpilogueTag,
|
||||
typename std::enable_if<std::is_same<T, float>::value>::type* = nullptr>
|
||||
void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales,
|
||||
GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C,
|
||||
int64_t const* total_tokens_including_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe,
|
||||
float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr)
|
||||
{
|
||||
switch (gemm_config.tile_config)
|
||||
{
|
||||
case cutlass_extensions::CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8:
|
||||
dispatchGemmConfig<T, WeightType, GemmOutputType, arch, EpilogueTag, cutlass::gemm::GemmShape<128, 128, 8>,
|
||||
cutlass::gemm::GemmShape<64, 64, 8>>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count,
|
||||
use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
break;
|
||||
case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break;
|
||||
case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic:
|
||||
TLLM_THROW("GEMM config should have already been set by heuristic.");
|
||||
break;
|
||||
default: TLLM_THROW("Unsupported config for float MoE gemm."); break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig>
|
||||
MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getConfigs() const
|
||||
{
|
||||
return getConfigs(sm_);
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getConfigs(
|
||||
int sm)
|
||||
{
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> candidate_configs = getHopperConfigs(sm);
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> ampere_configs = getAmpereConfigs(sm);
|
||||
std::copy(ampere_configs.begin(), ampere_configs.end(), std::back_inserter(candidate_configs));
|
||||
|
||||
return candidate_configs;
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig>
|
||||
MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getAmpereConfigs(int sm)
|
||||
{
|
||||
using tensorrt_llm::cutlass_extensions::CutlassGemmConfig;
|
||||
static constexpr auto weight_only_flag
|
||||
= std::is_same<T, WeightType>::value ? CutlassGemmConfig::NONE : CutlassGemmConfig::WEIGHT_ONLY;
|
||||
static constexpr auto simt_only_flag
|
||||
= std::is_same<T, float>::value ? CutlassGemmConfig::SIMT_ONLY : CutlassGemmConfig::NONE;
|
||||
static constexpr auto fp8_only_flag = use_fp8 ? CutlassGemmConfig::FP8_ONLY : CutlassGemmConfig::NONE;
|
||||
int const max_split_k = 1;
|
||||
int const grouped_gemm_flag = CutlassGemmConfig::GROUPED_GEMM;
|
||||
int const enable_hopper = CutlassGemmConfig::NONE;
|
||||
|
||||
auto config_type_param = static_cast<CutlassGemmConfig::CandidateConfigTypeParam>(
|
||||
weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_hopper | fp8_only_flag);
|
||||
|
||||
if (!kernels::cutlass_kernels::isValidAmpereMOESpecialisation<T, WeightType>())
|
||||
{
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> ampere_configs
|
||||
= kernels::cutlass_kernels::get_candidate_configs(sm, max_split_k, config_type_param);
|
||||
return ampere_configs;
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig>
|
||||
MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getHopperConfigs(int sm)
|
||||
{
|
||||
using tensorrt_llm::cutlass_extensions::CutlassGemmConfig;
|
||||
static constexpr auto weight_only_flag
|
||||
= std::is_same<T, WeightType>::value ? CutlassGemmConfig::NONE : CutlassGemmConfig::WEIGHT_ONLY;
|
||||
static constexpr auto simt_only_flag
|
||||
= std::is_same<T, float>::value ? CutlassGemmConfig::SIMT_ONLY : CutlassGemmConfig::NONE;
|
||||
int const max_split_k = 1;
|
||||
int const grouped_gemm_flag = CutlassGemmConfig::GROUPED_GEMM;
|
||||
int const enable_hopper = CutlassGemmConfig::HOPPER;
|
||||
static constexpr auto fp8_only_flag = use_fp8 ? CutlassGemmConfig::FP8_ONLY : CutlassGemmConfig::NONE;
|
||||
auto config_type_param = static_cast<CutlassGemmConfig::CandidateConfigTypeParam>(
|
||||
weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_hopper | fp8_only_flag);
|
||||
|
||||
if (!kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType>())
|
||||
{
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> hopper_configs
|
||||
= kernels::cutlass_kernels::get_candidate_configs(sm, max_split_k, config_type_param);
|
||||
return hopper_configs;
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
bool MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::isHopperSpecialised(
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config) const
|
||||
{
|
||||
bool config_is_sm90 = gemm_config.is_sm90;
|
||||
return supportsHopperSpecialisation() && config_is_sm90;
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
bool MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::supportsHopperSpecialisation() const
|
||||
{
|
||||
return sm_ == 90 && kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType>();
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
int MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getSM() const
|
||||
{
|
||||
return this->sm_;
|
||||
}
|
||||
|
||||
// currently support sm80 bf16/fp16 gate activation, only set predication tensor for m direction
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
bool MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::supportsFusedGatedActivation(
|
||||
bool is_gated_activation, int gemm_n, int gemm_k) const
|
||||
{
|
||||
constexpr bool ENABLE_FUSED_GATED_ACTIVATION = true;
|
||||
return is_gated_activation && std::is_same_v<T, WeightType> && !std::is_same_v<T, float> && !use_fp8
|
||||
&& (this->getSM() >= 80) && (gemm_k % 64 == 0) && (gemm_n % 64 == 0) && ENABLE_FUSED_GATED_ACTIVATION;
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
bool MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::isFusedGatedActivation(
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, bool is_gated_activation, int gemm_n, int gemm_k) const
|
||||
{
|
||||
return supportsFusedGatedActivation(is_gated_activation, gemm_n, gemm_k) && !gemm_config.is_sm90;
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::MoeGemmRunner()
|
||||
{
|
||||
int device{-1};
|
||||
tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device));
|
||||
sm_ = tensorrt_llm::common::getSMVersion();
|
||||
tensorrt_llm::common::check_cuda_error(
|
||||
cudaDeviceGetAttribute(&multi_processor_count_, cudaDevAttrMultiProcessorCount, device));
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
template <typename EpilogueTag>
|
||||
void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::dispatchToArch<EpilogueTag>(T const* A,
|
||||
WeightType const* B, ScaleBiasType const* weight_scales, ScaleBiasType const* biases, bool bias_is_broadcast,
|
||||
void* C_void, int64_t const* total_tokens_including_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows,
|
||||
int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config,
|
||||
bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy)
|
||||
{
|
||||
static_assert(std::is_same_v<ScaleBiasType, OutputType>,
|
||||
"Separate Scale/Bias type is not supported. This is assumed to be the gemm output type");
|
||||
|
||||
// For now we always cast this to output type.
|
||||
// In the future this will vary based on what fusions are applied for FP8
|
||||
auto* C = reinterpret_cast<OutputType*>(C_void);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
sm_ >= 89 || !hopper_input.isValid(), "Hopper input information is set for non specialised implementation");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
sm_ == 90 || !gemm_config.is_sm90, "Hopper configuration provided for non-Hopper architecture");
|
||||
|
||||
if (sm_ >= 75 && sm_ < 80)
|
||||
{
|
||||
dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm75, EpilogueTag>(A, B, weight_scales,
|
||||
biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts,
|
||||
gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream, occupancy);
|
||||
}
|
||||
else if (sm_ >= 80 && sm_ < 90)
|
||||
{
|
||||
if constexpr (use_fp8)
|
||||
{
|
||||
#if defined(ENABLE_FP8)
|
||||
static_assert(!std::is_same_v<OutputType, __nv_fp8_e4m3> && !std::is_same_v<OutputType, __nv_fp8_e5m2>,
|
||||
"FP8 GEMM Output not supported");
|
||||
#endif
|
||||
|
||||
TLLM_CHECK_WITH_INFO(sm_ == 89, "For sm >= 80 and < 90, fp8 is only supported with sm == 89");
|
||||
dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm89, EpilogueTag>(A, B,
|
||||
weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n, gemm_k,
|
||||
num_experts, gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream,
|
||||
occupancy);
|
||||
}
|
||||
else
|
||||
{
|
||||
dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm80, EpilogueTag>(A, B,
|
||||
weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n, gemm_k,
|
||||
num_experts, gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream,
|
||||
occupancy);
|
||||
}
|
||||
}
|
||||
else if (sm_ >= 90)
|
||||
{
|
||||
if constexpr (kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType, EpilogueTag>())
|
||||
{
|
||||
|
||||
// We allow both SM90 and SM80 configurations to coexist because for some cases with small numbers of tokens
|
||||
// SM80 is faster. We check here to see which is selected
|
||||
if (gemm_config.is_sm90)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(biases != nullptr || hopper_input.ptr_c == nullptr,
|
||||
"Input biases and hopper input disagree if bias is enabled");
|
||||
TLLM_CHECK_WITH_INFO(hopper_input.isValid(), "Calling SM90 configuration with invalid hopper config");
|
||||
|
||||
// Select the appropriate fusion function
|
||||
auto select_function = [&]()
|
||||
{
|
||||
switch (hopper_input.fusion)
|
||||
{
|
||||
case HopperGroupedGemmInput::EpilogueFusion::FINALIZE:
|
||||
return &dispatchMoeGemmSelectTileShapeSM90<T, WeightType, OutputType, EpilogueTag,
|
||||
HopperGroupedGemmInput::EpilogueFusion::FINALIZE>;
|
||||
case HopperGroupedGemmInput::EpilogueFusion::NONE:
|
||||
return &dispatchMoeGemmSelectTileShapeSM90<T, WeightType, OutputType, EpilogueTag,
|
||||
HopperGroupedGemmInput::EpilogueFusion::NONE>;
|
||||
case HopperGroupedGemmInput::EpilogueFusion::ACTIVATION:
|
||||
case HopperGroupedGemmInput::EpilogueFusion::GATED_ACTIVATION:
|
||||
default: TLLM_THROW("Unimplemented fusion %d requested", (int) hopper_input.fusion);
|
||||
};
|
||||
};
|
||||
auto selected_func = select_function();
|
||||
selected_func(
|
||||
hopper_input, num_experts, gemm_config, multi_processor_count_, stream, occupancy, nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
// Fallthrough to SM80 impl below
|
||||
}
|
||||
|
||||
// Do Ampere case instead
|
||||
if constexpr (kernels::cutlass_kernels::isValidAmpereMOESpecialisation<T, WeightType, EpilogueTag>())
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(!hopper_input.isValid(),
|
||||
"Non-specialised Hopper implementation is being rerouted to fallback implementation so input "
|
||||
"information is not required");
|
||||
TLLM_CHECK_WITH_INFO(!gemm_config.is_sm90,
|
||||
"GEMM config is for SM90 configuration, but this configuration is not valid for Hppper");
|
||||
dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm80, EpilogueTag>(A, B,
|
||||
weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n, gemm_k,
|
||||
num_experts, gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream,
|
||||
occupancy);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_THROW("Configuration expects SM80 but configuration is not supported by SM80 kernels");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_THROW("Arch unsupported for MoE GEMM");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
size_t MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getMaxWorkspaceSize(int num_experts) const
|
||||
{
|
||||
if (num_experts != num_experts_)
|
||||
{
|
||||
TLLM_LOG_TRACE("Calling getMaxWorkspaceSize() with a new expert count %d vs %d", num_experts, num_experts_);
|
||||
num_experts_ = num_experts;
|
||||
gemm_workspace_size_ = calcMaxWorkspaceSize(num_experts);
|
||||
}
|
||||
return gemm_workspace_size_;
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
size_t MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::calcMaxWorkspaceSize(int num_experts) const
|
||||
{
|
||||
if (!supportsHopperSpecialisation())
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
if constexpr (kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType>())
|
||||
{
|
||||
auto configs = getHopperConfigs(sm_);
|
||||
size_t max_size = 0;
|
||||
bool has_config = false;
|
||||
for (auto conf : configs)
|
||||
{
|
||||
#define CALC_SIZE_FUSION(FUSION) \
|
||||
do \
|
||||
{ \
|
||||
try \
|
||||
{ \
|
||||
size_t size = calcMaxWorkspaceSizeSM90<T, WeightType, OutputType, FUSION>( \
|
||||
num_experts, conf, multi_processor_count_); \
|
||||
max_size = std::max(max_size, size); \
|
||||
has_config = true; \
|
||||
} \
|
||||
catch (tensorrt_llm::common::TllmException const& e) \
|
||||
{ \
|
||||
TLLM_LOG_TRACE("Unsupported config skipped when calculating MOE workspace size"); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
CALC_SIZE_FUSION(HopperGroupedGemmInput::EpilogueFusion::NONE);
|
||||
CALC_SIZE_FUSION(HopperGroupedGemmInput::EpilogueFusion::FINALIZE);
|
||||
}
|
||||
TLLM_CHECK_WITH_INFO(has_config, "Could not find valid config when calculating workspace size");
|
||||
return max_size;
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_THROW("Attempting to calculate Hopper GEMM workspace size with unsupported weight combination");
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
template <typename EpilogueTag>
|
||||
void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::runGemm(T const* A, WeightType const* B,
|
||||
ScaleBiasType const* weight_scales, ScaleBiasType const* biases, bool bias_is_broadcast, void* C,
|
||||
int64_t const* total_tokens_including_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows,
|
||||
int64_t gemm_n, int64_t gemm_k, int num_experts, bool use_fused_moe, float const** alpha_scale_ptr_array,
|
||||
cudaStream_t stream, cutlass_extensions::CutlassGemmConfig chosen_conf)
|
||||
{
|
||||
dispatchToArch<EpilogueTag>(A, B, weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert,
|
||||
hopper_input, total_rows, gemm_n, gemm_k, num_experts, chosen_conf, use_fused_moe, alpha_scale_ptr_array,
|
||||
stream, nullptr);
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::moeGemmBiasAct(T const* A, WeightType const* B,
|
||||
ScaleBiasType const* weight_scales, ScaleBiasType const* biases, bool bias_is_broadcast, void* C,
|
||||
int64_t const* total_tokens_including_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows,
|
||||
int64_t gemm_n, int64_t gemm_k, int num_experts, ActivationType activation_type, bool use_fused_moe,
|
||||
float const** alpha_scale_ptr_array, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig chosen_conf)
|
||||
{
|
||||
switch (activation_type)
|
||||
{
|
||||
case ActivationType::Relu:
|
||||
runGemm<cutlass_extensions::EpilogueOpDefaultReLU>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe,
|
||||
alpha_scale_ptr_array, stream, chosen_conf);
|
||||
break;
|
||||
case ActivationType::Gelu:
|
||||
runGemm<cutlass_extensions::EpilogueOpDefaultFtGelu>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe,
|
||||
alpha_scale_ptr_array, stream, chosen_conf);
|
||||
break;
|
||||
case ActivationType::Silu:
|
||||
runGemm<cutlass_extensions::EpilogueOpDefaultSilu>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe,
|
||||
alpha_scale_ptr_array, stream, chosen_conf);
|
||||
break;
|
||||
case ActivationType::Identity:
|
||||
runGemm<cutlass_extensions::EpilogueOpDefault>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe,
|
||||
alpha_scale_ptr_array, stream, chosen_conf);
|
||||
break;
|
||||
case ActivationType::Swiglu:
|
||||
runGemm<cutlass_extensions::EpilogueOpDefaultSilu>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe,
|
||||
alpha_scale_ptr_array, stream, chosen_conf);
|
||||
break;
|
||||
case ActivationType::Geglu:
|
||||
runGemm<cutlass_extensions::EpilogueOpDefaultFtGelu>(A, B, weight_scales, biases, bias_is_broadcast, C,
|
||||
total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe,
|
||||
alpha_scale_ptr_array, stream, chosen_conf);
|
||||
break;
|
||||
case ActivationType::InvalidType: TLLM_THROW("Activation type for fpA_intB must be valid."); break;
|
||||
default: TLLM_THROW("Invalid activation type."); break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::moeGemm(T const* A, WeightType const* B,
|
||||
ScaleBiasType const* weight_scales, void* C, int64_t const* total_tokens_including_expert,
|
||||
HopperGroupedGemmInput hopper_input, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
|
||||
bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream,
|
||||
cutlass_extensions::CutlassGemmConfig chosen_conf)
|
||||
{
|
||||
runGemm<cutlass_extensions::EpilogueOpDefault>(A, B, weight_scales, nullptr, true, C, total_tokens_including_expert,
|
||||
hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, alpha_scale_ptr_array, stream,
|
||||
chosen_conf);
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm
|
||||
222
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template_sm90.h
vendored
Normal file
222
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template_sm90.h
vendored
Normal file
@@ -0,0 +1,222 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// Ignore CUTLASS warnings about type punning
|
||||
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
#endif // __GNUC__
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_grouped.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
#include "cutlass_extensions/compute_occupancy.h"
|
||||
#include "cutlass_extensions/epilogue_helpers.h"
|
||||
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
|
||||
#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
|
||||
|
||||
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
|
||||
#pragma GCC diagnostic pop
|
||||
#endif // __GNUC__
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h"
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <math.h>
|
||||
#include <sstream>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
using EpilogueFusion = HopperGroupedGemmInput::EpilogueFusion;
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename EpilogueTag, EpilogueFusion FUSION,
|
||||
typename TileShape, typename ClusterShape>
|
||||
void dispatchMoeGemmSelectBiasSM90(HopperGroupedGemmInput hopper_input, int num_experts, int multi_processor_count,
|
||||
cudaStream_t stream, int* occupancy, size_t* workspace_size)
|
||||
{
|
||||
static_assert(kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType, EpilogueTag>(),
|
||||
"Invalid hopper configuration invoked, fallback to Sm80");
|
||||
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
workspace_size || hopper_input.isValid(), "Hopper specialisation is missing additional input information");
|
||||
|
||||
// auto func = hopper_input.ptr_c ?
|
||||
// kernels::cutlass_kernels::genericMoeGemmKernelLauncherHopper<T, WeightType,
|
||||
// cutlass::arch::Sm90, EpilogueTag, true>
|
||||
// :
|
||||
// kernels::cutlass_kernels::genericMoeGemmKernelLauncherHopper<T,
|
||||
// WeightType,
|
||||
// cutlass::arch::Sm90, EpilogueTag, false>;
|
||||
// TODO(dastokes) Re-enable bias when CUTLASS supports it
|
||||
auto func = kernels::cutlass_kernels::sm90_generic_moe_gemm_kernelLauncher<T, WeightType, OutputType, EpilogueTag,
|
||||
FUSION, TileShape, ClusterShape, false>;
|
||||
func(hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size);
|
||||
}
|
||||
|
||||
/*
|
||||
1x1x1 cluster shape is are supported for any tile shape.
|
||||
|
||||
2x1x1 cluster shape is only supported for when the M tile is at least 128.
|
||||
|
||||
1x2x1 cluster shape is only supported when the N tile is at least 128.
|
||||
|
||||
2x2x1 cluster shape is only supported when both the M and N tiles are at least 128.
|
||||
|
||||
We make the above restrictions are to improve compilation speed in TRT-LLM by pruning kernels
|
||||
that may not be very useful in practice.
|
||||
*/
|
||||
template <typename CTAShape, typename ClusterShape>
|
||||
constexpr bool are_tile_shapes_supported()
|
||||
{
|
||||
using namespace cute;
|
||||
[[maybe_unused]] constexpr int cta_m = get<0>(CTAShape{});
|
||||
[[maybe_unused]] constexpr int cta_n = get<1>(CTAShape{});
|
||||
constexpr int cga_m = get<0>(ClusterShape{});
|
||||
constexpr int cga_n = get<1>(ClusterShape{});
|
||||
|
||||
if constexpr (cga_m == _1{} && cga_n == _1{})
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else if constexpr (cga_m == _2{} && cga_n == _1{} && cta_m >= _128{})
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else if constexpr (cga_m == _1{} && cga_n == _2{} && cta_n >= _128{})
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else if constexpr (cga_m == _2{} && cga_n == _2{} && cta_m >= _128{} && cta_n >= _128{})
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename EpilogueTag, EpilogueFusion FUSION,
|
||||
typename TileShape>
|
||||
void dispatchMoeGemmSelectClusterShapeSM90(HopperGroupedGemmInput hopper_input, int num_experts,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, int* occupancy,
|
||||
size_t* workspace_size)
|
||||
{
|
||||
using namespace cute;
|
||||
switch (gemm_config.cluster_shape)
|
||||
{
|
||||
#define SHAPE_CASE(M, N, K) \
|
||||
case cutlass_extensions::ClusterShape::ClusterShape_##M##x##N##x##K: \
|
||||
{ \
|
||||
using ClusterShape = Shape<_##M, _##N, _##K>; \
|
||||
if constexpr (are_tile_shapes_supported<TileShape, ClusterShape>()) \
|
||||
{ \
|
||||
dispatchMoeGemmSelectBiasSM90<T, WeightType, OutputType, EpilogueTag, FUSION, TileShape, ClusterShape>( \
|
||||
hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size); \
|
||||
break; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
TLLM_THROW("Unsupported tile and cluster shape combination"); \
|
||||
} \
|
||||
}
|
||||
|
||||
SHAPE_CASE(1, 1, 1)
|
||||
SHAPE_CASE(1, 2, 1)
|
||||
|
||||
SHAPE_CASE(2, 1, 1)
|
||||
SHAPE_CASE(2, 2, 1)
|
||||
|
||||
#undef SHAPE_CASE
|
||||
default: TLLM_THROW("Unsupported config for MoE gemm.");
|
||||
}
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename EpilogueTag, EpilogueFusion FUSION>
|
||||
void dispatchMoeGemmSelectTileShapeSM90(HopperGroupedGemmInput hopper_input, int num_experts,
|
||||
cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, int* occupancy,
|
||||
size_t* workspace_size)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
switch (gemm_config.tile_config_sm90)
|
||||
{
|
||||
#define SHAPE_CASE(M, N, K) \
|
||||
case cutlass_extensions::CutlassTileConfigSM90::CtaShape##M##x##N##x##K##B: \
|
||||
{ \
|
||||
constexpr int KtileBytes = K / sizeof(T); \
|
||||
using KTileDim = Int<KtileBytes>; \
|
||||
using TileShape = Shape<_##M, _##N, KTileDim>; \
|
||||
dispatchMoeGemmSelectClusterShapeSM90<T, WeightType, OutputType, EpilogueTag, FUSION, TileShape>( \
|
||||
hopper_input, num_experts, gemm_config, multi_processor_count, stream, occupancy, workspace_size); \
|
||||
break; \
|
||||
}
|
||||
|
||||
SHAPE_CASE(128, 16, 128)
|
||||
SHAPE_CASE(128, 32, 128)
|
||||
SHAPE_CASE(128, 64, 128)
|
||||
SHAPE_CASE(128, 128, 128)
|
||||
SHAPE_CASE(128, 256, 128)
|
||||
SHAPE_CASE(256, 128, 128)
|
||||
|
||||
#undef SHAPE_CASE
|
||||
case cutlass_extensions::CutlassTileConfigSM90::Undefined: TLLM_THROW("GEMM config undefined."); break;
|
||||
case cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic:
|
||||
TLLM_THROW("GEMM config should have already been set by heuristic.");
|
||||
break;
|
||||
default: TLLM_THROW("Unsupported config for MoE gemm."); break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, EpilogueFusion FUSION>
|
||||
size_t calcMaxWorkspaceSizeSM90(
|
||||
int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count)
|
||||
{
|
||||
size_t count;
|
||||
// Most of the values are ignored for WS size calculation. We reuse the function to reduce the template bloat
|
||||
dispatchMoeGemmSelectTileShapeSM90<T, WeightType, OutputType, cutlass_extensions::EpilogueOpDefault, FUSION>(
|
||||
HopperGroupedGemmInput{}, num_experts, gemm_config, multi_processor_count, cudaStream_t{0}, nullptr, &count);
|
||||
return count;
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm
|
||||
44
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h
vendored
Normal file
44
sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h
vendored
Normal file
@@ -0,0 +1,44 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/arch/mma_sm90.h"
|
||||
#include "cutlass_extensions/epilogue_helpers.h"
|
||||
|
||||
namespace tensorrt_llm::kernels::cutlass_kernels
|
||||
{
|
||||
|
||||
// Hopper arch
|
||||
template <typename T, typename WeightType, typename EpilogueTag = cutlass_extensions::EpilogueOpDefault>
|
||||
constexpr bool isValidHopperMOESpecialisation()
|
||||
{
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
return cutlass::platform::is_same<T, WeightType>::value
|
||||
&& cutlass::platform::is_same<EpilogueTag, cutlass_extensions::EpilogueOpDefault>::value;
|
||||
#else
|
||||
return false; // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED is set when Hopper kernels are enabled
|
||||
#endif
|
||||
}
|
||||
|
||||
// Hopper arch
|
||||
template <typename T, typename WeightType, typename EpilogueTag = cutlass_extensions::EpilogueOpDefault>
|
||||
constexpr bool isValidAmpereMOESpecialisation()
|
||||
{
|
||||
return true; // Default to true
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::kernels::cutlass_kernels
|
||||
Reference in New Issue
Block a user