From 12ef7e3bc3a322f20cede3e2bfc68205c475cdec Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Wed, 16 Apr 2025 01:18:47 +0800 Subject: [PATCH] bugfix: fix merge_state_v2 cuda graph (#5419) --- .../csrc/attention/merge_attn_states.cu | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/sgl-kernel/csrc/attention/merge_attn_states.cu b/sgl-kernel/csrc/attention/merge_attn_states.cu index a3b405340..c719498b8 100644 --- a/sgl-kernel/csrc/attention/merge_attn_states.cu +++ b/sgl-kernel/csrc/attention/merge_attn_states.cu @@ -121,18 +121,18 @@ __global__ void merge_attn_states_kernel( } \ } -#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \ - { \ - merge_attn_states_kernel<<>>( \ - reinterpret_cast(output.data_ptr()), \ - reinterpret_cast(output_lse.data_ptr()), \ - reinterpret_cast(prefix_output.data_ptr()), \ - reinterpret_cast(prefix_lse.data_ptr()), \ - reinterpret_cast(suffix_output.data_ptr()), \ - reinterpret_cast(suffix_lse.data_ptr()), \ - num_tokens, \ - num_heads, \ - head_size); \ +#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \ + { \ + merge_attn_states_kernel<<>>( \ + reinterpret_cast(output.data_ptr()), \ + reinterpret_cast(output_lse.data_ptr()), \ + reinterpret_cast(prefix_output.data_ptr()), \ + reinterpret_cast(prefix_lse.data_ptr()), \ + reinterpret_cast(suffix_output.data_ptr()), \ + reinterpret_cast(suffix_lse.data_ptr()), \ + num_tokens, \ + num_heads, \ + head_size); \ } /*@brief Merges the attention states from prefix and suffix @@ -170,6 +170,9 @@ void merge_attn_states_launcher( dim3 block(NUM_THREADS); dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS); + const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device()); + auto stream = at::cuda::getCurrentCUDAStream(); + LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS); }