Revert "Revert the changes on NCCL symmetric memory" (#10238)
This commit is contained in:
@@ -510,6 +510,17 @@ class GroupCoordinator:
|
|||||||
if self.npu_communicator is not None and not self.npu_communicator.disabled:
|
if self.npu_communicator is not None and not self.npu_communicator.disabled:
|
||||||
return self.npu_communicator.all_reduce(input_)
|
return self.npu_communicator.all_reduce(input_)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.pynccl_comm is not None
|
||||||
|
and hasattr(input_, "symmetric_memory")
|
||||||
|
and input_.symmetric_memory
|
||||||
|
):
|
||||||
|
with self.pynccl_comm.change_state(
|
||||||
|
enable=True, stream=torch.cuda.current_stream()
|
||||||
|
):
|
||||||
|
self.pynccl_comm.all_reduce(input_)
|
||||||
|
return input_
|
||||||
|
|
||||||
outplace_all_reduce_method = None
|
outplace_all_reduce_method = None
|
||||||
if (
|
if (
|
||||||
self.qr_comm is not None
|
self.qr_comm is not None
|
||||||
|
|||||||
@@ -13,10 +13,14 @@ from sglang.srt.distributed import (
|
|||||||
divide,
|
divide,
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
|
parallel_state,
|
||||||
split_tensor_along_last_dim,
|
split_tensor_along_last_dim,
|
||||||
tensor_model_parallel_all_gather,
|
tensor_model_parallel_all_gather,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
||||||
|
use_symmetric_memory,
|
||||||
|
)
|
||||||
from sglang.srt.layers.parameter import (
|
from sglang.srt.layers.parameter import (
|
||||||
BasevLLMParameter,
|
BasevLLMParameter,
|
||||||
BlockQuantScaleParameter,
|
BlockQuantScaleParameter,
|
||||||
@@ -1311,7 +1315,9 @@ class RowParallelLinear(LinearBase):
|
|||||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||||
# bias will not get added more than once in TP>1 case)
|
# bias will not get added more than once in TP>1 case)
|
||||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||||
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
|
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
||||||
|
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
|
||||||
|
sm.tag(output_parallel)
|
||||||
|
|
||||||
if self.reduce_results and self.tp_size > 1 and not skip_all_reduce:
|
if self.reduce_results and self.tp_size > 1 and not skip_all_reduce:
|
||||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||||
|
|||||||
@@ -11,8 +11,12 @@ from sglang.srt.distributed import (
|
|||||||
get_moe_expert_parallel_world_size,
|
get_moe_expert_parallel_world_size,
|
||||||
get_moe_tensor_parallel_rank,
|
get_moe_tensor_parallel_rank,
|
||||||
get_moe_tensor_parallel_world_size,
|
get_moe_tensor_parallel_world_size,
|
||||||
|
get_tp_group,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
||||||
|
use_symmetric_memory,
|
||||||
|
)
|
||||||
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
||||||
from sglang.srt.layers.moe import (
|
from sglang.srt.layers.moe import (
|
||||||
MoeRunnerConfig,
|
MoeRunnerConfig,
|
||||||
|
|||||||
@@ -11,8 +11,12 @@ from sglang.srt.distributed import (
|
|||||||
divide,
|
divide,
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
|
parallel_state,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
||||||
|
use_symmetric_memory,
|
||||||
|
)
|
||||||
from sglang.srt.layers.amx_utils import PackWeightMethod
|
from sglang.srt.layers.amx_utils import PackWeightMethod
|
||||||
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
||||||
from sglang.srt.layers.parameter import BasevLLMParameter
|
from sglang.srt.layers.parameter import BasevLLMParameter
|
||||||
@@ -468,10 +472,10 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
masked_input = input_
|
masked_input = input_
|
||||||
|
|
||||||
# Get the embeddings.
|
# Get the embeddings.
|
||||||
output_parallel = self.quant_method.embedding(self, masked_input.long())
|
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
||||||
|
output_parallel = self.quant_method.embedding(self, masked_input.long())
|
||||||
|
sm.tag(output_parallel)
|
||||||
# Mask the output embedding.
|
# Mask the output embedding.
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from tqdm import tqdm
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
@@ -34,6 +35,9 @@ from sglang.srt.distributed import (
|
|||||||
parallel_state,
|
parallel_state,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
||||||
|
use_symmetric_memory,
|
||||||
|
)
|
||||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||||
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
||||||
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
||||||
@@ -524,8 +528,12 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
final_hidden_states *= self.routed_scaling_factor
|
final_hidden_states *= self.routed_scaling_factor
|
||||||
|
|
||||||
current_stream.wait_stream(self.alt_stream)
|
current_stream.wait_stream(self.alt_stream)
|
||||||
final_hidden_states += shared_output
|
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
||||||
|
final_hidden_states_out = torch.empty_like(final_hidden_states)
|
||||||
|
|
||||||
|
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
||||||
|
final_hidden_states = final_hidden_states_out
|
||||||
|
sm.tag(final_hidden_states)
|
||||||
if (
|
if (
|
||||||
self.tp_size > 1
|
self.tp_size > 1
|
||||||
and not should_allreduce_fusion
|
and not should_allreduce_fusion
|
||||||
@@ -563,8 +571,11 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
# fused in biased_grouped_topk so we can skip here
|
# fused in biased_grouped_topk so we can skip here
|
||||||
final_hidden_states *= self.routed_scaling_factor
|
final_hidden_states *= self.routed_scaling_factor
|
||||||
if shared_output is not None:
|
if shared_output is not None:
|
||||||
final_hidden_states = final_hidden_states + shared_output
|
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
||||||
|
final_hidden_states_out = torch.empty_like(final_hidden_states)
|
||||||
|
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
||||||
|
final_hidden_states = final_hidden_states_out
|
||||||
|
sm.tag(final_hidden_states)
|
||||||
if (
|
if (
|
||||||
self.tp_size > 1
|
self.tp_size > 1
|
||||||
and not should_allreduce_fusion
|
and not should_allreduce_fusion
|
||||||
|
|||||||
Reference in New Issue
Block a user