Add dsv3 fused a gemm to sgl-kernel (#7630)
This commit is contained in:
@@ -201,6 +201,8 @@ void bmm_fp8(
|
||||
int64_t cublas_handle,
|
||||
int64_t cuda_stream);
|
||||
|
||||
void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch::Tensor const& mat_b);
|
||||
|
||||
/*
|
||||
* From csrc/moe
|
||||
*/
|
||||
|
||||
@@ -241,6 +241,23 @@ inline int getSMVersion() {
|
||||
return sm_major * 10 + sm_minor;
|
||||
}
|
||||
|
||||
inline bool getBoolEnv(char const* name) {
|
||||
char const* env = std::getenv(name);
|
||||
return env && env[0] == '1' && env[1] == '\0';
|
||||
}
|
||||
|
||||
inline bool getEnvEnablePDL() {
|
||||
static std::once_flag flag;
|
||||
static bool enablePDL = false;
|
||||
std::call_once(flag, [&]() {
|
||||
if (getSMVersion() >= 90) {
|
||||
// PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1`
|
||||
enablePDL = getBoolEnv("TRTLLM_ENABLE_PDL");
|
||||
}
|
||||
});
|
||||
return enablePDL;
|
||||
}
|
||||
|
||||
// SGLANG_SHFL_XOR_* adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/csrc/cuda_compat.h#L19-L28
|
||||
#ifndef USE_ROCM
|
||||
#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor_sync((mask), (var), (lane_mask))
|
||||
|
||||
Reference in New Issue
Block a user