Support (1 <= dp < tp) in the dp attention in DeepEP (#4770)
Co-authored-by: Cheng Wan <cwan39@gatech.edu>
This commit is contained in:
@@ -174,6 +174,7 @@ class CudaGraphRunner:
|
||||
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
||||
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
|
||||
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
|
||||
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
|
||||
self.tp_size = model_runner.server_args.tp_size
|
||||
self.dp_size = model_runner.server_args.dp_size
|
||||
@@ -245,8 +246,8 @@ class CudaGraphRunner:
|
||||
)
|
||||
else:
|
||||
self.encoder_lens = None
|
||||
|
||||
if self.enable_dp_attention:
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_bs * self.dp_size * self.num_tokens_per_bs,
|
||||
@@ -288,7 +289,7 @@ class CudaGraphRunner:
|
||||
self.model_runner.token_to_kv_pool.capture_mode = False
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
if self.enable_dp_attention:
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
|
||||
|
||||
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
|
||||
@@ -369,7 +370,7 @@ class CudaGraphRunner:
|
||||
encoder_lens = None
|
||||
mrope_positions = self.mrope_positions[:, :bs]
|
||||
|
||||
if self.enable_dp_attention:
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
self.global_num_tokens_gpu.copy_(
|
||||
torch.tensor(
|
||||
[
|
||||
@@ -471,7 +472,7 @@ class CudaGraphRunner:
|
||||
raw_num_token = raw_bs * self.num_tokens_per_bs
|
||||
|
||||
# Pad
|
||||
if self.enable_dp_attention:
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
index = bisect.bisect_left(
|
||||
self.capture_bs, sum(forward_batch.global_num_tokens_cpu)
|
||||
)
|
||||
@@ -497,7 +498,7 @@ class CudaGraphRunner:
|
||||
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
||||
if forward_batch.mrope_positions is not None:
|
||||
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
||||
if self.enable_dp_attention:
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
||||
|
||||
if hasattr(forward_batch.spec_info, "hidden_states"):
|
||||
|
||||
@@ -281,9 +281,6 @@ class ModelRunner:
|
||||
|
||||
if server_args.enable_deepep_moe:
|
||||
logger.info("DeepEP is turned on.")
|
||||
assert (
|
||||
server_args.enable_dp_attention == True
|
||||
), "Currently DeepEP is bind to Attention DP. Set '--enable-dp-attention --enable-deepep-moe'"
|
||||
|
||||
def init_torch_distributed(self):
|
||||
logger.info("Init torch distributed begin.")
|
||||
|
||||
Reference in New Issue
Block a user