add mxfp8 moe quantization (#6670)
### What this PR does / why we need it?
support mxfp8 quantization (Qwen MOE )
Using adaptor to make the hardware-specific behavior clearer and more
maintainable
### How was this patch tested?
- vLLM version: v0.15.0
- vLLM main:
13397841ab
---------
Signed-off-by: fangrongcan <17343701736@163.com>
Signed-off-by: wangyao-i <iwangyao@outlook.com>
Signed-off-by: linfeng-yuan <1102311262@qq.com>
Signed-off-by: Eric-dot <60131170+Eric-dot@users.noreply.github.com>
Co-authored-by: fangrongcan <f00876277@china.huawei.com>
Co-authored-by: wangyao-i <iwangyao@outlook.com>
Co-authored-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -38,6 +38,7 @@ from vllm_ascend.ops.fused_moe.token_dispatcher import (
|
||||
TokenDispatcherWithMC2,
|
||||
)
|
||||
from vllm_ascend.quantization.methods.base import QuantType
|
||||
from vllm_ascend.quantization.quant_parser import parse_mxfp_quant_params
|
||||
|
||||
_MoECommMethods: dict[MoECommType | None, MoECommMethod] = {}
|
||||
|
||||
@@ -129,6 +130,7 @@ class MoECommMethod(ABC):
|
||||
dynamic_eplb: bool = False,
|
||||
mc2_mask: torch.Tensor = None,
|
||||
pertoken_scale: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
# Check constraints
|
||||
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16, torch.int8]
|
||||
@@ -140,20 +142,36 @@ class MoECommMethod(ABC):
|
||||
# Apply log2phy if needed
|
||||
if log2phy is not None:
|
||||
topk_ids = log2phy[topk_ids]
|
||||
|
||||
dispatch_results = self.token_dispatcher.token_dispatch(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
global_redundant_expert_num=self.moe_config.global_redundant_expert_num,
|
||||
mc2_mask=mc2_mask,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
with_quant=use_int8_w8a8 or use_int4_w4a8,
|
||||
dynamic_eplb=dynamic_eplb,
|
||||
pertoken_scale=pertoken_scale,
|
||||
# TODO(linfeng): Current massive parameter passing is quite severe; parameter differences introduced
|
||||
# by different quantization modes will be consolidated into a dataclass in a follow-up.
|
||||
use_mxfp_quant = kwargs.get("use_mxfp_quant", False)
|
||||
dispatch_with_quant = use_int8_w8a8 or use_int4_w4a8 or use_mxfp_quant
|
||||
act_quant_type, weight_quant_type, scale_type, per_token_scale_type, round_mode = parse_mxfp_quant_params(
|
||||
**kwargs
|
||||
)
|
||||
|
||||
dispatch_kwargs = {
|
||||
"hidden_states": hidden_states,
|
||||
"topk_weights": topk_weights,
|
||||
"topk_ids": topk_ids,
|
||||
"expert_map": expert_map,
|
||||
"global_redundant_expert_num": self.moe_config.global_redundant_expert_num,
|
||||
"mc2_mask": mc2_mask,
|
||||
"apply_router_weight_on_input": apply_router_weight_on_input,
|
||||
"dynamic_eplb": dynamic_eplb,
|
||||
"pertoken_scale": pertoken_scale,
|
||||
}
|
||||
|
||||
if isinstance(self.token_dispatcher, TokenDispatcherWithMC2):
|
||||
dispatch_kwargs["with_quant"] = dispatch_with_quant
|
||||
dispatch_kwargs["comm_quant_mode"] = kwargs.get("comm_quant_mode")
|
||||
dispatch_kwargs["y_dtype"] = act_quant_type if use_mxfp_quant else None
|
||||
dispatch_kwargs["use_mxfp_quant"] = use_mxfp_quant
|
||||
else:
|
||||
dispatch_kwargs["with_quant"] = use_int8_w8a8 or use_int4_w4a8
|
||||
|
||||
dispatch_results = self.token_dispatcher.token_dispatch(**dispatch_kwargs)
|
||||
|
||||
mlp_output = unified_apply_mlp(
|
||||
hidden_states=dispatch_results.hidden_states,
|
||||
w1=w1,
|
||||
@@ -171,10 +189,18 @@ class MoECommMethod(ABC):
|
||||
w1_offset=w1_offset,
|
||||
w2_offset=w2_offset,
|
||||
topk_scales=dispatch_results.topk_scales,
|
||||
with_quant=use_int8_w8a8 or use_int4_w4a8 or use_int4_w4a16,
|
||||
fusion=use_int8_w8a8 and self.use_fusion_ops,
|
||||
with_quant=use_int8_w8a8 or use_int4_w4a8 or use_int4_w4a16 or use_mxfp_quant,
|
||||
fusion=(use_int8_w8a8 or use_mxfp_quant) and self.use_fusion_ops,
|
||||
need_trans=need_trans,
|
||||
dynamic_eplb=dynamic_eplb,
|
||||
use_mxfp_quant=use_mxfp_quant,
|
||||
act_quant_type=act_quant_type,
|
||||
weight_quant_type=weight_quant_type,
|
||||
scale_type=scale_type,
|
||||
per_token_scale_type=per_token_scale_type,
|
||||
round_mode=round_mode,
|
||||
use_bf16=(hidden_states.dtype == torch.bfloat16),
|
||||
rollback_quant_config=kwargs.get("rollback_quant_config"),
|
||||
)
|
||||
|
||||
before_combine_evt = torch.npu.current_stream().record_event()
|
||||
@@ -317,6 +343,7 @@ class FusedMC2CommImpl(MoECommMethod):
|
||||
dynamic_eplb: bool = False,
|
||||
mc2_mask: torch.Tensor = None,
|
||||
pertoken_scale: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert not (w1_scale is None or w2_scale is None), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl."
|
||||
|
||||
|
||||
@@ -22,7 +22,11 @@ from vllm.forward_context import get_forward_context
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.device.device_op import DeviceOperator
|
||||
from vllm_ascend.ops.activation import AscendSwigluOAIAndMul
|
||||
from vllm_ascend.quantization.mxfp_compat import (
|
||||
ensure_mxfp8_moe_available,
|
||||
)
|
||||
from vllm_ascend.utils import (
|
||||
dispose_tensor,
|
||||
enable_custom_op,
|
||||
@@ -66,12 +70,22 @@ def cumsum_group_list(
|
||||
)
|
||||
|
||||
|
||||
def _require_single_tensor_for_swiglu_quant(
|
||||
tensor_or_list: list[torch.Tensor] | torch.Tensor, *, name: str
|
||||
) -> torch.Tensor:
|
||||
if isinstance(tensor_or_list, list):
|
||||
if len(tensor_or_list) != 1:
|
||||
raise ValueError(f"{name} must be a tensor or a single-element list, but got {len(tensor_or_list)}.")
|
||||
return tensor_or_list[0]
|
||||
return tensor_or_list
|
||||
|
||||
|
||||
def quant_apply_mlp(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: list[torch.Tensor],
|
||||
w1_scale: list[torch.Tensor],
|
||||
w2: list[torch.Tensor],
|
||||
w2_scale: list[torch.Tensor],
|
||||
w1: list[torch.Tensor] | torch.Tensor,
|
||||
w1_scale: list[torch.Tensor] | torch.Tensor,
|
||||
w2: list[torch.Tensor] | torch.Tensor,
|
||||
w2_scale: list[torch.Tensor] | torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
group_list_type: int = 1,
|
||||
dynamic_scale: torch.Tensor = None,
|
||||
@@ -81,15 +95,45 @@ def quant_apply_mlp(
|
||||
w2_offset: torch.Tensor | None = None,
|
||||
fusion: bool = False,
|
||||
dynamic_eplb: bool = False,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
# TODO(linfeng): Current massive parameter passing is quite severe; parameter differences introduced by different
|
||||
# quantization modes will be consolidated into a dataclass in a follow-up.
|
||||
use_mxfp_quant = kwargs.get("use_mxfp_quant", False)
|
||||
act_quant_type = torch.float8_e4m3fn
|
||||
weight_quant_type = None
|
||||
scale_type = None
|
||||
per_token_scale_type = None
|
||||
use_bf16 = True
|
||||
|
||||
input_hidden_dtype = hidden_states.dtype
|
||||
use_gmm_swiglu_quant_fusion = use_mxfp_quant or (fusion and not dynamic_eplb)
|
||||
|
||||
if use_mxfp_quant:
|
||||
act_quant_type = kwargs.get("act_quant_type", torch.float8_e4m3fn)
|
||||
weight_quant_type = kwargs.get("weight_quant_type", torch.float8_e4m3fn)
|
||||
scale_type = kwargs.get("scale_type")
|
||||
per_token_scale_type = kwargs.get("per_token_scale_type")
|
||||
use_bf16 = kwargs.get("use_bf16", True)
|
||||
|
||||
ensure_mxfp8_moe_available("MXFP MoE MLP path")
|
||||
|
||||
if w1_scale_bias is not None or w2_scale_bias is not None:
|
||||
raise NotImplementedError("MXFP path does not support scale_bias yet.")
|
||||
if w1_offset is not None or w2_offset is not None:
|
||||
raise NotImplementedError("MXFP path does not support antiquant offset yet.")
|
||||
|
||||
if w1_offset is not None:
|
||||
unquantized_hidden_states = hidden_states
|
||||
quantized_hidden_states = None
|
||||
elif dynamic_scale is None:
|
||||
unquantized_hidden_states = hidden_states
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
||||
# Dispose the original unquantized hidden states
|
||||
# to save npu memory because they're no longer used.
|
||||
hidden_states, pertoken_scale = DeviceOperator.npu_dynamic_quant(
|
||||
hidden_states=hidden_states,
|
||||
dynamic_scale=None,
|
||||
act_quant_type=act_quant_type,
|
||||
use_mxfp_quant=use_mxfp_quant,
|
||||
)
|
||||
dispose_tensor(unquantized_hidden_states)
|
||||
quantized_hidden_states = None
|
||||
else:
|
||||
@@ -98,13 +142,14 @@ def quant_apply_mlp(
|
||||
quantized_hidden_states = hidden_states
|
||||
|
||||
bias1, bias2 = None, None
|
||||
_output_dtype = w2_scale[0].dtype
|
||||
_output_dtype = w2_scale[0].dtype if isinstance(w2_scale, list) else w2_scale.dtype
|
||||
|
||||
weight_prefetch_method = get_weight_prefetch_method()
|
||||
weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(hidden_states)
|
||||
if weight_prefetch_method:
|
||||
weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(hidden_states)
|
||||
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
|
||||
if w1_scale_bias is None and w1_offset is None and is_mc2:
|
||||
if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
|
||||
if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb) and not use_mxfp_quant:
|
||||
# gmm1: gate_up_proj & act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale, _ = torch.ops._C_ascend.grouped_matmul_swiglu_quant_weight_nz_tensor_list(
|
||||
x=hidden_states,
|
||||
@@ -113,14 +158,16 @@ def quant_apply_mlp(
|
||||
x_scale=pertoken_scale,
|
||||
group_list=cumsum_group_list(group_list, group_list_type, 0),
|
||||
)
|
||||
elif fusion and not dynamic_eplb:
|
||||
elif use_gmm_swiglu_quant_fusion:
|
||||
# gmm1: gate_up_proj & act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
|
||||
hidden_states, swiglu_out_scale, _ = DeviceOperator.npu_grouped_matmul_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight=w1[0],
|
||||
weight=_require_single_tensor_for_swiglu_quant(w1, name="w1"),
|
||||
group_list=cumsum_group_list(group_list, group_list_type, 0),
|
||||
weight_scale=w1_scale[0],
|
||||
weight_scale=_require_single_tensor_for_swiglu_quant(w1_scale, name="w1_scale"),
|
||||
x_scale=pertoken_scale,
|
||||
bias=None,
|
||||
use_mxfp_quant=use_mxfp_quant,
|
||||
)
|
||||
if quantized_hidden_states is not None:
|
||||
dispose_tensor(quantized_hidden_states)
|
||||
@@ -152,17 +199,23 @@ def quant_apply_mlp(
|
||||
quant_mode=1,
|
||||
)
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
hidden_states = DeviceOperator.npu_grouped_matmul_gmm2(
|
||||
hidden_states=hidden_states,
|
||||
weight=w2,
|
||||
scale=w2_scale,
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
weight_scale=w2_scale,
|
||||
per_token_scale=swiglu_out_scale,
|
||||
group_list=group_list,
|
||||
output_dtype=w2_scale[0].dtype,
|
||||
)[0]
|
||||
group_list_type=group_list_type,
|
||||
input_dtype=input_hidden_dtype,
|
||||
act_quant_type=act_quant_type,
|
||||
weight_quant_type=weight_quant_type,
|
||||
scale_type=scale_type,
|
||||
per_token_scale_type=per_token_scale_type,
|
||||
use_bf16=use_bf16,
|
||||
use_mxfp_quant=use_mxfp_quant,
|
||||
bias=None,
|
||||
fallback_output_dtype=w2_scale[0].dtype if isinstance(w2_scale, list) else w2_scale.dtype,
|
||||
)
|
||||
elif w1_offset is not None:
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
@@ -201,7 +254,7 @@ def quant_apply_mlp(
|
||||
# TODO w4a8 scene: dynamic acquisition of dtype in the future
|
||||
_output_dtype = torch.bfloat16
|
||||
|
||||
if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
|
||||
if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb) and not use_mxfp_quant:
|
||||
# gmm1: gate_up_proj & act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale, _ = torch.ops._C_ascend.grouped_matmul_swiglu_quant_weight_nz_tensor_list(
|
||||
x=hidden_states,
|
||||
@@ -211,15 +264,15 @@ def quant_apply_mlp(
|
||||
group_list=cumsum_group_list(group_list, group_list_type, 0),
|
||||
bias=bias1,
|
||||
)
|
||||
elif fusion and not dynamic_eplb:
|
||||
# gmm1: gate_up_proj & act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
|
||||
elif use_gmm_swiglu_quant_fusion:
|
||||
hidden_states, swiglu_out_scale, _ = DeviceOperator.npu_grouped_matmul_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight=w1[0],
|
||||
bias=bias1,
|
||||
weight=_require_single_tensor_for_swiglu_quant(w1, name="w1"),
|
||||
group_list=cumsum_group_list(group_list, group_list_type, 0),
|
||||
weight_scale=w1_scale[0],
|
||||
weight_scale=_require_single_tensor_for_swiglu_quant(w1_scale, name="w1_scale"),
|
||||
x_scale=pertoken_scale,
|
||||
bias=bias1,
|
||||
use_mxfp_quant=use_mxfp_quant,
|
||||
)
|
||||
if quantized_hidden_states is not None:
|
||||
dispose_tensor(quantized_hidden_states)
|
||||
@@ -251,18 +304,23 @@ def quant_apply_mlp(
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
hidden_states = DeviceOperator.npu_grouped_matmul_gmm2(
|
||||
hidden_states=hidden_states,
|
||||
weight=w2,
|
||||
scale=w2_scale,
|
||||
bias=bias2,
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
weight_scale=w2_scale,
|
||||
per_token_scale=swiglu_out_scale,
|
||||
group_list=group_list,
|
||||
output_dtype=_output_dtype,
|
||||
)[0]
|
||||
group_list_type=group_list_type,
|
||||
input_dtype=input_hidden_dtype,
|
||||
act_quant_type=act_quant_type,
|
||||
weight_quant_type=weight_quant_type,
|
||||
scale_type=scale_type,
|
||||
per_token_scale_type=per_token_scale_type,
|
||||
use_bf16=use_bf16,
|
||||
use_mxfp_quant=use_mxfp_quant,
|
||||
bias=bias2,
|
||||
fallback_output_dtype=_output_dtype,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -334,26 +392,13 @@ def unified_apply_mlp(
|
||||
fusion: bool = False,
|
||||
need_trans: bool = True,
|
||||
dynamic_eplb: bool = False,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if with_quant:
|
||||
assert w1_scale is not None and w2_scale is not None
|
||||
return quant_apply_mlp(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
dynamic_scale=dynamic_scale,
|
||||
group_list_type=group_list_type,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
w1_offset=w1_offset,
|
||||
w2_offset=w2_offset,
|
||||
fusion=fusion,
|
||||
dynamic_eplb=dynamic_eplb,
|
||||
)
|
||||
else:
|
||||
"""
|
||||
Unified MoE MLP entry.
|
||||
Quant path is dispatched by DeviceOperator with explicit quant-type flags.
|
||||
"""
|
||||
if not with_quant:
|
||||
return unquant_apply_mlp(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
@@ -366,3 +411,34 @@ def unified_apply_mlp(
|
||||
topk_scales=topk_scales,
|
||||
need_trans=need_trans,
|
||||
)
|
||||
|
||||
assert w1_scale is not None and w2_scale is not None
|
||||
# TODO(linfeng): Current massive parameter passing is quite severe; parameter differences introduced by different
|
||||
# quantization modes will be consolidated into a dataclass in a follow-up.
|
||||
act_quant_type = kwargs.get("act_quant_type", torch.float8_e4m3fn)
|
||||
weight_quant_type = kwargs.get("weight_quant_type", torch.float8_e4m3fn)
|
||||
scale_type = kwargs.get("scale_type")
|
||||
per_token_scale_type = kwargs.get("per_token_scale_type")
|
||||
use_mxfp_quant = kwargs.get("use_mxfp_quant", False)
|
||||
return quant_apply_mlp(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
dynamic_scale=dynamic_scale,
|
||||
group_list_type=group_list_type,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
w1_offset=w1_offset,
|
||||
w2_offset=w2_offset,
|
||||
fusion=fusion,
|
||||
dynamic_eplb=dynamic_eplb,
|
||||
act_quant_type=act_quant_type,
|
||||
weight_quant_type=weight_quant_type,
|
||||
scale_type=scale_type,
|
||||
per_token_scale_type=per_token_scale_type,
|
||||
use_mxfp_quant=use_mxfp_quant,
|
||||
use_bf16=kwargs.get("use_bf16", True),
|
||||
)
|
||||
|
||||
@@ -76,7 +76,7 @@ class PrepareAndFinalize(ABC):
|
||||
router_logits (torch.Tensor): Router outputs, shape [num_tokens, num_experts]
|
||||
enable_shared_expert_dp (bool): Skip DP communication for shared experts
|
||||
replace_allreduce (bool): Bypass default all-reduce behavior
|
||||
quant_type: none, w8a8 or w4a8
|
||||
quant_type: none, w8a8, w4a8 or mxfp8
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
@@ -323,6 +323,10 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
pertoken_scale = None
|
||||
if quant_type == QuantType.W8A8:
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
||||
elif quant_type == QuantType.MXFP8:
|
||||
# TODO(linfeng): MXFP8 with AllGather+EP currently does not pre-quantize
|
||||
# per-token activations in prepare. Keep quantization in the MoE MLP path.
|
||||
pass
|
||||
|
||||
if self.multistream_overlap_gate:
|
||||
assert PrepareAndFinalize.quant_stream is not None
|
||||
|
||||
@@ -28,6 +28,7 @@ import torch_npu
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
|
||||
from vllm_ascend.device.device_op import DeviceOperator
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.fused_moe.comm_utils import async_all_to_all, gather_from_sequence_parallel_region
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, is_hierarchical_communication_enabled
|
||||
@@ -103,8 +104,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
self.ep_rank_id = get_mc2_group().rank_in_group
|
||||
self.ep_world_size = get_mc2_group().world_size
|
||||
self.enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2")
|
||||
self.need_extra_args = get_ascend_device_type() == AscendDeviceType.A3
|
||||
|
||||
self.need_extra_args = get_ascend_device_type() in [AscendDeviceType.A3, AscendDeviceType.A5]
|
||||
self.a5_need_extra_args = get_ascend_device_type() == AscendDeviceType.A5
|
||||
# NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and
|
||||
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
|
||||
# improve communication performance.
|
||||
@@ -136,8 +137,21 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
expert_map: torch.Tensor,
|
||||
mc2_mask: torch.Tensor,
|
||||
global_redundant_expert_num: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
quant_mode = 2 if self.with_quant else 0
|
||||
use_mxfp_quant = kwargs.get("use_mxfp_quant", False)
|
||||
comm_quant_mode = kwargs.get("comm_quant_mode")
|
||||
# NOTE: quant_mode differs by quant feature:
|
||||
# - Legacy int communication quantization uses quant_mode=2.
|
||||
# - A5 MXFP8 communication uses quant_mode=4.
|
||||
# TODO(linfeng): The quantization-related parameters need to be consolidated into a single
|
||||
# dataclass, and the FP8 MoE code path should be integrated into it going forward.
|
||||
if comm_quant_mode is not None:
|
||||
quant_mode = comm_quant_mode
|
||||
elif self.with_quant:
|
||||
quant_mode = 4 if self.a5_need_extra_args and use_mxfp_quant else 2
|
||||
else:
|
||||
quant_mode = 0
|
||||
self.moe_expert_num = len(expert_map) + global_redundant_expert_num
|
||||
kwargs_mc2 = {
|
||||
"x": hidden_states,
|
||||
@@ -164,7 +178,12 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
"tp_rank_id": 0,
|
||||
}
|
||||
)
|
||||
if self.need_expert_scale:
|
||||
if self.a5_need_extra_args and use_mxfp_quant:
|
||||
y_dtype = kwargs.get("y_dtype")
|
||||
if self.with_quant:
|
||||
y_dtype = torch.float8_e4m3fn if y_dtype is None else y_dtype
|
||||
stage1_kwargs.update({"tp_world_size": 1, "tp_rank_id": 0, "y_dtype": y_dtype})
|
||||
if self.need_expert_scale or self.a5_need_extra_args:
|
||||
stage1_kwargs.update(
|
||||
{
|
||||
"expert_scales": topk_weights.to(torch.float32),
|
||||
@@ -186,11 +205,11 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
with_quant: bool = False,
|
||||
dynamic_eplb: bool = False,
|
||||
pertoken_scale: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.with_quant = with_quant
|
||||
|
||||
kwargs_mc2 = self.get_dispatch_mc2_kwargs(
|
||||
hidden_states, topk_weights, topk_ids, expert_map, mc2_mask, global_redundant_expert_num
|
||||
hidden_states, topk_weights, topk_ids, expert_map, mc2_mask, global_redundant_expert_num, **kwargs
|
||||
)
|
||||
output = (
|
||||
torch_npu.npu_moe_distribute_dispatch_v2(**kwargs_mc2)
|
||||
@@ -337,19 +356,16 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
first_expert_idx = 0
|
||||
last_expert_idx = self.num_experts_local
|
||||
global_num_experts = self.num_experts_local
|
||||
|
||||
sorted_hidden_states, expanded_row_idx, expert_tokens, pertoken_scale = (
|
||||
torch.ops._C_ascend.npu_moe_init_routing_custom(
|
||||
hidden_states,
|
||||
topk_ids,
|
||||
scale=pertoken_scale,
|
||||
active_num=num_tokens * self.top_k,
|
||||
expert_num=global_num_experts,
|
||||
expert_tokens_num_type=1,
|
||||
expert_tokens_num_flag=True,
|
||||
active_expert_range=[first_expert_idx, last_expert_idx],
|
||||
quant_mode=1 if self.with_quant and pertoken_scale is None else -1,
|
||||
)
|
||||
sorted_hidden_states, expanded_row_idx, expert_tokens, pertoken_scale = DeviceOperator.npu_moe_init_routing(
|
||||
hidden_states,
|
||||
topk_ids,
|
||||
scale=pertoken_scale,
|
||||
active_num=num_tokens * self.top_k,
|
||||
expert_num=global_num_experts,
|
||||
expert_tokens_num_type=1,
|
||||
expert_tokens_num_flag=True,
|
||||
active_expert_range=[first_expert_idx, last_expert_idx],
|
||||
quant_mode=1 if self.with_quant and pertoken_scale is None else -1,
|
||||
)
|
||||
expert_tokens = expert_tokens.to(torch.int64)
|
||||
group_list_type = 1 # `count` mode
|
||||
|
||||
Reference in New Issue
Block a user