[main] [bugfix] Fix misjudging quantized/unquantized scenarios (#2627)
### What this PR does / why we need it?
In a mixed-precision scenario, quant_config is not None, but MoE needs
to perform unquantized computation; however, quantized computation is
currently being used. Therefore, we put the with_quant logic into
forward, avoid misjudging in mix-precision scenarios.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
e2e & ut
- vLLM version: v0.10.1.1
- vLLM main:
98ac0cb32d
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
@@ -543,7 +543,6 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
mock_get_forward_context):
|
||||
|
||||
mock_forward_context = MagicMock()
|
||||
mock_forward_context.with_quant = True
|
||||
mock_forward_context.fused_moe_state = FusedMoEState.MC2
|
||||
mock_get_forward_context.return_value = mock_forward_context
|
||||
|
||||
@@ -587,10 +586,10 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
group_list_type=1,
|
||||
w1_scale_bias=None,
|
||||
w2_scale_bias=None,
|
||||
topk_scales=None)
|
||||
topk_scales=None,
|
||||
with_quant=True)
|
||||
|
||||
mock_get_forward_context.assert_called()
|
||||
self.assertTrue(mock_forward_context.with_quant)
|
||||
self.assertEqual(mock_forward_context.fused_moe_state,
|
||||
FusedMoEState.MC2)
|
||||
|
||||
@@ -602,19 +601,15 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
@patch('vllm_ascend.ops.fused_moe.get_forward_context')
|
||||
@patch('vllm_ascend.ops.fused_moe.is_310p')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
def test_unified_apply_mlp_without_quantization(
|
||||
self, mock_npu_dynamic_quant, mock_npu_swiglu,
|
||||
mock_npu_grouped_matmul, mock_is_310p, mock_get_forward_context):
|
||||
|
||||
mock_forward_context = MagicMock()
|
||||
mock_forward_context.with_quant = False
|
||||
mock_get_forward_context.return_value = mock_forward_context
|
||||
|
||||
def test_unified_apply_mlp_without_quantization(self,
|
||||
mock_npu_dynamic_quant,
|
||||
mock_npu_swiglu,
|
||||
mock_npu_grouped_matmul,
|
||||
mock_is_310p):
|
||||
mock_is_310p.return_value = False
|
||||
|
||||
mock_npu_grouped_matmul.side_effect = [[
|
||||
@@ -639,10 +634,8 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
group_list_type=1,
|
||||
w1_scale_bias=None,
|
||||
w2_scale_bias=None,
|
||||
topk_scales=topk_scales)
|
||||
|
||||
mock_get_forward_context.assert_called()
|
||||
self.assertFalse(mock_forward_context.with_quant)
|
||||
topk_scales=topk_scales,
|
||||
with_quant=False)
|
||||
|
||||
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
||||
mock_npu_swiglu.assert_called_once()
|
||||
@@ -698,10 +691,10 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
group_list_type=1,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
topk_scales=None)
|
||||
topk_scales=None,
|
||||
with_quant=True)
|
||||
|
||||
mock_get_forward_context.assert_called()
|
||||
self.assertTrue(mock_forward_context.with_quant)
|
||||
|
||||
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
||||
mock_npu_swiglu.assert_called_once()
|
||||
@@ -710,19 +703,13 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
@patch('vllm_ascend.ops.fused_moe.get_forward_context')
|
||||
@patch('vllm_ascend.ops.fused_moe.is_310p')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
def test_unified_apply_mlp_without_quantization_310p(
|
||||
self, mock_npu_dynamic_quant, mock_npu_swiglu,
|
||||
mock_npu_grouped_matmul, mock_is_310p, mock_get_forward_context):
|
||||
|
||||
mock_forward_context = MagicMock()
|
||||
mock_forward_context.with_quant = False
|
||||
mock_get_forward_context.return_value = mock_forward_context
|
||||
|
||||
mock_npu_grouped_matmul, mock_is_310p):
|
||||
mock_is_310p.return_value = True
|
||||
|
||||
mock_gmm1_out = torch.randn(10, 40, dtype=torch.float16)
|
||||
@@ -750,10 +737,9 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
group_list_type=1,
|
||||
w1_scale_bias=None,
|
||||
w2_scale_bias=None,
|
||||
topk_scales=topk_scales)
|
||||
topk_scales=topk_scales,
|
||||
with_quant=False)
|
||||
|
||||
mock_get_forward_context.assert_called()
|
||||
self.assertFalse(mock_forward_context.with_quant)
|
||||
mock_is_310p.assert_called_once()
|
||||
|
||||
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
||||
|
||||
@@ -263,7 +263,6 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
"max_num_tokens": 100,
|
||||
"ep_size": 2,
|
||||
"num_experts": 128,
|
||||
"with_quant": True,
|
||||
}
|
||||
self.dispatcher_quant = TokenDispatcherWithAllGather(**kwargs)
|
||||
|
||||
@@ -460,8 +459,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
def test_token_dispatch_with_quant(self):
|
||||
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
|
||||
num_experts=4,
|
||||
num_local_experts=2,
|
||||
with_quant=True)
|
||||
num_local_experts=2)
|
||||
|
||||
hidden_states = torch.randn(8, 16)
|
||||
topk_weights = torch.rand(8, 4)
|
||||
@@ -476,7 +474,8 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=self.row_idx,
|
||||
expert_map=expert_map)
|
||||
expert_map=expert_map,
|
||||
with_quant=True)
|
||||
|
||||
self.assertIsNotNone(result["hidden_states"])
|
||||
self.assertIsNotNone(result["group_list"])
|
||||
@@ -486,8 +485,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
def test_token_dispatch_with_quant_no_active_tokens(self):
|
||||
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
|
||||
num_experts=4,
|
||||
num_local_experts=2,
|
||||
with_quant=True)
|
||||
num_local_experts=2)
|
||||
|
||||
self.mock_repeat_interleave.return_value = torch.tensor(
|
||||
[], dtype=torch.long)
|
||||
@@ -505,7 +503,8 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=self.row_idx,
|
||||
expert_map=expert_map)
|
||||
expert_map=expert_map,
|
||||
with_quant=True)
|
||||
|
||||
self.assertIsNotNone(result["hidden_states"])
|
||||
self.assertIsNotNone(result["group_list"])
|
||||
|
||||
Reference in New Issue
Block a user