enable full cudagraph for deepseek
This commit is contained in:
@@ -99,29 +99,29 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
|||||||
# TODO: we can disambiguate between decode and mixed-prefill decode here
|
# TODO: we can disambiguate between decode and mixed-prefill decode here
|
||||||
# so we can only use the persistent buffer if a cudagraph is actually
|
# so we can only use the persistent buffer if a cudagraph is actually
|
||||||
# being used.
|
# being used.
|
||||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
# if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||||
assert self.cg_buf_tile_scheduler_metadata is not None
|
# assert self.cg_buf_tile_scheduler_metadata is not None
|
||||||
assert self.cg_buf_num_splits is not None
|
# assert self.cg_buf_num_splits is not None
|
||||||
|
|
||||||
sm_parts = tile_scheduler_metadata.size(0)
|
# sm_parts = tile_scheduler_metadata.size(0)
|
||||||
# Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize)
|
# # Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize)
|
||||||
assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0)
|
# assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0)
|
||||||
tile_scheduler_metadata_view = \
|
# tile_scheduler_metadata_view = \
|
||||||
self.cg_buf_tile_scheduler_metadata[:sm_parts]
|
# self.cg_buf_tile_scheduler_metadata[:sm_parts]
|
||||||
tile_scheduler_metadata_view.copy_(tile_scheduler_metadata)
|
# tile_scheduler_metadata_view.copy_(tile_scheduler_metadata)
|
||||||
tile_scheduler_metadata = tile_scheduler_metadata_view
|
# tile_scheduler_metadata = tile_scheduler_metadata_view
|
||||||
|
|
||||||
# Num splits is per-batch, varying size (batch_size,)
|
# # Num splits is per-batch, varying size (batch_size,)
|
||||||
n = num_splits.size(0)
|
# n = num_splits.size(0)
|
||||||
# make sure static buffer is large enough
|
# # make sure static buffer is large enough
|
||||||
assert n <= self.cg_buf_num_splits.size(0)
|
# assert n <= self.cg_buf_num_splits.size(0)
|
||||||
num_splits_view = self.cg_buf_num_splits[:n]
|
# num_splits_view = self.cg_buf_num_splits[:n]
|
||||||
num_splits_view.copy_(num_splits)
|
# num_splits_view.copy_(num_splits)
|
||||||
# Num splits needs to monotonically increasing
|
# # Num splits needs to monotonically increasing
|
||||||
# (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise
|
# # (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise
|
||||||
# it needs to monotonically increasing by 1)
|
# # it needs to monotonically increasing by 1)
|
||||||
self.cg_buf_num_splits[n:].fill_(num_splits[-1])
|
# self.cg_buf_num_splits[n:].fill_(num_splits[-1])
|
||||||
num_splits = num_splits_view
|
# num_splits = num_splits_view
|
||||||
|
|
||||||
return FlashMLADecodeMetadata(
|
return FlashMLADecodeMetadata(
|
||||||
block_table=block_table_tensor,
|
block_table=block_table_tensor,
|
||||||
|
|||||||
Reference in New Issue
Block a user