Revert "Revert the changes on NCCL symmetric memory" (#10238)

This commit is contained in:
Lianmin Zheng
2025-09-09 12:11:49 -07:00
committed by GitHub
parent d352c29aa0
commit 4582931ac3
5 changed files with 43 additions and 7 deletions

View File

@@ -25,6 +25,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from tqdm import tqdm
from transformers import PretrainedConfig
from sglang.srt.distributed import (
@@ -34,6 +35,9 @@ from sglang.srt.distributed import (
parallel_state,
tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
@@ -524,8 +528,12 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states *= self.routed_scaling_factor
current_stream.wait_stream(self.alt_stream)
final_hidden_states += shared_output
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
final_hidden_states_out = torch.empty_like(final_hidden_states)
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states)
if (
self.tp_size > 1
and not should_allreduce_fusion
@@ -563,8 +571,11 @@ class DeepseekV2MoE(nn.Module):
# fused in biased_grouped_topk so we can skip here
final_hidden_states *= self.routed_scaling_factor
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
final_hidden_states_out = torch.empty_like(final_hidden_states)
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states)
if (
self.tp_size > 1
and not should_allreduce_fusion