diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 875104544..f6f243acc 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -510,17 +510,6 @@ class GroupCoordinator: if self.npu_communicator is not None and not self.npu_communicator.disabled: return self.npu_communicator.all_reduce(input_) - if ( - self.pynccl_comm is not None - and hasattr(input_, "symmetric_memory") - and input_.symmetric_memory - ): - with self.pynccl_comm.change_state( - enable=True, stream=torch.cuda.current_stream() - ): - self.pynccl_comm.all_reduce(input_) - return input_ - outplace_all_reduce_method = None if ( self.qr_comm is not None diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index df2b77e08..035b8bee7 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -13,14 +13,10 @@ from sglang.srt.distributed import ( divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - parallel_state, split_tensor_along_last_dim, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, ) -from sglang.srt.distributed.device_communicators.pynccl_allocator import ( - use_symmetric_memory, -) from sglang.srt.layers.parameter import ( BasevLLMParameter, BlockQuantScaleParameter, @@ -1315,9 +1311,7 @@ class RowParallelLinear(LinearBase): # Only fuse bias add into GEMM for rank 0 (this ensures that # bias will not get added more than once in TP>1 case) bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias - with use_symmetric_memory(parallel_state.get_tp_group()) as sm: - output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) - sm.tag(output_parallel) + output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) if self.reduce_results and self.tp_size > 1 and not skip_all_reduce: output = tensor_model_parallel_all_reduce(output_parallel) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 5f219739c..d9862f674 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -11,12 +11,8 @@ from sglang.srt.distributed import ( get_moe_expert_parallel_world_size, get_moe_tensor_parallel_rank, get_moe_tensor_parallel_world_size, - get_tp_group, tensor_model_parallel_all_reduce, ) -from sglang.srt.distributed.device_communicators.pynccl_allocator import ( - use_symmetric_memory, -) from sglang.srt.eplb.expert_location import get_global_expert_location_metadata from sglang.srt.layers.moe import ( MoeRunnerConfig, diff --git a/python/sglang/srt/layers/vocab_parallel_embedding.py b/python/sglang/srt/layers/vocab_parallel_embedding.py index 66abb7541..b2ad1a824 100644 --- a/python/sglang/srt/layers/vocab_parallel_embedding.py +++ b/python/sglang/srt/layers/vocab_parallel_embedding.py @@ -11,12 +11,8 @@ from sglang.srt.distributed import ( divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - parallel_state, tensor_model_parallel_all_reduce, ) -from sglang.srt.distributed.device_communicators.pynccl_allocator import ( - use_symmetric_memory, -) from sglang.srt.layers.amx_utils import PackWeightMethod from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size from sglang.srt.layers.parameter import BasevLLMParameter @@ -472,10 +468,10 @@ class VocabParallelEmbedding(torch.nn.Module): ) else: masked_input = input_ + # Get the embeddings. - with use_symmetric_memory(parallel_state.get_tp_group()) as sm: - output_parallel = self.quant_method.embedding(self, masked_input.long()) - sm.tag(output_parallel) + output_parallel = self.quant_method.embedding(self, masked_input.long()) + # Mask the output embedding. if self.tp_size > 1: output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 168ad9f29..a25c59948 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -25,7 +25,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union import torch import torch.nn.functional as F from torch import nn -from tqdm import tqdm from transformers import PretrainedConfig from sglang.srt.distributed import ( @@ -35,9 +34,6 @@ from sglang.srt.distributed import ( parallel_state, tensor_model_parallel_all_reduce, ) -from sglang.srt.distributed.device_communicators.pynccl_allocator import ( - use_symmetric_memory, -) from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo @@ -528,12 +524,8 @@ class DeepseekV2MoE(nn.Module): final_hidden_states *= self.routed_scaling_factor current_stream.wait_stream(self.alt_stream) - with use_symmetric_memory(parallel_state.get_tp_group()) as sm: - final_hidden_states_out = torch.empty_like(final_hidden_states) + final_hidden_states += shared_output - torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) - final_hidden_states = final_hidden_states_out - sm.tag(final_hidden_states) if ( self.tp_size > 1 and not should_allreduce_fusion @@ -571,11 +563,8 @@ class DeepseekV2MoE(nn.Module): # fused in biased_grouped_topk so we can skip here final_hidden_states *= self.routed_scaling_factor if shared_output is not None: - with use_symmetric_memory(parallel_state.get_tp_group()) as sm: - final_hidden_states_out = torch.empty_like(final_hidden_states) - torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) - final_hidden_states = final_hidden_states_out - sm.tag(final_hidden_states) + final_hidden_states = final_hidden_states + shared_output + if ( self.tp_size > 1 and not should_allreduce_fusion