[bugfix] Fix page size for create_flashmla_kv_indices_triton() for cutlass mla (#8685)
This commit is contained in:
@@ -102,7 +102,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
|||||||
block_kv_indices,
|
block_kv_indices,
|
||||||
self.req_to_token.stride(0),
|
self.req_to_token.stride(0),
|
||||||
max_seqlen_pad,
|
max_seqlen_pad,
|
||||||
PAGE_SIZE,
|
PAGED_SIZE=PAGE_SIZE,
|
||||||
)
|
)
|
||||||
workspace_size = cutlass_mla_get_workspace_size(
|
workspace_size = cutlass_mla_get_workspace_size(
|
||||||
max_seqlen_pad * PAGE_SIZE, bs, num_kv_splits=1
|
max_seqlen_pad * PAGE_SIZE, bs, num_kv_splits=1
|
||||||
@@ -165,7 +165,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
|||||||
self.cuda_graph_kv_indices,
|
self.cuda_graph_kv_indices,
|
||||||
self.req_to_token.stride(0),
|
self.req_to_token.stride(0),
|
||||||
self.cuda_graph_kv_indices.stride(0),
|
self.cuda_graph_kv_indices.stride(0),
|
||||||
PAGE_SIZE,
|
PAGED_SIZE=PAGE_SIZE,
|
||||||
)
|
)
|
||||||
self.forward_metadata = CutlassMLADecodeMetadata(
|
self.forward_metadata = CutlassMLADecodeMetadata(
|
||||||
self.cuda_graph_mla_workspace,
|
self.cuda_graph_mla_workspace,
|
||||||
@@ -206,7 +206,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
|||||||
self.cuda_graph_kv_indices,
|
self.cuda_graph_kv_indices,
|
||||||
self.req_to_token.stride(0),
|
self.req_to_token.stride(0),
|
||||||
self.cuda_graph_kv_indices.stride(0),
|
self.cuda_graph_kv_indices.stride(0),
|
||||||
PAGE_SIZE,
|
PAGED_SIZE=PAGE_SIZE,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
super().init_forward_metadata_replay_cuda_graph(
|
super().init_forward_metadata_replay_cuda_graph(
|
||||||
|
|||||||
@@ -147,8 +147,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
block_kv_indices,
|
block_kv_indices,
|
||||||
self.req_to_token.stride(0),
|
self.req_to_token.stride(0),
|
||||||
max_blocks,
|
max_blocks,
|
||||||
TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
||||||
self.page_size,
|
PAGED_SIZE=self.page_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
return block_kv_indices
|
return block_kv_indices
|
||||||
@@ -204,8 +204,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
block_kv_indices,
|
block_kv_indices,
|
||||||
self.req_to_token.stride(0),
|
self.req_to_token.stride(0),
|
||||||
max_seqlen_pad,
|
max_seqlen_pad,
|
||||||
TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
||||||
self.page_size,
|
PAGED_SIZE=self.page_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata = TRTLLMMLADecodeMetadata(self.cuda_graph_workspace, block_kv_indices)
|
metadata = TRTLLMMLADecodeMetadata(self.cuda_graph_workspace, block_kv_indices)
|
||||||
@@ -248,8 +248,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
metadata.block_kv_indices,
|
metadata.block_kv_indices,
|
||||||
self.req_to_token.stride(0),
|
self.req_to_token.stride(0),
|
||||||
metadata.block_kv_indices.shape[1],
|
metadata.block_kv_indices.shape[1],
|
||||||
TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
||||||
self.page_size,
|
PAGED_SIZE=self.page_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_cuda_graph_seq_len_fill_value(self) -> int:
|
def get_cuda_graph_seq_len_fill_value(self) -> int:
|
||||||
|
|||||||
Reference in New Issue
Block a user