bugfix: fix merge_state_v2 cuda graph (#5419)

This commit is contained in:
DefTruth
2025-04-16 01:18:47 +08:00
committed by GitHub
parent 838fa0f218
commit 12ef7e3bc3

View File

@@ -123,7 +123,7 @@ __global__ void merge_attn_states_kernel(
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
{ \
merge_attn_states_kernel<scalar_t, NUM_THREADS><<<grid, block>>>( \
merge_attn_states_kernel<scalar_t, NUM_THREADS><<<grid, block, 0, stream>>>( \
reinterpret_cast<scalar_t*>(output.data_ptr()), \
reinterpret_cast<float*>(output_lse.data_ptr()), \
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
@@ -170,6 +170,9 @@ void merge_attn_states_launcher(
dim3 block(NUM_THREADS);
dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS);
const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device());
auto stream = at::cuda::getCurrentCUDAStream();
LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
}