bugfix: fix merge_state_v2 cuda graph (#5419)
This commit is contained in:
@@ -121,18 +121,18 @@ __global__ void merge_attn_states_kernel(
|
|||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
|
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
|
||||||
{ \
|
{ \
|
||||||
merge_attn_states_kernel<scalar_t, NUM_THREADS><<<grid, block>>>( \
|
merge_attn_states_kernel<scalar_t, NUM_THREADS><<<grid, block, 0, stream>>>( \
|
||||||
reinterpret_cast<scalar_t*>(output.data_ptr()), \
|
reinterpret_cast<scalar_t*>(output.data_ptr()), \
|
||||||
reinterpret_cast<float*>(output_lse.data_ptr()), \
|
reinterpret_cast<float*>(output_lse.data_ptr()), \
|
||||||
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
|
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
|
||||||
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
|
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
|
||||||
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
|
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
|
||||||
reinterpret_cast<float*>(suffix_lse.data_ptr()), \
|
reinterpret_cast<float*>(suffix_lse.data_ptr()), \
|
||||||
num_tokens, \
|
num_tokens, \
|
||||||
num_heads, \
|
num_heads, \
|
||||||
head_size); \
|
head_size); \
|
||||||
}
|
}
|
||||||
|
|
||||||
/*@brief Merges the attention states from prefix and suffix
|
/*@brief Merges the attention states from prefix and suffix
|
||||||
@@ -170,6 +170,9 @@ void merge_attn_states_launcher(
|
|||||||
dim3 block(NUM_THREADS);
|
dim3 block(NUM_THREADS);
|
||||||
dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS);
|
dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS);
|
||||||
|
|
||||||
|
const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device());
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
|
LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user