enable full cudagraph for deepseek

This commit is contained in:
hanhaowen
2026-01-12 15:18:12 +08:00
parent 87a57e43ca
commit ff8ebfa208

View File

@@ -99,29 +99,29 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
# TODO: we can disambiguate between decode and mixed-prefill decode here # TODO: we can disambiguate between decode and mixed-prefill decode here
# so we can only use the persistent buffer if a cudagraph is actually # so we can only use the persistent buffer if a cudagraph is actually
# being used. # being used.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): # if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
assert self.cg_buf_tile_scheduler_metadata is not None # assert self.cg_buf_tile_scheduler_metadata is not None
assert self.cg_buf_num_splits is not None # assert self.cg_buf_num_splits is not None
sm_parts = tile_scheduler_metadata.size(0) # sm_parts = tile_scheduler_metadata.size(0)
# Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize) # # Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize)
assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0) # assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0)
tile_scheduler_metadata_view = \ # tile_scheduler_metadata_view = \
self.cg_buf_tile_scheduler_metadata[:sm_parts] # self.cg_buf_tile_scheduler_metadata[:sm_parts]
tile_scheduler_metadata_view.copy_(tile_scheduler_metadata) # tile_scheduler_metadata_view.copy_(tile_scheduler_metadata)
tile_scheduler_metadata = tile_scheduler_metadata_view # tile_scheduler_metadata = tile_scheduler_metadata_view
# Num splits is per-batch, varying size (batch_size,) # # Num splits is per-batch, varying size (batch_size,)
n = num_splits.size(0) # n = num_splits.size(0)
# make sure static buffer is large enough # # make sure static buffer is large enough
assert n <= self.cg_buf_num_splits.size(0) # assert n <= self.cg_buf_num_splits.size(0)
num_splits_view = self.cg_buf_num_splits[:n] # num_splits_view = self.cg_buf_num_splits[:n]
num_splits_view.copy_(num_splits) # num_splits_view.copy_(num_splits)
# Num splits needs to monotonically increasing # # Num splits needs to monotonically increasing
# (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise # # (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise
# it needs to monotonically increasing by 1) # # it needs to monotonically increasing by 1)
self.cg_buf_num_splits[n:].fill_(num_splits[-1]) # self.cg_buf_num_splits[n:].fill_(num_splits[-1])
num_splits = num_splits_view # num_splits = num_splits_view
return FlashMLADecodeMetadata( return FlashMLADecodeMetadata(
block_table=block_table_tensor, block_table=block_table_tensor,