From ff8ebfa2087dd9d04d216389921d7a20f2c3f142 Mon Sep 17 00:00:00 2001 From: hanhaowen Date: Mon, 12 Jan 2026 15:18:12 +0800 Subject: [PATCH] enable full cudagraph for deepseek --- .../v1/attention/backends/mla/flashmla.py | 42 +++++++++---------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/vllm_kunlun/v1/attention/backends/mla/flashmla.py b/vllm_kunlun/v1/attention/backends/mla/flashmla.py index 46268eb..f256e02 100644 --- a/vllm_kunlun/v1/attention/backends/mla/flashmla.py +++ b/vllm_kunlun/v1/attention/backends/mla/flashmla.py @@ -99,29 +99,29 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): # TODO: we can disambiguate between decode and mixed-prefill decode here # so we can only use the persistent buffer if a cudagraph is actually # being used. - if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): - assert self.cg_buf_tile_scheduler_metadata is not None - assert self.cg_buf_num_splits is not None + # if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): + # assert self.cg_buf_tile_scheduler_metadata is not None + # assert self.cg_buf_num_splits is not None - sm_parts = tile_scheduler_metadata.size(0) - # Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize) - assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0) - tile_scheduler_metadata_view = \ - self.cg_buf_tile_scheduler_metadata[:sm_parts] - tile_scheduler_metadata_view.copy_(tile_scheduler_metadata) - tile_scheduler_metadata = tile_scheduler_metadata_view + # sm_parts = tile_scheduler_metadata.size(0) + # # Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize) + # assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0) + # tile_scheduler_metadata_view = \ + # self.cg_buf_tile_scheduler_metadata[:sm_parts] + # tile_scheduler_metadata_view.copy_(tile_scheduler_metadata) + # tile_scheduler_metadata = tile_scheduler_metadata_view - # Num splits is per-batch, varying size (batch_size,) - n = num_splits.size(0) - # make sure static buffer is large enough - assert n <= self.cg_buf_num_splits.size(0) - num_splits_view = self.cg_buf_num_splits[:n] - num_splits_view.copy_(num_splits) - # Num splits needs to monotonically increasing - # (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise - # it needs to monotonically increasing by 1) - self.cg_buf_num_splits[n:].fill_(num_splits[-1]) - num_splits = num_splits_view + # # Num splits is per-batch, varying size (batch_size,) + # n = num_splits.size(0) + # # make sure static buffer is large enough + # assert n <= self.cg_buf_num_splits.size(0) + # num_splits_view = self.cg_buf_num_splits[:n] + # num_splits_view.copy_(num_splits) + # # Num splits needs to monotonically increasing + # # (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise + # # it needs to monotonically increasing by 1) + # self.cg_buf_num_splits[n:].fill_(num_splits[-1]) + # num_splits = num_splits_view return FlashMLADecodeMetadata( block_table=block_table_tensor,