Add PDL support for quant kernel and rope kernel (#9106)
This commit is contained in:
@@ -635,6 +635,8 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
|
||||
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
|
||||
os.environ["CUDA_MODULE_LOADING"] = "AUTO"
|
||||
# flashinfer uses this environment variable for various kernels from MoE to quant kernels
|
||||
os.environ["TRTLLM_ENABLE_PDL"] = "1"
|
||||
|
||||
# Set prometheus env vars
|
||||
if server_args.enable_metrics:
|
||||
|
||||
@@ -550,7 +550,6 @@ class ServerArgs:
|
||||
assert (
|
||||
self.quantization == "modelopt_fp4"
|
||||
), "modelopt_fp4 quantization is required for Flashinfer MOE"
|
||||
os.environ["TRTLLM_ENABLE_PDL"] = "1"
|
||||
assert self.ep_size in [
|
||||
1,
|
||||
self.tp_size,
|
||||
|
||||
@@ -90,7 +90,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
|
||||
m.def(
|
||||
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
|
||||
"Tensor pos_ids, bool interleave, int cuda_stream, "
|
||||
"Tensor pos_ids, bool interleave, bool enable_pdl, int cuda_stream, "
|
||||
"Tensor? v, Tensor!? k_buffer, Tensor!? v_buffer, Tensor? kv_cache_loc) -> ()");
|
||||
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
|
||||
|
||||
|
||||
@@ -104,6 +104,10 @@ __global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedHeadParallelismKernel
|
||||
uint32_t by = blockIdx.y;
|
||||
const uint32_t bdy = blockDim.y;
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;");
|
||||
#endif
|
||||
|
||||
vec_t<float, vec_size> cos, sin;
|
||||
if (bx * bdy + ty < nnz) {
|
||||
const uint32_t idx = bx * bdy + ty;
|
||||
@@ -178,6 +182,10 @@ __global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedHeadParallelismKernel
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <
|
||||
@@ -220,6 +228,10 @@ __global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel(
|
||||
uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y;
|
||||
const uint32_t bdy = blockDim.y;
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;");
|
||||
#endif
|
||||
|
||||
vec_t<float, vec_size> cos, sin;
|
||||
if (bx * bdy + ty < nnz) {
|
||||
const uint32_t idx = bx * bdy + ty;
|
||||
@@ -296,6 +308,10 @@ __global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
}
|
||||
|
||||
#define DISPATCH_SAVE_KV_CACHE(save_kv_cache, SAVE_KV_CACHE, ...) \
|
||||
@@ -340,12 +356,59 @@ cudaError_t BatchQKApplyRotaryPosIdsCosSinCacheEnhanced(
|
||||
IdType* kv_cache_loc,
|
||||
bool interleave,
|
||||
bool save_kv_cache,
|
||||
bool enable_pdl,
|
||||
cudaStream_t stream = nullptr) {
|
||||
int dev_id = 0;
|
||||
int num_sms = 0;
|
||||
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
|
||||
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id));
|
||||
|
||||
#define LAUNCH_KERNEL_RAW(kernel_name) \
|
||||
do { \
|
||||
cudaLaunchConfig_t config = {}; \
|
||||
config.gridDim = nblks; \
|
||||
config.blockDim = nthrs; \
|
||||
config.dynamicSmemBytes = 0; \
|
||||
config.stream = stream; \
|
||||
cudaLaunchAttribute attrs[1] = {}; \
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; \
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; \
|
||||
config.numAttrs = 1; \
|
||||
config.attrs = attrs; \
|
||||
\
|
||||
FLASHINFER_CUDA_CALL(cudaLaunchKernelEx( \
|
||||
&config, \
|
||||
kernel_name, \
|
||||
q, \
|
||||
k, \
|
||||
v, \
|
||||
q_rope, \
|
||||
k_rope, \
|
||||
k_buffer, \
|
||||
v_buffer, \
|
||||
cos_sin_cache, \
|
||||
pos_ids, \
|
||||
nnz, \
|
||||
num_qo_heads, \
|
||||
num_kv_heads, \
|
||||
rotary_dim, \
|
||||
q_stride_n, \
|
||||
q_stride_h, \
|
||||
k_stride_n, \
|
||||
k_stride_h, \
|
||||
v_stride_n, \
|
||||
v_stride_h, \
|
||||
q_rope_stride_n, \
|
||||
q_rope_stride_h, \
|
||||
k_rope_stride_n, \
|
||||
k_rope_stride_h, \
|
||||
k_buffer_stride_n, \
|
||||
k_buffer_stride_h, \
|
||||
v_buffer_stride_n, \
|
||||
v_buffer_stride_h, \
|
||||
kv_cache_loc)); \
|
||||
} while (0)
|
||||
|
||||
DISPATCH_SAVE_KV_CACHE(save_kv_cache, SAVE_KV_CACHE, {
|
||||
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
|
||||
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
|
||||
@@ -359,35 +422,7 @@ cudaError_t BatchQKApplyRotaryPosIdsCosSinCacheEnhanced(
|
||||
uint32_t bdy = num_threads / bdx;
|
||||
// how many blocks needed to process all tokens
|
||||
uint32_t nblks_x = (nnz + bdy - 1) / bdy;
|
||||
void* args[] = {
|
||||
(void*)&q,
|
||||
(void*)&k,
|
||||
(void*)&v,
|
||||
(void*)&q_rope,
|
||||
(void*)&k_rope,
|
||||
(void*)&k_buffer,
|
||||
(void*)&v_buffer,
|
||||
(void*)&cos_sin_cache,
|
||||
(void*)&pos_ids,
|
||||
(void*)&nnz,
|
||||
(void*)&num_qo_heads,
|
||||
(void*)&num_kv_heads,
|
||||
(void*)&rotary_dim,
|
||||
(void*)&q_stride_n,
|
||||
(void*)&q_stride_h,
|
||||
(void*)&k_stride_n,
|
||||
(void*)&k_stride_h,
|
||||
(void*)&v_stride_n,
|
||||
(void*)&v_stride_h,
|
||||
(void*)&q_rope_stride_n,
|
||||
(void*)&q_rope_stride_h,
|
||||
(void*)&k_rope_stride_n,
|
||||
(void*)&k_rope_stride_h,
|
||||
(void*)&k_buffer_stride_n,
|
||||
(void*)&k_buffer_stride_h,
|
||||
(void*)&v_buffer_stride_n,
|
||||
(void*)&v_buffer_stride_h,
|
||||
(void*)&kv_cache_loc};
|
||||
|
||||
auto kernel_0 = BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel<
|
||||
SAVE_KV_CACHE,
|
||||
INTERLEAVE,
|
||||
@@ -405,7 +440,7 @@ cudaError_t BatchQKApplyRotaryPosIdsCosSinCacheEnhanced(
|
||||
if ((nnz + bdy - 1) / bdy >= num_ctas_0) {
|
||||
dim3 nblks(nblks_x);
|
||||
dim3 nthrs(bdx, bdy);
|
||||
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel_0, nblks, nthrs, args, 0, stream));
|
||||
LAUNCH_KERNEL_RAW(kernel_0);
|
||||
} else {
|
||||
dim3 nblks(nblks_x, num_qo_heads + num_kv_heads);
|
||||
dim3 nthrs(bdx, bdy);
|
||||
@@ -417,11 +452,12 @@ cudaError_t BatchQKApplyRotaryPosIdsCosSinCacheEnhanced(
|
||||
bdx,
|
||||
DType,
|
||||
IdType>;
|
||||
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel_1, nblks, nthrs, args, 0, stream));
|
||||
LAUNCH_KERNEL_RAW(kernel_1);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
#undef LAUNCH_KERNEL_RAW
|
||||
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ void apply_rope_pos_ids_cos_sin_cache(
|
||||
at::Tensor cos_sin_cache,
|
||||
at::Tensor pos_ids,
|
||||
bool interleave,
|
||||
bool enable_pdl,
|
||||
int64_t cuda_stream,
|
||||
const std::optional<at::Tensor>& v,
|
||||
const std::optional<at::Tensor>& k_buffer,
|
||||
@@ -124,12 +125,14 @@ void apply_rope_pos_ids_cos_sin_cache(
|
||||
kv_cache_loc_ptr,
|
||||
interleave,
|
||||
save_kv_cache,
|
||||
enable_pdl,
|
||||
stream);
|
||||
TORCH_CHECK(
|
||||
status == cudaSuccess,
|
||||
"BatchQKApplyRotaryPosIdsCosSinCacheEnhanced failed with error code " +
|
||||
std::string(cudaGetErrorString(status)));
|
||||
} else {
|
||||
TORCH_CHECK(!enable_pdl);
|
||||
cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCache(
|
||||
static_cast<c_type*>(q.data_ptr()),
|
||||
static_cast<c_type*>(k.data_ptr()),
|
||||
|
||||
@@ -151,6 +151,7 @@ void apply_rope_pos_ids_cos_sin_cache(
|
||||
at::Tensor cos_sin_cache,
|
||||
at::Tensor pos_ids,
|
||||
bool interleave,
|
||||
bool enable_pdl,
|
||||
int64_t cuda_stream,
|
||||
const std::optional<at::Tensor>& v,
|
||||
const std::optional<at::Tensor>& k_buffer,
|
||||
|
||||
@@ -271,6 +271,7 @@ def apply_rope_with_cos_sin_cache_inplace(
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool = True,
|
||||
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
||||
enable_pdl: Optional[bool] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Apply rotary embedding to keys and queries with precomputed cos/sin values.
|
||||
@@ -307,6 +308,10 @@ def apply_rope_with_cos_sin_cache_inplace(
|
||||
if cos_sin_cache.dtype != torch.float32:
|
||||
raise ValueError("cos_sin_cache should be float32")
|
||||
|
||||
if enable_pdl is None:
|
||||
# the non-fused branch does not yet support PDL, but after we switch to our impl for that branch it will
|
||||
enable_pdl = is_arch_support_pdl() and (fused_set_kv_buffer_arg is not None)
|
||||
|
||||
if (a := fused_set_kv_buffer_arg) is not None:
|
||||
assert a.k_scale is None, "k_scale is not yet supported"
|
||||
assert a.v_scale is None, "v_scale is not yet supported"
|
||||
@@ -323,6 +328,7 @@ def apply_rope_with_cos_sin_cache_inplace(
|
||||
cos_sin_cache,
|
||||
positions.long(),
|
||||
(not is_neox),
|
||||
enable_pdl,
|
||||
get_cuda_stream(),
|
||||
(
|
||||
_view_3d(fused_set_kv_buffer_arg.value)
|
||||
|
||||
Reference in New Issue
Block a user