diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 246cfc643..9077095b1 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -635,6 +635,8 @@ def _set_envs_and_config(server_args: ServerArgs): os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls)) os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" os.environ["CUDA_MODULE_LOADING"] = "AUTO" + # flashinfer uses this environment variable for various kernels from MoE to quant kernels + os.environ["TRTLLM_ENABLE_PDL"] = "1" # Set prometheus env vars if server_args.enable_metrics: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index c24c63ce9..b6a98e05f 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -550,7 +550,6 @@ class ServerArgs: assert ( self.quantization == "modelopt_fp4" ), "modelopt_fp4 quantization is required for Flashinfer MOE" - os.environ["TRTLLM_ENABLE_PDL"] = "1" assert self.ep_size in [ 1, self.tp_size, diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 7aab0b9d3..ac11ff2a7 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -90,7 +90,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def( "apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, " - "Tensor pos_ids, bool interleave, int cuda_stream, " + "Tensor pos_ids, bool interleave, bool enable_pdl, int cuda_stream, " "Tensor? v, Tensor!? k_buffer, Tensor!? v_buffer, Tensor? kv_cache_loc) -> ()"); m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache); diff --git a/sgl-kernel/csrc/elementwise/pos_enc.cuh b/sgl-kernel/csrc/elementwise/pos_enc.cuh index 5388f0e74..a2e4e2ebb 100644 --- a/sgl-kernel/csrc/elementwise/pos_enc.cuh +++ b/sgl-kernel/csrc/elementwise/pos_enc.cuh @@ -104,6 +104,10 @@ __global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedHeadParallelismKernel uint32_t by = blockIdx.y; const uint32_t bdy = blockDim.y; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif + vec_t cos, sin; if (bx * bdy + ty < nnz) { const uint32_t idx = bx * bdy + ty; @@ -178,6 +182,10 @@ __global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedHeadParallelismKernel } } } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif } template < @@ -220,6 +228,10 @@ __global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel( uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y; const uint32_t bdy = blockDim.y; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif + vec_t cos, sin; if (bx * bdy + ty < nnz) { const uint32_t idx = bx * bdy + ty; @@ -296,6 +308,10 @@ __global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel( } } } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif } #define DISPATCH_SAVE_KV_CACHE(save_kv_cache, SAVE_KV_CACHE, ...) \ @@ -340,12 +356,59 @@ cudaError_t BatchQKApplyRotaryPosIdsCosSinCacheEnhanced( IdType* kv_cache_loc, bool interleave, bool save_kv_cache, + bool enable_pdl, cudaStream_t stream = nullptr) { int dev_id = 0; int num_sms = 0; FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id)); +#define LAUNCH_KERNEL_RAW(kernel_name) \ + do { \ + cudaLaunchConfig_t config = {}; \ + config.gridDim = nblks; \ + config.blockDim = nthrs; \ + config.dynamicSmemBytes = 0; \ + config.stream = stream; \ + cudaLaunchAttribute attrs[1] = {}; \ + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; \ + attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; \ + config.numAttrs = 1; \ + config.attrs = attrs; \ + \ + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx( \ + &config, \ + kernel_name, \ + q, \ + k, \ + v, \ + q_rope, \ + k_rope, \ + k_buffer, \ + v_buffer, \ + cos_sin_cache, \ + pos_ids, \ + nnz, \ + num_qo_heads, \ + num_kv_heads, \ + rotary_dim, \ + q_stride_n, \ + q_stride_h, \ + k_stride_n, \ + k_stride_h, \ + v_stride_n, \ + v_stride_h, \ + q_rope_stride_n, \ + q_rope_stride_h, \ + k_rope_stride_n, \ + k_rope_stride_h, \ + k_buffer_stride_n, \ + k_buffer_stride_h, \ + v_buffer_stride_n, \ + v_buffer_stride_h, \ + kv_cache_loc)); \ + } while (0) + DISPATCH_SAVE_KV_CACHE(save_kv_cache, SAVE_KV_CACHE, { DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { @@ -359,35 +422,7 @@ cudaError_t BatchQKApplyRotaryPosIdsCosSinCacheEnhanced( uint32_t bdy = num_threads / bdx; // how many blocks needed to process all tokens uint32_t nblks_x = (nnz + bdy - 1) / bdy; - void* args[] = { - (void*)&q, - (void*)&k, - (void*)&v, - (void*)&q_rope, - (void*)&k_rope, - (void*)&k_buffer, - (void*)&v_buffer, - (void*)&cos_sin_cache, - (void*)&pos_ids, - (void*)&nnz, - (void*)&num_qo_heads, - (void*)&num_kv_heads, - (void*)&rotary_dim, - (void*)&q_stride_n, - (void*)&q_stride_h, - (void*)&k_stride_n, - (void*)&k_stride_h, - (void*)&v_stride_n, - (void*)&v_stride_h, - (void*)&q_rope_stride_n, - (void*)&q_rope_stride_h, - (void*)&k_rope_stride_n, - (void*)&k_rope_stride_h, - (void*)&k_buffer_stride_n, - (void*)&k_buffer_stride_h, - (void*)&v_buffer_stride_n, - (void*)&v_buffer_stride_h, - (void*)&kv_cache_loc}; + auto kernel_0 = BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel< SAVE_KV_CACHE, INTERLEAVE, @@ -405,7 +440,7 @@ cudaError_t BatchQKApplyRotaryPosIdsCosSinCacheEnhanced( if ((nnz + bdy - 1) / bdy >= num_ctas_0) { dim3 nblks(nblks_x); dim3 nthrs(bdx, bdy); - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel_0, nblks, nthrs, args, 0, stream)); + LAUNCH_KERNEL_RAW(kernel_0); } else { dim3 nblks(nblks_x, num_qo_heads + num_kv_heads); dim3 nthrs(bdx, bdy); @@ -417,11 +452,12 @@ cudaError_t BatchQKApplyRotaryPosIdsCosSinCacheEnhanced( bdx, DType, IdType>; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel_1, nblks, nthrs, args, 0, stream)); + LAUNCH_KERNEL_RAW(kernel_1); } }); }); }); +#undef LAUNCH_KERNEL_RAW return cudaSuccess; } diff --git a/sgl-kernel/csrc/elementwise/rope.cu b/sgl-kernel/csrc/elementwise/rope.cu index 41cad7dd4..041558f61 100644 --- a/sgl-kernel/csrc/elementwise/rope.cu +++ b/sgl-kernel/csrc/elementwise/rope.cu @@ -27,6 +27,7 @@ void apply_rope_pos_ids_cos_sin_cache( at::Tensor cos_sin_cache, at::Tensor pos_ids, bool interleave, + bool enable_pdl, int64_t cuda_stream, const std::optional& v, const std::optional& k_buffer, @@ -124,12 +125,14 @@ void apply_rope_pos_ids_cos_sin_cache( kv_cache_loc_ptr, interleave, save_kv_cache, + enable_pdl, stream); TORCH_CHECK( status == cudaSuccess, "BatchQKApplyRotaryPosIdsCosSinCacheEnhanced failed with error code " + std::string(cudaGetErrorString(status))); } else { + TORCH_CHECK(!enable_pdl); cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCache( static_cast(q.data_ptr()), static_cast(k.data_ptr()), diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 007916f9d..33d883d2c 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -151,6 +151,7 @@ void apply_rope_pos_ids_cos_sin_cache( at::Tensor cos_sin_cache, at::Tensor pos_ids, bool interleave, + bool enable_pdl, int64_t cuda_stream, const std::optional& v, const std::optional& k_buffer, diff --git a/sgl-kernel/python/sgl_kernel/elementwise.py b/sgl-kernel/python/sgl_kernel/elementwise.py index 559d6ef39..9abfe8384 100644 --- a/sgl-kernel/python/sgl_kernel/elementwise.py +++ b/sgl-kernel/python/sgl_kernel/elementwise.py @@ -271,6 +271,7 @@ def apply_rope_with_cos_sin_cache_inplace( cos_sin_cache: torch.Tensor, is_neox: bool = True, fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, + enable_pdl: Optional[bool] = None, ) -> None: r""" Apply rotary embedding to keys and queries with precomputed cos/sin values. @@ -307,6 +308,10 @@ def apply_rope_with_cos_sin_cache_inplace( if cos_sin_cache.dtype != torch.float32: raise ValueError("cos_sin_cache should be float32") + if enable_pdl is None: + # the non-fused branch does not yet support PDL, but after we switch to our impl for that branch it will + enable_pdl = is_arch_support_pdl() and (fused_set_kv_buffer_arg is not None) + if (a := fused_set_kv_buffer_arg) is not None: assert a.k_scale is None, "k_scale is not yet supported" assert a.v_scale is None, "v_scale is not yet supported" @@ -323,6 +328,7 @@ def apply_rope_with_cos_sin_cache_inplace( cos_sin_cache, positions.long(), (not is_neox), + enable_pdl, get_cuda_stream(), ( _view_3d(fused_set_kv_buffer_arg.value)