Merge pull request #106 from baoqian426/enable-full-cudagraph-deepseek

enable full cudagraph for deepseek
This commit is contained in:
baoqian426
2026-01-13 09:57:56 +08:00
committed by GitHub

View File

@@ -99,29 +99,29 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
# TODO: we can disambiguate between decode and mixed-prefill decode here
# so we can only use the persistent buffer if a cudagraph is actually
# being used.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
assert self.cg_buf_tile_scheduler_metadata is not None
assert self.cg_buf_num_splits is not None
# if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
# assert self.cg_buf_tile_scheduler_metadata is not None
# assert self.cg_buf_num_splits is not None
sm_parts = tile_scheduler_metadata.size(0)
# Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize)
assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0)
tile_scheduler_metadata_view = \
self.cg_buf_tile_scheduler_metadata[:sm_parts]
tile_scheduler_metadata_view.copy_(tile_scheduler_metadata)
tile_scheduler_metadata = tile_scheduler_metadata_view
# sm_parts = tile_scheduler_metadata.size(0)
# # Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize)
# assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0)
# tile_scheduler_metadata_view = \
# self.cg_buf_tile_scheduler_metadata[:sm_parts]
# tile_scheduler_metadata_view.copy_(tile_scheduler_metadata)
# tile_scheduler_metadata = tile_scheduler_metadata_view
# Num splits is per-batch, varying size (batch_size,)
n = num_splits.size(0)
# make sure static buffer is large enough
assert n <= self.cg_buf_num_splits.size(0)
num_splits_view = self.cg_buf_num_splits[:n]
num_splits_view.copy_(num_splits)
# Num splits needs to monotonically increasing
# (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise
# it needs to monotonically increasing by 1)
self.cg_buf_num_splits[n:].fill_(num_splits[-1])
num_splits = num_splits_view
# # Num splits is per-batch, varying size (batch_size,)
# n = num_splits.size(0)
# # make sure static buffer is large enough
# assert n <= self.cg_buf_num_splits.size(0)
# num_splits_view = self.cg_buf_num_splits[:n]
# num_splits_view.copy_(num_splits)
# # Num splits needs to monotonically increasing
# # (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise
# # it needs to monotonically increasing by 1)
# self.cg_buf_num_splits[n:].fill_(num_splits[-1])
# num_splits = num_splits_view
return FlashMLADecodeMetadata(
block_table=block_table_tensor,