diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 47872b6b9..eed6125a9 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -189,6 +189,9 @@ class GroupCoordinator: device_group: ProcessGroup # group for device communication use_pynccl: bool # a hint of whether to use PyNccl use_custom_allreduce: bool # a hint of whether to use CustomAllreduce + use_message_queue_broadcaster: ( + bool # a hint of whether to use message queue broadcaster + ) # communicators are only created for world size > 1 pynccl_comm: Optional[Any] # PyNccl communicator ca_comm: Optional[Any] # Custom allreduce communicator @@ -241,6 +244,7 @@ class GroupCoordinator: self.use_custom_allreduce = use_custom_allreduce self.use_hpu_communicator = use_hpu_communicator self.use_xpu_communicator = use_xpu_communicator + self.use_message_queue_broadcaster = use_message_queue_broadcaster # lazy import to avoid documentation build error from sglang.srt.distributed.device_communicators.custom_all_reduce import ( @@ -269,7 +273,7 @@ class GroupCoordinator: HpuCommunicator, ) - self.hpu_communicator: Optional[HpuCommunicator] + self.hpu_communicator: Optional[HpuCommunicator] = None if use_hpu_communicator and self.world_size > 1: self.hpu_communicator = HpuCommunicator(group=self.device_group) @@ -277,7 +281,7 @@ class GroupCoordinator: XpuCommunicator, ) - self.xpu_communicator: Optional[XpuCommunicator] + self.xpu_communicator: Optional[XpuCommunicator] = None if use_xpu_communicator and self.world_size > 1: self.xpu_communicator = XpuCommunicator(group=self.device_group) diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index 0d593e048..c36b9706e 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -53,10 +53,8 @@ def initialize_dp_attention( ) if enable_dp_attention: - local_rank = tp_rank % (tp_size // dp_size) _DP_SIZE = dp_size else: - local_rank = tp_rank _DP_SIZE = 1 tp_group = get_tp_group() @@ -65,7 +63,7 @@ def initialize_dp_attention( list(range(head, head + _ATTN_TP_SIZE)) for head in range(0, tp_size, _ATTN_TP_SIZE) ], - local_rank, + tp_group.local_rank, torch.distributed.get_backend(tp_group.device_group), SYNC_TOKEN_IDS_ACROSS_TP, False, @@ -180,20 +178,19 @@ def memcpy_triton(dst, src, dim, offset, sz, offset_src): memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE) -def dp_gather( +def _dp_gather( global_tokens: torch.Tensor, local_tokens: torch.Tensor, forward_batch: ForwardBatch, - layer_id: Union[str, int], + is_partial: bool, ): local_start_pos, local_num_tokens = get_dp_local_info(forward_batch) global_tokens.fill_(0) assert local_tokens.is_contiguous() assert global_tokens.is_contiguous() - if local_tokens.shape[0] > 0 and ( - layer_id != "embedding" or get_attention_tp_rank() == 0 - ): + + if local_tokens.shape[0] > 0 and (is_partial or get_attention_tp_rank() == 0): assert ( global_tokens.untyped_storage().data_ptr() != local_tokens.untyped_storage().data_ptr() @@ -216,6 +213,22 @@ def dp_gather( global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens) +def dp_gather_partial( + global_tokens: torch.Tensor, + local_tokens: torch.Tensor, + forward_batch: ForwardBatch, +): + _dp_gather(global_tokens, local_tokens, forward_batch, is_partial=True) + + +def dp_gather_replicate( + global_tokens: torch.Tensor, + local_tokens: torch.Tensor, + forward_batch: ForwardBatch, +): + _dp_gather(global_tokens, local_tokens, forward_batch, is_partial=False) + + def dp_scatter( local_tokens: torch.Tensor, # output global_tokens: torch.Tensor, # input @@ -236,16 +249,3 @@ def dp_scatter( memcpy_triton( local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True ) - - -def get_do_logits_dp_scatter(forward_batch: ForwardBatch): - def do_logits_dp_scatter(logits: torch.Tensor): - local_logits = torch.empty( - (forward_batch.input_ids.shape[0], *logits.shape[1:]), - dtype=logits.dtype, - device=logits.device, - ) - dp_scatter(local_logits, logits, forward_batch) - return local_logits - - return do_logits_dp_scatter diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 0f8329ae2..981040d0d 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -28,7 +28,7 @@ from sglang.srt.distributed import ( tensor_model_parallel_all_gather, ) from sglang.srt.layers.dp_attention import ( - dp_gather, + dp_gather_replicate, dp_scatter, get_attention_dp_rank, get_attention_dp_size, @@ -428,7 +428,7 @@ class LogitsProcessor(nn.Module): logits_metadata.gathered_buffer, hidden_states.clone(), ) - dp_gather(hidden_states, local_hidden_states, logits_metadata, "embedding") + dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata) if hasattr(lm_head, "weight"): logits = torch.matmul( diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index c9701f166..1cbd0097a 100755 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -33,7 +33,7 @@ from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import ( decode_attention_fwd_grouped_rope, ) from sglang.srt.layers.dp_attention import ( - dp_gather, + dp_gather_partial, dp_scatter, get_attention_dp_size, get_attention_tp_rank, @@ -939,11 +939,47 @@ class DeepseekV2DecoderLayer(nn.Module): forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> torch.Tensor: - if residual is None: + if hidden_states.shape[0] == 0: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + # Self Attention + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + + # Gather + if get_tensor_model_parallel_world_size() > 1: + # all gather and all reduce + if self.dp_size != 1: + if get_attention_tp_rank() == 0: + hidden_states += residual + hidden_states, local_hidden_states = ( + forward_batch.gathered_buffer, + hidden_states, + ) + dp_gather_partial(hidden_states, local_hidden_states, forward_batch) + dp_scatter(residual, hidden_states, forward_batch) + hidden_states = self.post_attention_layernorm(hidden_states) + else: + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) + else: + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) + + # Fully Connected + hidden_states = self.mlp(hidden_states) # Scatter if self.dp_size != 1: @@ -955,31 +991,6 @@ class DeepseekV2DecoderLayer(nn.Module): ) dp_scatter(hidden_states, global_hidden_states, forward_batch) - # Self Attention - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - forward_batch=forward_batch, - ) - - # Gather - if get_tensor_model_parallel_world_size() > 1: - # all gather and all reduce - if self.dp_size != 1: - hidden_states, local_hidden_states = ( - forward_batch.gathered_buffer, - hidden_states, - ) - dp_gather( - hidden_states, local_hidden_states, forward_batch, self.layer_id - ) - else: - hidden_states = tensor_model_parallel_all_reduce(hidden_states) - - hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) - - # Fully Connected - hidden_states = self.mlp(hidden_states) return hidden_states, residual @@ -1025,18 +1036,6 @@ class DeepseekV2Model(nn.Module): input_embeds: torch.Tensor = None, ) -> torch.Tensor: - # Gather - if self.dp_size != 1: - input_ids, local_input_ids = ( - torch.empty( - (forward_batch.gathered_buffer.shape[0],), - dtype=input_ids.dtype, - device=input_ids.device, - ), - input_ids, - ) - dp_gather(input_ids, local_input_ids, forward_batch, "embedding") - if input_embeds is None: hidden_states = self.embed_tokens(input_ids) else: @@ -1087,15 +1086,6 @@ class DeepseekV2ForCausalLM(nn.Module): hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) - if self.dp_size != 1: - # important: forward batch.gathered_buffer is used both after scatter and after gather. - # be careful about this! - hidden_states, global_hidden_states = ( - forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], - hidden_states, - ) - dp_scatter(hidden_states, global_hidden_states, forward_batch) - return self.logits_processor( input_ids, hidden_states, self.lm_head, forward_batch ) diff --git a/test/srt/test_dp_attention.py b/test/srt/test_dp_attention.py index d24507ae2..f7811911f 100644 --- a/test/srt/test_dp_attention.py +++ b/test/srt/test_dp_attention.py @@ -11,7 +11,7 @@ from sglang.test.test_utils import ( ) -class TestDPAttention(unittest.TestCase): +class TestDPAttentionDP2TP2(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST @@ -59,7 +59,3 @@ class TestDPAttention(unittest.TestCase): metrics = run_eval(args) print(f"{metrics=}") self.assertGreater(metrics["score"], 0.8) - - -if __name__ == "__main__": - unittest.main()