Log if cuda graph is used & extend cuda graph capture to cuda-graph-max-bs (#6201)
Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
@@ -28,7 +28,8 @@ def create_flashinfer_kv_indices_triton(
|
||||
|
||||
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
||||
for i in range(num_loop):
|
||||
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
||||
# index into req_to_token_ptr needs to be int64
|
||||
offset = tl.arange(0, BLOCK_SIZE).to(tl.int64) + i * BLOCK_SIZE
|
||||
mask = offset < kv_end - kv_start
|
||||
data = tl.load(
|
||||
req_to_token_ptr
|
||||
@@ -70,8 +71,9 @@ def create_flashmla_kv_indices_triton(
|
||||
num_pages_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
||||
|
||||
for i in range(num_pages_loop):
|
||||
# index into req_to_token_ptr needs to be int64
|
||||
paged_offset = (
|
||||
tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK
|
||||
tl.arange(0, NUM_PAGE_PER_BLOCK).to(tl.int64) + i * NUM_PAGE_PER_BLOCK
|
||||
) * PAGED_SIZE
|
||||
paged_offset_out = tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK
|
||||
|
||||
|
||||
Reference in New Issue
Block a user