[bugfix]Enable dispatch_ffn_combine feature for qwen3.5 (#7066)
### What this PR does / why we need it?
Qwen3.5 Moe supports enabling the dispatch_ffn_combine fusion operator.
Fix problem: In the w8a8 quantization scene, Qwen3.5 model's config.json
lacks the quantize field. The previous logic strictly relied on
quant_type == "w8a8_dynamic" to enable VLLM_ASCEND_ENABLE_FUSED_MC2.
This caused the dispatch_ffn_combine fusion operator to fail to activate
even when the environment variable was set.
Enable dispatch_ffn_combine fusion operator for BF16 scenarios.
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: asunxiao <asunxiao@qq.com>
This commit is contained in:
@@ -245,14 +245,18 @@ def select_moe_comm_method(num_tokens: int, vllm_config: VllmConfig, is_draft_mo
|
|||||||
elif soc_version in {AscendDeviceType.A3}:
|
elif soc_version in {AscendDeviceType.A3}:
|
||||||
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
|
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
|
||||||
# TODO: drop speculative method guard when dispatch_gmm_combine_decode supports w16a16
|
# TODO: drop speculative method guard when dispatch_gmm_combine_decode supports w16a16
|
||||||
fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic"
|
fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2
|
||||||
dispatch_ffn_combine_enable = get_ep_group().world_size <= 32 and (not is_draft_model)
|
dispatch_ffn_combine_enable = get_ep_group().world_size <= 32 and (not is_draft_model)
|
||||||
if num_tokens <= mc2_tokens_capacity:
|
if num_tokens <= mc2_tokens_capacity:
|
||||||
fused_decode_enable = fused_mc2_enable
|
fused_decode_enable = fused_mc2_enable
|
||||||
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||||
fused_decode_enable = fused_mc2_enable and dispatch_ffn_combine_enable
|
fused_decode_enable = fused_mc2_enable and dispatch_ffn_combine_enable
|
||||||
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
|
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
|
||||||
fused_decode_enable = fused_mc2_enable and speculative_enable_dispatch_gmm_combine_decode(vllm_config)
|
fused_decode_enable = (
|
||||||
|
fused_mc2_enable
|
||||||
|
and speculative_enable_dispatch_gmm_combine_decode(vllm_config)
|
||||||
|
and quant_type == "w8a8_dynamic"
|
||||||
|
)
|
||||||
moe_comm_type = MoECommType.FUSED_MC2 if fused_decode_enable else MoECommType.MC2
|
moe_comm_type = MoECommType.FUSED_MC2 if fused_decode_enable else MoECommType.MC2
|
||||||
else:
|
else:
|
||||||
fused_prefill_enable = fused_mc2_enable
|
fused_prefill_enable = fused_mc2_enable
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from functools import wraps
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import torch_npu
|
||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group, tensor_model_parallel_all_reduce
|
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group, tensor_model_parallel_all_reduce
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
@@ -32,6 +33,7 @@ from vllm.model_executor.layers.fused_moe.router.fused_moe_router import FusedMo
|
|||||||
from vllm.model_executor.layers.fused_moe.runner.default_moe_runner import DefaultMoERunner # type: ignore
|
from vllm.model_executor.layers.fused_moe.runner.default_moe_runner import DefaultMoERunner # type: ignore
|
||||||
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
|
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
|
||||||
|
|
||||||
|
import vllm_ascend.envs as envs_ascend
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
|
||||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||||
@@ -41,6 +43,7 @@ from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_expe
|
|||||||
from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl, FusedExpertsResult, setup_moe_comm_method
|
from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl, FusedExpertsResult, setup_moe_comm_method
|
||||||
from vllm_ascend.quantization.methods.base import QuantType
|
from vllm_ascend.quantization.methods.base import QuantType
|
||||||
from vllm_ascend.utils import (
|
from vllm_ascend.utils import (
|
||||||
|
ACL_FORMAT_FRACTAL_NZ,
|
||||||
enable_sp,
|
enable_sp,
|
||||||
maybe_trans_nz,
|
maybe_trans_nz,
|
||||||
npu_stream_switch,
|
npu_stream_switch,
|
||||||
@@ -77,6 +80,16 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(1, 2).contiguous()
|
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(1, 2).contiguous()
|
||||||
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
||||||
|
|
||||||
|
# TODO: Current dispatch_ffn_combine fusion operator ONLY supports NZ format.
|
||||||
|
# Therefore, we must cast weights to NZ when fusion is enabled.
|
||||||
|
# Once the underlying dispatch_ffn_combine operator is updated to support
|
||||||
|
# ND format (or other formats), remove this specific 'if' check and the forced
|
||||||
|
# npu_format_cast. At that point, the operator should be able to handle weights
|
||||||
|
# in their native format without explicit casting here.
|
||||||
|
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2:
|
||||||
|
layer.w13_weight.data = torch_npu.npu_format_cast(layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||||
|
layer.w2_weight.data = torch_npu.npu_format_cast(layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||||
|
else:
|
||||||
layer.w13_weight.data = maybe_trans_nz(layer.w13_weight.data)
|
layer.w13_weight.data = maybe_trans_nz(layer.w13_weight.data)
|
||||||
layer.w2_weight.data = maybe_trans_nz(layer.w2_weight.data)
|
layer.w2_weight.data = maybe_trans_nz(layer.w2_weight.data)
|
||||||
|
|
||||||
@@ -144,10 +157,33 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype)
|
topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype)
|
||||||
|
|
||||||
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
||||||
|
# NOTE: In the MoECommType.FUSED_MC2 branch, we wrap weights (w1, w2) into lists
|
||||||
|
# and provide dummy scales (w1_scale, w2_scale). This is required because:
|
||||||
|
# The underlying Ascend fused operator (e.g., dispatch_ffn_combine) expects
|
||||||
|
# inputs in a list format.
|
||||||
|
# TODO: Passing an empty tensor as scale for float (BF16) cases is semantically
|
||||||
|
# incorrect. The ideal solution is to pass None. However, if the underlying
|
||||||
|
# dispatch_ffn_combine C++ operator does not support None for the scale argument
|
||||||
|
# (due to signature constraints), we are forced to use a placeholder empty tensor.
|
||||||
|
# This TODO tracks the requirement to update the C++ operator to accept Optional[Tensor]
|
||||||
|
# or None for scales in non-quantized scenarios.
|
||||||
|
if get_forward_context().moe_comm_type == MoECommType.FUSED_MC2:
|
||||||
|
w1 = [layer.w13_weight]
|
||||||
|
w1_scale = [torch.tensor([], dtype=torch.int64)]
|
||||||
|
w2 = [layer.w2_weight]
|
||||||
|
w2_scale = [torch.tensor([], dtype=torch.int64)]
|
||||||
|
else:
|
||||||
|
w1 = layer.w13_weight
|
||||||
|
w1_scale = None
|
||||||
|
w2 = layer.w2_weight
|
||||||
|
w2_scale = None
|
||||||
|
|
||||||
final_hidden_states = moe_comm_method.fused_experts(
|
final_hidden_states = moe_comm_method.fused_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=w1,
|
||||||
w2=layer.w2_weight,
|
w2=w2,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
w1_bias=layer.w13_bias if self.moe.has_bias else None,
|
w1_bias=layer.w13_bias if self.moe.has_bias else None,
|
||||||
w2_bias=layer.w2_bias if self.moe.has_bias else None,
|
w2_bias=layer.w2_bias if self.moe.has_bias else None,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
|
|||||||
Reference in New Issue
Block a user