diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index f6f243acc..875104544 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -510,6 +510,17 @@ class GroupCoordinator: if self.npu_communicator is not None and not self.npu_communicator.disabled: return self.npu_communicator.all_reduce(input_) + if ( + self.pynccl_comm is not None + and hasattr(input_, "symmetric_memory") + and input_.symmetric_memory + ): + with self.pynccl_comm.change_state( + enable=True, stream=torch.cuda.current_stream() + ): + self.pynccl_comm.all_reduce(input_) + return input_ + outplace_all_reduce_method = None if ( self.qr_comm is not None diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 035b8bee7..df2b77e08 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -13,10 +13,14 @@ from sglang.srt.distributed import ( divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + parallel_state, split_tensor_along_last_dim, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, ) +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) from sglang.srt.layers.parameter import ( BasevLLMParameter, BlockQuantScaleParameter, @@ -1311,7 +1315,9 @@ class RowParallelLinear(LinearBase): # Only fuse bias add into GEMM for rank 0 (this ensures that # bias will not get added more than once in TP>1 case) bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias - output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) + with use_symmetric_memory(parallel_state.get_tp_group()) as sm: + output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) + sm.tag(output_parallel) if self.reduce_results and self.tp_size > 1 and not skip_all_reduce: output = tensor_model_parallel_all_reduce(output_parallel) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index d9862f674..5f219739c 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -11,8 +11,12 @@ from sglang.srt.distributed import ( get_moe_expert_parallel_world_size, get_moe_tensor_parallel_rank, get_moe_tensor_parallel_world_size, + get_tp_group, tensor_model_parallel_all_reduce, ) +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) from sglang.srt.eplb.expert_location import get_global_expert_location_metadata from sglang.srt.layers.moe import ( MoeRunnerConfig, diff --git a/python/sglang/srt/layers/vocab_parallel_embedding.py b/python/sglang/srt/layers/vocab_parallel_embedding.py index b2ad1a824..66abb7541 100644 --- a/python/sglang/srt/layers/vocab_parallel_embedding.py +++ b/python/sglang/srt/layers/vocab_parallel_embedding.py @@ -11,8 +11,12 @@ from sglang.srt.distributed import ( divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + parallel_state, tensor_model_parallel_all_reduce, ) +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) from sglang.srt.layers.amx_utils import PackWeightMethod from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size from sglang.srt.layers.parameter import BasevLLMParameter @@ -468,10 +472,10 @@ class VocabParallelEmbedding(torch.nn.Module): ) else: masked_input = input_ - # Get the embeddings. - output_parallel = self.quant_method.embedding(self, masked_input.long()) - + with use_symmetric_memory(parallel_state.get_tp_group()) as sm: + output_parallel = self.quant_method.embedding(self, masked_input.long()) + sm.tag(output_parallel) # Mask the output embedding. if self.tp_size > 1: output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index a25c59948..168ad9f29 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -25,6 +25,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union import torch import torch.nn.functional as F from torch import nn +from tqdm import tqdm from transformers import PretrainedConfig from sglang.srt.distributed import ( @@ -34,6 +35,9 @@ from sglang.srt.distributed import ( parallel_state, tensor_model_parallel_all_reduce, ) +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo @@ -524,8 +528,12 @@ class DeepseekV2MoE(nn.Module): final_hidden_states *= self.routed_scaling_factor current_stream.wait_stream(self.alt_stream) - final_hidden_states += shared_output + with use_symmetric_memory(parallel_state.get_tp_group()) as sm: + final_hidden_states_out = torch.empty_like(final_hidden_states) + torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) + final_hidden_states = final_hidden_states_out + sm.tag(final_hidden_states) if ( self.tp_size > 1 and not should_allreduce_fusion @@ -563,8 +571,11 @@ class DeepseekV2MoE(nn.Module): # fused in biased_grouped_topk so we can skip here final_hidden_states *= self.routed_scaling_factor if shared_output is not None: - final_hidden_states = final_hidden_states + shared_output - + with use_symmetric_memory(parallel_state.get_tp_group()) as sm: + final_hidden_states_out = torch.empty_like(final_hidden_states) + torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) + final_hidden_states = final_hidden_states_out + sm.tag(final_hidden_states) if ( self.tp_size > 1 and not should_allreduce_fusion