diff --git a/python/sglang/srt/layers/attention/cutlass_mla_backend.py b/python/sglang/srt/layers/attention/cutlass_mla_backend.py index fcfd648d0..eb0cae262 100644 --- a/python/sglang/srt/layers/attention/cutlass_mla_backend.py +++ b/python/sglang/srt/layers/attention/cutlass_mla_backend.py @@ -102,7 +102,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend): block_kv_indices, self.req_to_token.stride(0), max_seqlen_pad, - PAGE_SIZE, + PAGED_SIZE=PAGE_SIZE, ) workspace_size = cutlass_mla_get_workspace_size( max_seqlen_pad * PAGE_SIZE, bs, num_kv_splits=1 @@ -165,7 +165,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend): self.cuda_graph_kv_indices, self.req_to_token.stride(0), self.cuda_graph_kv_indices.stride(0), - PAGE_SIZE, + PAGED_SIZE=PAGE_SIZE, ) self.forward_metadata = CutlassMLADecodeMetadata( self.cuda_graph_mla_workspace, @@ -206,7 +206,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend): self.cuda_graph_kv_indices, self.req_to_token.stride(0), self.cuda_graph_kv_indices.stride(0), - PAGE_SIZE, + PAGED_SIZE=PAGE_SIZE, ) else: super().init_forward_metadata_replay_cuda_graph( diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index d33201442..f255f9ce2 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -147,8 +147,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): block_kv_indices, self.req_to_token.stride(0), max_blocks, - TRITON_PAD_NUM_PAGE_PER_BLOCK, - self.page_size, + NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK, + PAGED_SIZE=self.page_size, ) return block_kv_indices @@ -204,8 +204,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): block_kv_indices, self.req_to_token.stride(0), max_seqlen_pad, - TRITON_PAD_NUM_PAGE_PER_BLOCK, - self.page_size, + NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK, + PAGED_SIZE=self.page_size, ) metadata = TRTLLMMLADecodeMetadata(self.cuda_graph_workspace, block_kv_indices) @@ -248,8 +248,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): metadata.block_kv_indices, self.req_to_token.stride(0), metadata.block_kv_indices.shape[1], - TRITON_PAD_NUM_PAGE_PER_BLOCK, - self.page_size, + NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK, + PAGED_SIZE=self.page_size, ) def get_cuda_graph_seq_len_fill_value(self) -> int: