diff --git a/sgl-kernel/csrc/attention/merge_attn_states.cu b/sgl-kernel/csrc/attention/merge_attn_states.cu index a3b405340..c719498b8 100644 --- a/sgl-kernel/csrc/attention/merge_attn_states.cu +++ b/sgl-kernel/csrc/attention/merge_attn_states.cu @@ -121,18 +121,18 @@ __global__ void merge_attn_states_kernel( } \ } -#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \ - { \ - merge_attn_states_kernel<<>>( \ - reinterpret_cast(output.data_ptr()), \ - reinterpret_cast(output_lse.data_ptr()), \ - reinterpret_cast(prefix_output.data_ptr()), \ - reinterpret_cast(prefix_lse.data_ptr()), \ - reinterpret_cast(suffix_output.data_ptr()), \ - reinterpret_cast(suffix_lse.data_ptr()), \ - num_tokens, \ - num_heads, \ - head_size); \ +#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \ + { \ + merge_attn_states_kernel<<>>( \ + reinterpret_cast(output.data_ptr()), \ + reinterpret_cast(output_lse.data_ptr()), \ + reinterpret_cast(prefix_output.data_ptr()), \ + reinterpret_cast(prefix_lse.data_ptr()), \ + reinterpret_cast(suffix_output.data_ptr()), \ + reinterpret_cast(suffix_lse.data_ptr()), \ + num_tokens, \ + num_heads, \ + head_size); \ } /*@brief Merges the attention states from prefix and suffix @@ -170,6 +170,9 @@ void merge_attn_states_launcher( dim3 block(NUM_THREADS); dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS); + const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device()); + auto stream = at::cuda::getCurrentCUDAStream(); + LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS); }