[Feature] Support DeepEP normal & Redundant Experts on NPU (#9881)
This commit is contained in:
@@ -55,7 +55,7 @@ class EPLBManager:
|
||||
enable_timing = self._rebalance_layers_per_chunk is None
|
||||
|
||||
if enable_timing:
|
||||
torch.cuda.synchronize()
|
||||
torch.get_device_module().synchronize()
|
||||
time_start = time.time()
|
||||
|
||||
dump_record_output = get_global_expert_distribution_recorder().dump_record(
|
||||
@@ -85,7 +85,7 @@ class EPLBManager:
|
||||
|
||||
msg = f"[EPLBManager] rebalance end"
|
||||
if enable_timing:
|
||||
torch.cuda.synchronize()
|
||||
torch.get_device_module().synchronize()
|
||||
time_end = time.time()
|
||||
msg += f" time={time_end - time_start:.3f}s"
|
||||
logger.info(msg)
|
||||
|
||||
@@ -30,7 +30,9 @@ import torch.distributed
|
||||
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import Withable, get_bool_env_var
|
||||
from sglang.srt.utils import Withable, get_bool_env_var, is_npu
|
||||
|
||||
_is_npu = is_npu()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.eplb.expert_location import ExpertLocationMetadata
|
||||
@@ -216,7 +218,9 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
|
||||
def _on_hook(self, hook_name: str, **kwargs):
|
||||
if self._disable_all:
|
||||
return
|
||||
if not (self._recording or torch.cuda.is_current_stream_capturing()):
|
||||
if not (
|
||||
self._recording or torch.get_device_module().is_current_stream_capturing()
|
||||
):
|
||||
return
|
||||
gatherer = self._single_pass_gatherers[
|
||||
self._accumulator.get_single_pass_gatherer_key(
|
||||
@@ -451,6 +455,10 @@ def _list_sum(a: List, b: List) -> List:
|
||||
class _LayerBasedGpuSinglePassGatherer(_SinglePassGatherer):
|
||||
def __init__(self, *args, enable_global_physical_experts: bool, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if not _is_npu:
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "npu"
|
||||
self._enable_global_physical_experts = enable_global_physical_experts
|
||||
self._data = torch.zeros(
|
||||
(
|
||||
@@ -462,7 +470,7 @@ class _LayerBasedGpuSinglePassGatherer(_SinglePassGatherer):
|
||||
),
|
||||
),
|
||||
dtype=torch.int,
|
||||
device="cuda",
|
||||
device=device,
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
@@ -784,7 +792,7 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin):
|
||||
|
||||
if self._first_dump:
|
||||
self._first_dump = False
|
||||
torch.cuda.empty_cache()
|
||||
torch.get_device_module().empty_cache()
|
||||
|
||||
torch.distributed.all_reduce(
|
||||
logical_count_of_buffered_step, op=torch.distributed.ReduceOp.SUM
|
||||
|
||||
@@ -47,7 +47,7 @@ class ExpertLocationUpdater:
|
||||
):
|
||||
if self._first_execution:
|
||||
self._first_execution = False
|
||||
torch.cuda.empty_cache()
|
||||
torch.get_device_module().empty_cache()
|
||||
|
||||
old_expert_location_metadata = get_global_expert_location_metadata()
|
||||
assert old_expert_location_metadata is not None
|
||||
|
||||
@@ -10,6 +10,7 @@ from torch.nn.functional import scaled_dot_product_attention
|
||||
from sglang.srt.configs.model_config import AttentionArch
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.layers.radix_attention import AttentionType
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
@@ -33,6 +34,7 @@ class ForwardMetadata:
|
||||
extend_seq_lens_cpu_int: Optional[torch.Tensor] = None
|
||||
seq_lens_cpu_int: Optional[torch.Tensor] = None
|
||||
seq_lens_cpu_list: Optional[List[int]] = None
|
||||
seq_lens_list_cumsum: Optional[List[int]] = None
|
||||
|
||||
|
||||
class AscendAttnBackend(AttentionBackend):
|
||||
@@ -83,6 +85,7 @@ class AscendAttnBackend(AttentionBackend):
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
"""Init the metadata for a forward pass."""
|
||||
tp_size = get_attention_tp_size()
|
||||
self.forward_metadata = ForwardMetadata()
|
||||
|
||||
self.forward_metadata.block_tables = (
|
||||
@@ -96,9 +99,13 @@ class AscendAttnBackend(AttentionBackend):
|
||||
forward_batch.extend_seq_lens.cpu().int()
|
||||
)
|
||||
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
|
||||
self.forward_metadata.seq_lens_list_cumsum = np.cumsum(
|
||||
forward_batch.extend_seq_lens_cpu
|
||||
)
|
||||
|
||||
seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
|
||||
if forward_batch.is_extend_in_batch:
|
||||
seq_lens_list_cumsum[-1] = (
|
||||
(seq_lens_list_cumsum[-1] - 1) // tp_size + 1
|
||||
) * tp_size
|
||||
self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
|
||||
|
||||
self.graph_mode = False
|
||||
|
||||
|
||||
@@ -35,7 +35,6 @@ from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip,
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
AscendDeepEPLLOutput,
|
||||
DeepEPLLOutput,
|
||||
DeepEPNormalOutput,
|
||||
DispatchOutput,
|
||||
@@ -454,7 +453,7 @@ class DeepEPMoE(EPMoE):
|
||||
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
||||
return self.forward_aiter(dispatch_output)
|
||||
if _is_npu:
|
||||
assert DispatchOutputChecker.format_is_ascent_ll(dispatch_output)
|
||||
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
|
||||
return self.forward_npu(dispatch_output)
|
||||
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
|
||||
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
||||
@@ -718,63 +717,124 @@ class DeepEPMoE(EPMoE):
|
||||
|
||||
def forward_npu(
|
||||
self,
|
||||
dispatch_output: DeepEPLLOutput,
|
||||
dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
|
||||
):
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(dispatch_output, AscendDeepEPLLOutput)
|
||||
hidden_states, topk_idx, topk_weights, _, seg_indptr, _ = dispatch_output
|
||||
assert self.quant_method is not None
|
||||
assert self.moe_runner_config.activation == "silu"
|
||||
|
||||
# NOTE: Ascend's Dispatch & Combine does not support FP16
|
||||
output_dtype = torch.bfloat16
|
||||
|
||||
pertoken_scale = hidden_states[1]
|
||||
hidden_states = hidden_states[0]
|
||||
|
||||
group_list_type = 1
|
||||
seg_indptr = seg_indptr.to(torch.int64)
|
||||
|
||||
import torch_npu
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w13_weight],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=seg_indptr,
|
||||
output_dtype=torch.int32,
|
||||
)[0]
|
||||
from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
|
||||
|
||||
# act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight_scale=self.w13_weight_scale.to(torch.float32),
|
||||
activation_scale=pertoken_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=seg_indptr,
|
||||
activate_left=True,
|
||||
quant_mode=1,
|
||||
)
|
||||
# NOTE: Ascend's Dispatch & Combine does not support FP16
|
||||
output_dtype = torch.bfloat16
|
||||
group_list_type = 1
|
||||
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w2_weight],
|
||||
scale=[self.w2_weight_scale.to(output_dtype)],
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=seg_indptr,
|
||||
output_dtype=output_dtype,
|
||||
)[0]
|
||||
def _forward_normal(dispatch_output: DeepEPNormalOutput):
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(dispatch_output, DeepEPNormalOutput)
|
||||
hidden_states, _, _, num_recv_tokens_per_expert = dispatch_output
|
||||
|
||||
return hidden_states
|
||||
if isinstance(hidden_states, tuple):
|
||||
per_token_scale = hidden_states[1]
|
||||
hidden_states = hidden_states[0]
|
||||
else:
|
||||
# dynamic quant
|
||||
hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states
|
||||
)
|
||||
|
||||
group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to(
|
||||
hidden_states.device
|
||||
)
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w13_weight],
|
||||
scale=[self.w13_weight_scale.to(output_dtype)],
|
||||
per_token_scale=[per_token_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=output_dtype,
|
||||
)[0]
|
||||
|
||||
# act_fn: swiglu
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
||||
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w2_weight],
|
||||
scale=[self.w2_weight_scale.to(output_dtype)],
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=output_dtype,
|
||||
)[0]
|
||||
|
||||
return hidden_states
|
||||
|
||||
def _forward_ll(dispatch_output: DeepEPLLOutput):
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(dispatch_output, DeepEPLLOutput)
|
||||
hidden_states, topk_idx, topk_weights, group_list, _ = dispatch_output
|
||||
|
||||
per_token_scale = hidden_states[1]
|
||||
hidden_states = hidden_states[0]
|
||||
|
||||
group_list = group_list.to(torch.int64)
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w13_weight],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=torch.int32,
|
||||
)[0]
|
||||
|
||||
# act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight_scale=self.w13_weight_scale.to(torch.float32),
|
||||
activation_scale=per_token_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=group_list,
|
||||
activate_left=True,
|
||||
quant_mode=1,
|
||||
)
|
||||
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w2_weight],
|
||||
scale=[self.w2_weight_scale.to(output_dtype)],
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=output_dtype,
|
||||
)[0]
|
||||
|
||||
return hidden_states
|
||||
|
||||
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
|
||||
return _forward_normal(dispatch_output)
|
||||
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
|
||||
return _forward_ll(dispatch_output)
|
||||
else:
|
||||
raise ValueError(f"Not Supported DeepEP format {dispatch_output.format}")
|
||||
|
||||
|
||||
def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
|
||||
|
||||
@@ -9,7 +9,6 @@ from sglang.srt.layers.moe.token_dispatcher.base import (
|
||||
DispatchOutputFormat,
|
||||
)
|
||||
from sglang.srt.layers.moe.token_dispatcher.deepep import (
|
||||
AscendDeepEPLLOutput,
|
||||
DeepEPConfig,
|
||||
DeepEPDispatcher,
|
||||
DeepEPLLCombineInput,
|
||||
@@ -23,7 +22,6 @@ from sglang.srt.layers.moe.token_dispatcher.standard import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AscendDeepEPLLOutput",
|
||||
"BaseDispatcher",
|
||||
"BaseDispatcherConfig",
|
||||
"CombineInput",
|
||||
|
||||
@@ -8,7 +8,6 @@ import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
AscendDeepEPLLOutput,
|
||||
DeepEPLLCombineInput,
|
||||
DeepEPLLOutput,
|
||||
DeepEPNormalCombineInput,
|
||||
@@ -47,19 +46,12 @@ class DispatchOutputChecker:
|
||||
) -> TypeGuard[Union[DeepEPNormalOutput, DeepEPLLOutput]]:
|
||||
return dispatch_output.format.is_deepep()
|
||||
|
||||
@staticmethod
|
||||
def format_is_ascent_ll(
|
||||
dispatch_output: DispatchOutput,
|
||||
) -> TypeGuard[AscendDeepEPLLOutput]:
|
||||
return dispatch_output.format.is_ascent_ll()
|
||||
|
||||
|
||||
class DispatchOutputFormat(Enum):
|
||||
|
||||
STANDARD = "standard"
|
||||
DEEPEP_NORMAL = "deepep_normal"
|
||||
DEEPEP_LL = "deepep_ll"
|
||||
ASCENT_LL = "ascent_ll"
|
||||
|
||||
def is_standard(self) -> bool:
|
||||
return self == DispatchOutputFormat.STANDARD
|
||||
@@ -76,9 +68,6 @@ class DispatchOutputFormat(Enum):
|
||||
DispatchOutputFormat.DEEPEP_LL,
|
||||
]
|
||||
|
||||
def is_ascent_ll(self) -> bool:
|
||||
return self == DispatchOutputFormat.ASCENT_LL
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class DispatchOutput(Protocol):
|
||||
|
||||
@@ -77,24 +77,8 @@ class DeepEPLLOutput(NamedTuple):
|
||||
return DispatchOutputFormat.DEEPEP_LL
|
||||
|
||||
|
||||
class AscendDeepEPLLOutput(NamedTuple):
|
||||
"""AscendDeepEP low latency dispatch output."""
|
||||
|
||||
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
|
||||
topk_idx: torch.Tensor
|
||||
topk_weights: torch.Tensor
|
||||
masked_m: torch.Tensor
|
||||
seg_indptr: torch.Tensor
|
||||
expected_m: int
|
||||
|
||||
@property
|
||||
def format(self) -> DispatchOutputFormat:
|
||||
return DispatchOutputFormat.ASCENT_LL
|
||||
|
||||
|
||||
assert isinstance(DeepEPNormalOutput, DispatchOutput)
|
||||
assert isinstance(DeepEPLLOutput, DispatchOutput)
|
||||
assert isinstance(AscendDeepEPLLOutput, DispatchOutput)
|
||||
|
||||
|
||||
class DeepEPNormalCombineInput(NamedTuple):
|
||||
@@ -434,12 +418,11 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
):
|
||||
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
deepep_post_reorder_triton_kernel,
|
||||
)
|
||||
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter:
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
|
||||
output = hidden_states
|
||||
else:
|
||||
if hidden_states.shape[0] > 0:
|
||||
@@ -553,23 +536,13 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
||||
masked_m
|
||||
)
|
||||
|
||||
if _is_npu:
|
||||
deepep_output = AscendDeepEPLLOutput(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
masked_m,
|
||||
self.handle[1],
|
||||
expected_m,
|
||||
)
|
||||
else:
|
||||
deepep_output = DeepEPLLOutput(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
masked_m,
|
||||
expected_m,
|
||||
)
|
||||
deepep_output = DeepEPLLOutput(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
masked_m,
|
||||
expected_m,
|
||||
)
|
||||
return deepep_output
|
||||
|
||||
def _dispatch_core(
|
||||
|
||||
@@ -330,6 +330,14 @@ class TopK(CustomOp):
|
||||
)
|
||||
topk_weights = topk_weights / topk_weights_sum
|
||||
|
||||
if expert_location_dispatch_info is not None:
|
||||
topk_ids = topk_ids_logical_to_physical(
|
||||
topk_ids, expert_location_dispatch_info
|
||||
)
|
||||
get_global_expert_distribution_recorder().on_select_experts(
|
||||
topk_ids=topk_ids
|
||||
)
|
||||
|
||||
return StandardTopKOutput(topk_weights, topk_ids, _)
|
||||
else:
|
||||
self.topk_config.torch_native = True
|
||||
|
||||
Reference in New Issue
Block a user