add mxfp8 moe quantization (#6670)
### What this PR does / why we need it?
support mxfp8 quantization (Qwen MOE )
Using adaptor to make the hardware-specific behavior clearer and more
maintainable
### How was this patch tested?
- vLLM version: v0.15.0
- vLLM main:
13397841ab
---------
Signed-off-by: fangrongcan <17343701736@163.com>
Signed-off-by: wangyao-i <iwangyao@outlook.com>
Signed-off-by: linfeng-yuan <1102311262@qq.com>
Signed-off-by: Eric-dot <60131170+Eric-dot@users.noreply.github.com>
Co-authored-by: fangrongcan <f00876277@china.huawei.com>
Co-authored-by: wangyao-i <iwangyao@outlook.com>
Co-authored-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -28,6 +28,7 @@ import torch_npu
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
|
||||
from vllm_ascend.device.device_op import DeviceOperator
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.fused_moe.comm_utils import async_all_to_all, gather_from_sequence_parallel_region
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, is_hierarchical_communication_enabled
|
||||
@@ -103,8 +104,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
self.ep_rank_id = get_mc2_group().rank_in_group
|
||||
self.ep_world_size = get_mc2_group().world_size
|
||||
self.enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2")
|
||||
self.need_extra_args = get_ascend_device_type() == AscendDeviceType.A3
|
||||
|
||||
self.need_extra_args = get_ascend_device_type() in [AscendDeviceType.A3, AscendDeviceType.A5]
|
||||
self.a5_need_extra_args = get_ascend_device_type() == AscendDeviceType.A5
|
||||
# NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and
|
||||
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
|
||||
# improve communication performance.
|
||||
@@ -136,8 +137,21 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
expert_map: torch.Tensor,
|
||||
mc2_mask: torch.Tensor,
|
||||
global_redundant_expert_num: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
quant_mode = 2 if self.with_quant else 0
|
||||
use_mxfp_quant = kwargs.get("use_mxfp_quant", False)
|
||||
comm_quant_mode = kwargs.get("comm_quant_mode")
|
||||
# NOTE: quant_mode differs by quant feature:
|
||||
# - Legacy int communication quantization uses quant_mode=2.
|
||||
# - A5 MXFP8 communication uses quant_mode=4.
|
||||
# TODO(linfeng): The quantization-related parameters need to be consolidated into a single
|
||||
# dataclass, and the FP8 MoE code path should be integrated into it going forward.
|
||||
if comm_quant_mode is not None:
|
||||
quant_mode = comm_quant_mode
|
||||
elif self.with_quant:
|
||||
quant_mode = 4 if self.a5_need_extra_args and use_mxfp_quant else 2
|
||||
else:
|
||||
quant_mode = 0
|
||||
self.moe_expert_num = len(expert_map) + global_redundant_expert_num
|
||||
kwargs_mc2 = {
|
||||
"x": hidden_states,
|
||||
@@ -164,7 +178,12 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
"tp_rank_id": 0,
|
||||
}
|
||||
)
|
||||
if self.need_expert_scale:
|
||||
if self.a5_need_extra_args and use_mxfp_quant:
|
||||
y_dtype = kwargs.get("y_dtype")
|
||||
if self.with_quant:
|
||||
y_dtype = torch.float8_e4m3fn if y_dtype is None else y_dtype
|
||||
stage1_kwargs.update({"tp_world_size": 1, "tp_rank_id": 0, "y_dtype": y_dtype})
|
||||
if self.need_expert_scale or self.a5_need_extra_args:
|
||||
stage1_kwargs.update(
|
||||
{
|
||||
"expert_scales": topk_weights.to(torch.float32),
|
||||
@@ -186,11 +205,11 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
with_quant: bool = False,
|
||||
dynamic_eplb: bool = False,
|
||||
pertoken_scale: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.with_quant = with_quant
|
||||
|
||||
kwargs_mc2 = self.get_dispatch_mc2_kwargs(
|
||||
hidden_states, topk_weights, topk_ids, expert_map, mc2_mask, global_redundant_expert_num
|
||||
hidden_states, topk_weights, topk_ids, expert_map, mc2_mask, global_redundant_expert_num, **kwargs
|
||||
)
|
||||
output = (
|
||||
torch_npu.npu_moe_distribute_dispatch_v2(**kwargs_mc2)
|
||||
@@ -337,19 +356,16 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
first_expert_idx = 0
|
||||
last_expert_idx = self.num_experts_local
|
||||
global_num_experts = self.num_experts_local
|
||||
|
||||
sorted_hidden_states, expanded_row_idx, expert_tokens, pertoken_scale = (
|
||||
torch.ops._C_ascend.npu_moe_init_routing_custom(
|
||||
hidden_states,
|
||||
topk_ids,
|
||||
scale=pertoken_scale,
|
||||
active_num=num_tokens * self.top_k,
|
||||
expert_num=global_num_experts,
|
||||
expert_tokens_num_type=1,
|
||||
expert_tokens_num_flag=True,
|
||||
active_expert_range=[first_expert_idx, last_expert_idx],
|
||||
quant_mode=1 if self.with_quant and pertoken_scale is None else -1,
|
||||
)
|
||||
sorted_hidden_states, expanded_row_idx, expert_tokens, pertoken_scale = DeviceOperator.npu_moe_init_routing(
|
||||
hidden_states,
|
||||
topk_ids,
|
||||
scale=pertoken_scale,
|
||||
active_num=num_tokens * self.top_k,
|
||||
expert_num=global_num_experts,
|
||||
expert_tokens_num_type=1,
|
||||
expert_tokens_num_flag=True,
|
||||
active_expert_range=[first_expert_idx, last_expert_idx],
|
||||
quant_mode=1 if self.with_quant and pertoken_scale is None else -1,
|
||||
)
|
||||
expert_tokens = expert_tokens.to(torch.int64)
|
||||
group_list_type = 1 # `count` mode
|
||||
|
||||
Reference in New Issue
Block a user