Support (1 <= dp < tp) in the dp attention in DeepEP (#4770)

Co-authored-by: Cheng Wan <cwan39@gatech.edu>
This commit is contained in:
tarinkk
2025-03-27 20:09:35 -04:00
committed by GitHub
parent 98a2cfa9b2
commit 7f19e083c1
10 changed files with 238 additions and 47 deletions

View File

@@ -174,6 +174,7 @@ class CudaGraphRunner:
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
self.tp_size = model_runner.server_args.tp_size
self.dp_size = model_runner.server_args.dp_size
@@ -245,8 +246,8 @@ class CudaGraphRunner:
)
else:
self.encoder_lens = None
if self.enable_dp_attention:
if self.enable_dp_attention or self.enable_sp_layernorm:
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
self.gathered_buffer = torch.zeros(
(
self.max_bs * self.dp_size * self.num_tokens_per_bs,
@@ -288,7 +289,7 @@ class CudaGraphRunner:
self.model_runner.token_to_kv_pool.capture_mode = False
def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention:
if self.enable_dp_attention or self.enable_sp_layernorm:
total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
@@ -369,7 +370,7 @@ class CudaGraphRunner:
encoder_lens = None
mrope_positions = self.mrope_positions[:, :bs]
if self.enable_dp_attention:
if self.enable_dp_attention or self.enable_sp_layernorm:
self.global_num_tokens_gpu.copy_(
torch.tensor(
[
@@ -471,7 +472,7 @@ class CudaGraphRunner:
raw_num_token = raw_bs * self.num_tokens_per_bs
# Pad
if self.enable_dp_attention:
if self.enable_dp_attention or self.enable_sp_layernorm:
index = bisect.bisect_left(
self.capture_bs, sum(forward_batch.global_num_tokens_cpu)
)
@@ -497,7 +498,7 @@ class CudaGraphRunner:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
if forward_batch.mrope_positions is not None:
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
if self.enable_dp_attention:
if self.enable_dp_attention or self.enable_sp_layernorm:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
if hasattr(forward_batch.spec_info, "hidden_states"):

View File

@@ -281,9 +281,6 @@ class ModelRunner:
if server_args.enable_deepep_moe:
logger.info("DeepEP is turned on.")
assert (
server_args.enable_dp_attention == True
), "Currently DeepEP is bind to Attention DP. Set '--enable-dp-attention --enable-deepep-moe'"
def init_torch_distributed(self):
logger.info("Init torch distributed begin.")