[4/N][refactor]delete torchair from quantization (#2535)

### What this PR does / why we need it?
After moved torchair related quantization section into
torchair_quantization, split the torchair from the origin quantization

### Does this PR introduce _any_ user-facing change?
NO

### How was this patch tested?
vLLM version: main
vLLM main:
ab9f2cfd19


- vLLM version: v0.10.1.1
- vLLM main:
69244e67e6

Signed-off-by: hust17yixuan <303660421@qq.com>
This commit is contained in:
Wang Yixuan
2025-08-28 09:10:03 +08:00
committed by GitHub
parent c578f817ca
commit a955e5d404
3 changed files with 16 additions and 42 deletions

View File

@@ -39,14 +39,10 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
@patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config')
@patch('vllm_ascend.quantization.w4a8_dynamic.get_ep_group')
@patch("vllm_ascend.ascend_config.get_ascend_config")
@patch('vllm_ascend.quantization.w4a8_dynamic.get_mc2_group')
@patch('torch.distributed.get_rank', return_value=0)
def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ascend_config,
mock_get_ep_group, get_current_vllm_config):
mock_ascend_config = Mock()
mock_ascend_config.torchair_graph_config = Mock(enabled=False)
mock_get_ascend_config.return_value = mock_ascend_config
def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ep_group,
get_current_vllm_config):
mock_vllm_config = Mock()
mock_vllm_config.quant_config = Mock(quant_description={
"group_size": self.group_size,

View File

@@ -24,13 +24,11 @@ from vllm.config import get_current_vllm_config
from vllm.distributed import get_ep_group
from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.quantization.w8a8_dynamic import (fused_experts_with_all2all,
fused_experts_with_mc2)
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
class AscendW4A8DynamicLinearMethod:
@@ -133,9 +131,6 @@ class AscendW4A8DynamicFusedMoEMethod:
self.ep_group = get_ep_group()
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
vllm_config = get_current_vllm_config()
self.group_size = vllm_config.quant_config.quant_description.get(
"group_size", 256)
@@ -284,12 +279,10 @@ class AscendW4A8DynamicFusedMoEMethod:
fused_moe_state = get_forward_context().fused_moe_state
shared_gate_up, shared_dequant_scale = None, None
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
with npu_stream_switch("moe_secondary", 0):
npu_wait_tensor(quantized_x_for_share, router_logits)
share_up_out, _ = shared_experts.gate_up_proj(
(quantized_x_for_share, dynamic_scale_for_share))
shared_gate_up, shared_dequant_scale = share_up_out[
0], share_up_out[1]
share_up_out, _ = shared_experts.gate_up_proj(
(quantized_x_for_share, dynamic_scale_for_share))
shared_gate_up, shared_dequant_scale = share_up_out[
0], share_up_out[1]
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
@@ -315,7 +308,6 @@ class AscendW4A8DynamicFusedMoEMethod:
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
shared_experts=shared_experts,
is_torchair=self.torchair_graph_enabled,
quantized_x_for_share=shared_gate_up,
dynamic_scale_for_share=shared_dequant_scale,
mc2_mask=kwargs.get("mc2_mask", None))

View File

@@ -24,11 +24,9 @@ from vllm.distributed import GroupCoordinator, get_ep_group
from vllm.forward_context import get_forward_context
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
dispose_tensor, get_ascend_soc_version)
@@ -213,7 +211,6 @@ def fused_experts_with_mc2(
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
is_torchair: bool = False,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
@@ -232,8 +229,7 @@ def fused_experts_with_mc2(
ep_world_size = ep_group.world_size
# NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
or is_torchair)
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3)
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3
@@ -282,11 +278,9 @@ def fused_experts_with_mc2(
0:5]
if shared_experts is not None:
with npu_stream_switch("moe_secondary", 0):
npu_wait_tensor(shared_gate_up, expand_x)
shared_act_out = shared_experts.act_fn(
(shared_gate_up, shared_dequant_scale))
shared_act, swiglu_out_scale = shared_act_out[0], shared_act_out[1]
shared_act_out = shared_experts.act_fn(
(shared_gate_up, shared_dequant_scale))
shared_act, swiglu_out_scale = shared_act_out[0], shared_act_out[1]
# `expand_x` will be disposed in the `apply_mlp` function
if w1_scale_bias is None:
@@ -358,10 +352,8 @@ def fused_experts_with_mc2(
if shared_experts is None:
return hidden_states
else:
with npu_stream_switch("moe_secondary", 0):
npu_wait_tensor(shared_act, down_out_list)
shared_output, _ = shared_experts.down_proj(
(shared_act, swiglu_out_scale))
shared_output, _ = shared_experts.down_proj(
(shared_act, swiglu_out_scale))
return hidden_states, shared_output
@@ -806,9 +798,6 @@ class AscendW8A8DynamicFusedMoEMethod:
self.ep_group = get_ep_group()
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
try:
device_group = get_mc2_group().device_group
# TODO: Try local_rank = ep_group.rank_in_group
@@ -904,12 +893,10 @@ class AscendW8A8DynamicFusedMoEMethod:
fused_moe_state = get_forward_context().fused_moe_state
shared_gate_up, shared_dequant_scale = None, None
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
with npu_stream_switch("moe_secondary", 0):
npu_wait_tensor(quantized_x_for_share, router_logits)
share_up_out, _ = shared_experts.gate_up_proj(
(quantized_x_for_share, dynamic_scale_for_share))
shared_gate_up, shared_dequant_scale = share_up_out[
0], share_up_out[1]
share_up_out, _ = shared_experts.gate_up_proj(
(quantized_x_for_share, dynamic_scale_for_share))
shared_gate_up, shared_dequant_scale = share_up_out[
0], share_up_out[1]
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
@@ -944,7 +931,6 @@ class AscendW8A8DynamicFusedMoEMethod:
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
shared_experts=shared_experts,
is_torchair=self.torchair_graph_enabled,
mc2_mask=kwargs.get("mc2_mask", None),
shared_gate_up=shared_gate_up,
shared_dequant_scale=shared_dequant_scale)