Revert the changes on NCCL symmetric memory (#10210)
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -25,7 +25,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
@@ -35,9 +34,6 @@ from sglang.srt.distributed import (
|
||||
parallel_state,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
||||
use_symmetric_memory,
|
||||
)
|
||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
||||
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
||||
@@ -528,12 +524,8 @@ class DeepseekV2MoE(nn.Module):
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
|
||||
current_stream.wait_stream(self.alt_stream)
|
||||
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
||||
final_hidden_states_out = torch.empty_like(final_hidden_states)
|
||||
final_hidden_states += shared_output
|
||||
|
||||
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
||||
final_hidden_states = final_hidden_states_out
|
||||
sm.tag(final_hidden_states)
|
||||
if (
|
||||
self.tp_size > 1
|
||||
and not should_allreduce_fusion
|
||||
@@ -571,11 +563,8 @@ class DeepseekV2MoE(nn.Module):
|
||||
# fused in biased_grouped_topk so we can skip here
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
if shared_output is not None:
|
||||
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
||||
final_hidden_states_out = torch.empty_like(final_hidden_states)
|
||||
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
||||
final_hidden_states = final_hidden_states_out
|
||||
sm.tag(final_hidden_states)
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
|
||||
if (
|
||||
self.tp_size > 1
|
||||
and not should_allreduce_fusion
|
||||
|
||||
Reference in New Issue
Block a user