[bugifx] QWen-1M context support[2/3] using current cuda stream in the DCA's kernel for bugfix. (#8611)
Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Co-authored-by: sa-buc <linzhu.ht@w32d09270.cloud.sqa.na131>
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
// This file is for blocksparse attention utils cuda kernel.
|
||||
|
||||
#include <assert.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <cuda.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
@@ -176,7 +177,8 @@ void convert_vertical_slash_indexes_64x64(
|
||||
const dim3 dimBlock((int32_t)N_THREADS);
|
||||
const dim3 dimGrid(
|
||||
(int32_t)N_HEADS, (int32_t)BATCH_SIZE, ((int32_t)N_ROWS + (int32_t)N_THREADS - 1) / (int32_t)N_THREADS);
|
||||
convert_vertical_slash_indexes_kernel<<<dimGrid, dimBlock>>>(
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
convert_vertical_slash_indexes_kernel<<<dimGrid, dimBlock, 0, stream>>>(
|
||||
q_seqlens,
|
||||
kv_seqlens,
|
||||
vertical_indexes,
|
||||
@@ -393,7 +395,8 @@ void convert_vertical_slash_indexes_64x64_mergehead(
|
||||
const int N_THREADS = 64;
|
||||
const dim3 dimBlock(N_THREADS);
|
||||
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
|
||||
convert_vertical_slash_indexes_kernel_mergehead<<<dimGrid, dimBlock>>>(
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
convert_vertical_slash_indexes_kernel_mergehead<<<dimGrid, dimBlock, 0, stream>>>(
|
||||
q_seqlens,
|
||||
kv_seqlens,
|
||||
vertical_indexes,
|
||||
|
||||
Reference in New Issue
Block a user