Add dsv3 fused a gemm to sgl-kernel (#7630)

This commit is contained in:
Ke Bao
2025-06-29 17:52:24 +08:00
committed by GitHub
parent 071a1f51ae
commit 04b35190e2
9 changed files with 800 additions and 0 deletions

View File

@@ -201,6 +201,8 @@ void bmm_fp8(
int64_t cublas_handle,
int64_t cuda_stream);
void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch::Tensor const& mat_b);
/*
* From csrc/moe
*/

View File

@@ -241,6 +241,23 @@ inline int getSMVersion() {
return sm_major * 10 + sm_minor;
}
inline bool getBoolEnv(char const* name) {
char const* env = std::getenv(name);
return env && env[0] == '1' && env[1] == '\0';
}
inline bool getEnvEnablePDL() {
static std::once_flag flag;
static bool enablePDL = false;
std::call_once(flag, [&]() {
if (getSMVersion() >= 90) {
// PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1`
enablePDL = getBoolEnv("TRTLLM_ENABLE_PDL");
}
});
return enablePDL;
}
// SGLANG_SHFL_XOR_* adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/csrc/cuda_compat.h#L19-L28
#ifndef USE_ROCM
#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor_sync((mask), (var), (lane_mask))