Merge pull request #106 from baoqian426/enable-full-cudagraph-deepseek
enable full cudagraph for deepseek
This commit is contained in:
@@ -99,29 +99,29 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
# TODO: we can disambiguate between decode and mixed-prefill decode here
|
||||
# so we can only use the persistent buffer if a cudagraph is actually
|
||||
# being used.
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
assert self.cg_buf_tile_scheduler_metadata is not None
|
||||
assert self.cg_buf_num_splits is not None
|
||||
# if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
# assert self.cg_buf_tile_scheduler_metadata is not None
|
||||
# assert self.cg_buf_num_splits is not None
|
||||
|
||||
sm_parts = tile_scheduler_metadata.size(0)
|
||||
# Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize)
|
||||
assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0)
|
||||
tile_scheduler_metadata_view = \
|
||||
self.cg_buf_tile_scheduler_metadata[:sm_parts]
|
||||
tile_scheduler_metadata_view.copy_(tile_scheduler_metadata)
|
||||
tile_scheduler_metadata = tile_scheduler_metadata_view
|
||||
# sm_parts = tile_scheduler_metadata.size(0)
|
||||
# # Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize)
|
||||
# assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0)
|
||||
# tile_scheduler_metadata_view = \
|
||||
# self.cg_buf_tile_scheduler_metadata[:sm_parts]
|
||||
# tile_scheduler_metadata_view.copy_(tile_scheduler_metadata)
|
||||
# tile_scheduler_metadata = tile_scheduler_metadata_view
|
||||
|
||||
# Num splits is per-batch, varying size (batch_size,)
|
||||
n = num_splits.size(0)
|
||||
# make sure static buffer is large enough
|
||||
assert n <= self.cg_buf_num_splits.size(0)
|
||||
num_splits_view = self.cg_buf_num_splits[:n]
|
||||
num_splits_view.copy_(num_splits)
|
||||
# Num splits needs to monotonically increasing
|
||||
# (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise
|
||||
# it needs to monotonically increasing by 1)
|
||||
self.cg_buf_num_splits[n:].fill_(num_splits[-1])
|
||||
num_splits = num_splits_view
|
||||
# # Num splits is per-batch, varying size (batch_size,)
|
||||
# n = num_splits.size(0)
|
||||
# # make sure static buffer is large enough
|
||||
# assert n <= self.cg_buf_num_splits.size(0)
|
||||
# num_splits_view = self.cg_buf_num_splits[:n]
|
||||
# num_splits_view.copy_(num_splits)
|
||||
# # Num splits needs to monotonically increasing
|
||||
# # (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise
|
||||
# # it needs to monotonically increasing by 1)
|
||||
# self.cg_buf_num_splits[n:].fill_(num_splits[-1])
|
||||
# num_splits = num_splits_view
|
||||
|
||||
return FlashMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
|
||||
Reference in New Issue
Block a user