add mxfp8 moe quantization (#6670)
### What this PR does / why we need it?
support mxfp8 quantization (Qwen MOE )
Using adaptor to make the hardware-specific behavior clearer and more
maintainable
### How was this patch tested?
- vLLM version: v0.15.0
- vLLM main:
13397841ab
---------
Signed-off-by: fangrongcan <17343701736@163.com>
Signed-off-by: wangyao-i <iwangyao@outlook.com>
Signed-off-by: linfeng-yuan <1102311262@qq.com>
Signed-off-by: Eric-dot <60131170+Eric-dot@users.noreply.github.com>
Co-authored-by: fangrongcan <f00876277@china.huawei.com>
Co-authored-by: wangyao-i <iwangyao@outlook.com>
Co-authored-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -38,6 +38,7 @@ from vllm_ascend.ops.fused_moe.token_dispatcher import (
|
||||
TokenDispatcherWithMC2,
|
||||
)
|
||||
from vllm_ascend.quantization.methods.base import QuantType
|
||||
from vllm_ascend.quantization.quant_parser import parse_mxfp_quant_params
|
||||
|
||||
_MoECommMethods: dict[MoECommType | None, MoECommMethod] = {}
|
||||
|
||||
@@ -129,6 +130,7 @@ class MoECommMethod(ABC):
|
||||
dynamic_eplb: bool = False,
|
||||
mc2_mask: torch.Tensor = None,
|
||||
pertoken_scale: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
# Check constraints
|
||||
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16, torch.int8]
|
||||
@@ -140,20 +142,36 @@ class MoECommMethod(ABC):
|
||||
# Apply log2phy if needed
|
||||
if log2phy is not None:
|
||||
topk_ids = log2phy[topk_ids]
|
||||
|
||||
dispatch_results = self.token_dispatcher.token_dispatch(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
global_redundant_expert_num=self.moe_config.global_redundant_expert_num,
|
||||
mc2_mask=mc2_mask,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
with_quant=use_int8_w8a8 or use_int4_w4a8,
|
||||
dynamic_eplb=dynamic_eplb,
|
||||
pertoken_scale=pertoken_scale,
|
||||
# TODO(linfeng): Current massive parameter passing is quite severe; parameter differences introduced
|
||||
# by different quantization modes will be consolidated into a dataclass in a follow-up.
|
||||
use_mxfp_quant = kwargs.get("use_mxfp_quant", False)
|
||||
dispatch_with_quant = use_int8_w8a8 or use_int4_w4a8 or use_mxfp_quant
|
||||
act_quant_type, weight_quant_type, scale_type, per_token_scale_type, round_mode = parse_mxfp_quant_params(
|
||||
**kwargs
|
||||
)
|
||||
|
||||
dispatch_kwargs = {
|
||||
"hidden_states": hidden_states,
|
||||
"topk_weights": topk_weights,
|
||||
"topk_ids": topk_ids,
|
||||
"expert_map": expert_map,
|
||||
"global_redundant_expert_num": self.moe_config.global_redundant_expert_num,
|
||||
"mc2_mask": mc2_mask,
|
||||
"apply_router_weight_on_input": apply_router_weight_on_input,
|
||||
"dynamic_eplb": dynamic_eplb,
|
||||
"pertoken_scale": pertoken_scale,
|
||||
}
|
||||
|
||||
if isinstance(self.token_dispatcher, TokenDispatcherWithMC2):
|
||||
dispatch_kwargs["with_quant"] = dispatch_with_quant
|
||||
dispatch_kwargs["comm_quant_mode"] = kwargs.get("comm_quant_mode")
|
||||
dispatch_kwargs["y_dtype"] = act_quant_type if use_mxfp_quant else None
|
||||
dispatch_kwargs["use_mxfp_quant"] = use_mxfp_quant
|
||||
else:
|
||||
dispatch_kwargs["with_quant"] = use_int8_w8a8 or use_int4_w4a8
|
||||
|
||||
dispatch_results = self.token_dispatcher.token_dispatch(**dispatch_kwargs)
|
||||
|
||||
mlp_output = unified_apply_mlp(
|
||||
hidden_states=dispatch_results.hidden_states,
|
||||
w1=w1,
|
||||
@@ -171,10 +189,18 @@ class MoECommMethod(ABC):
|
||||
w1_offset=w1_offset,
|
||||
w2_offset=w2_offset,
|
||||
topk_scales=dispatch_results.topk_scales,
|
||||
with_quant=use_int8_w8a8 or use_int4_w4a8 or use_int4_w4a16,
|
||||
fusion=use_int8_w8a8 and self.use_fusion_ops,
|
||||
with_quant=use_int8_w8a8 or use_int4_w4a8 or use_int4_w4a16 or use_mxfp_quant,
|
||||
fusion=(use_int8_w8a8 or use_mxfp_quant) and self.use_fusion_ops,
|
||||
need_trans=need_trans,
|
||||
dynamic_eplb=dynamic_eplb,
|
||||
use_mxfp_quant=use_mxfp_quant,
|
||||
act_quant_type=act_quant_type,
|
||||
weight_quant_type=weight_quant_type,
|
||||
scale_type=scale_type,
|
||||
per_token_scale_type=per_token_scale_type,
|
||||
round_mode=round_mode,
|
||||
use_bf16=(hidden_states.dtype == torch.bfloat16),
|
||||
rollback_quant_config=kwargs.get("rollback_quant_config"),
|
||||
)
|
||||
|
||||
before_combine_evt = torch.npu.current_stream().record_event()
|
||||
@@ -317,6 +343,7 @@ class FusedMC2CommImpl(MoECommMethod):
|
||||
dynamic_eplb: bool = False,
|
||||
mc2_mask: torch.Tensor = None,
|
||||
pertoken_scale: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert not (w1_scale is None or w2_scale is None), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl."
|
||||
|
||||
|
||||
Reference in New Issue
Block a user