diff --git a/tests/ut/ops/test_fused_ops.py b/tests/ut/ops/test_fused_ops.py index 6a51d1d..2e1661b 100644 --- a/tests/ut/ops/test_fused_ops.py +++ b/tests/ut/ops/test_fused_ops.py @@ -29,7 +29,7 @@ from vllm_ascend.ascend_forward_context import (FusedMoEState, from vllm_ascend.ops.fused_moe import (AscendFusedMoE, AscendUnquantizedFusedMoEMethod) from vllm_ascend.ops.layers.experts_selector import select_experts -from vllm_ascend.ops.layers.moe_mlp import unified_apply_mlp +from vllm_ascend.ops.layers.moe_mlp import cumsum_group_list, unified_apply_mlp from vllm_ascend.utils import AscendSocVersion, adapt_patch adapt_patch(True) @@ -524,6 +524,43 @@ class TestExpertsSelector: assert topk_ids.shape == (8, 2) +class TestCumsumGroupList(TestBase): + + def setUp(self): + self.active_num = 8 + self.expert_num = 128 + self.experts = torch.zeros((self.expert_num, ), dtype=torch.int64) + self.experts[:self.active_num] = 1 + self.experts = self.experts[torch.randperm(self.expert_num)] + self.group_list = self.experts.cumsum(dim=0) + + def test_cumsum_group_list_with_type_0(self): + group_list = self.experts.cumsum(dim=0) + group_list_type = 0 + result = cumsum_group_list(group_list, group_list_type) + self.assertTrue(torch.equal(result, self.group_list)) + + def test_cumsum_group_list_with_type_1(self): + group_list = self.experts + group_list_type = 1 + result = cumsum_group_list(group_list, group_list_type) + self.assertTrue(torch.equal(result, self.group_list)) + + def test_cumsum_group_list_with_type_2(self): + tokens = torch.arange(self.expert_num, dtype=torch.int64) + group_list = torch.cat([ + tokens.reshape(self.expert_num, 1), + self.experts.reshape(self.expert_num, 1) + ], + dim=1) + group_list_type = 2 + result = cumsum_group_list(group_list, + group_list_type, + active_num=self.active_num, + expert_num=self.expert_num) + self.assertTrue(torch.equal(result, self.group_list)) + + class TestUnifiedApplyMLP(TestBase): @patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context') @@ -739,3 +776,68 @@ class TestUnifiedApplyMLP(TestBase): self.assertEqual(result.shape, hidden_states.shape) self.assertEqual(result.dtype, torch.float16) + + @patch("vllm_ascend.ops.layers.moe_mlp.get_forward_context") + @patch("torch_npu.npu_grouped_matmul") + @patch("torch_npu.npu_swiglu") + @patch("torch_npu.npu_grouped_matmul_swiglu_quant") + @patch("torch_npu.npu_dynamic_quant") + def test_unified_apply_mlp_with_quantization_and_fusion_mlp( + self, mock_npu_dynamic_quant, mock_npu_grouped_matmul_swiglu_quant, + mock_npu_swiglu, mock_npu_grouped_matmul, + mock_get_forward_context): + + mock_forward_context = MagicMock() + mock_forward_context.with_quant = True + mock_forward_context.fused_moe_state = "NOT_MC2" + mock_get_forward_context.return_value = mock_forward_context + + mock_npu_grouped_matmul_swiglu_quant.return_value = (torch.randint( + -128, 127, (10, 40), + dtype=torch.int8), torch.rand( + 10, 1, + dtype=torch.float32), torch.rand(10, 1, dtype=torch.float32)) + mock_npu_grouped_matmul.side_effect = [[ + torch.randn(10, 20, dtype=torch.bfloat16) + ]] + mock_npu_swiglu.return_value = torch.randn(10, + 40, + dtype=torch.bfloat16) + mock_npu_dynamic_quant.return_value = (torch.randint(-128, + 127, (10, 40), + dtype=torch.int8), + torch.rand(10, + 1, + dtype=torch.float32)) + + hidden_states = torch.randn(10, 20, dtype=torch.bfloat16) + w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16) + w1_scale = torch.randn(5, 40, dtype=torch.bfloat16) + w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16) + w2_scale = torch.randn(5, 20, dtype=torch.bfloat16) + w1_scale_bias = torch.randn(5, 40, dtype=torch.bfloat16) + w2_scale_bias = torch.randn(5, 20, dtype=torch.bfloat16) + group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64) + provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32) + + result = unified_apply_mlp(hidden_states=hidden_states, + w1=w1, + w1_scale=w1_scale, + w2=w2, + w2_scale=w2_scale, + group_list=group_list, + dynamic_scale=provided_dynamic_scale, + group_list_type=1, + w1_scale_bias=w1_scale_bias, + w2_scale_bias=w2_scale_bias, + topk_scales=None, + with_quant=True, + fusion=True) + + mock_get_forward_context.assert_called() + mock_npu_grouped_matmul.assert_called_once() + mock_npu_grouped_matmul_swiglu_quant.assert_called_once() + + self.assertTrue(mock_forward_context.with_quant) + self.assertEqual(result.shape, hidden_states.shape) + self.assertEqual(result.dtype, torch.bfloat16) diff --git a/tests/ut/quantization/test_w8a8_dynamic.py b/tests/ut/quantization/test_w8a8_dynamic.py new file mode 100644 index 0000000..690778e --- /dev/null +++ b/tests/ut/quantization/test_w8a8_dynamic.py @@ -0,0 +1,49 @@ +from unittest.mock import Mock, patch + +import torch + +from tests.ut.base import TestBase +from vllm_ascend.quantization.w8a8_dynamic import \ + AscendW8A8DynamicFusedMoEMethod + + +class TestAscendW8A8FusedMoEMethod(TestBase): + num_experts = 8 + hidden_size = 128 + intermediate_size = 128 + + @patch("torch.distributed.get_rank") + @patch("vllm_ascend.quantization.w8a8_dynamic.get_mc2_group") + @patch("vllm_ascend.quantization.w8a8_dynamic.get_ascend_config") + @patch("vllm_ascend.quantization.w8a8_dynamic.get_ep_group") + def setUp(self, mock_get_ep_group, mock_get_ascend_config, + mock_get_mc2_group, mock_get_rank): + mock_ep_group = Mock() + mock_get_ep_group.return_value = mock_ep_group + mock_ascend_config = Mock() + mock_ascend_config.torchair_graph_config = Mock(enabled=False) + mock_get_ascend_config.return_value = mock_ascend_config + mock_mc2_group = Mock(device_group=0) + mock_get_mc2_group.return_value = mock_mc2_group + mock_rank = Mock() + mock_get_rank.return_value = mock_rank + + self.quant_method = AscendW8A8DynamicFusedMoEMethod() + + def test_get_weight(self): + param_dict = self.quant_method.get_weight(self.num_experts, + self.intermediate_size, + self.hidden_size, + torch.bfloat16) + self.assertEqual(param_dict["w13_weight"].dtype, torch.int8) + self.assertEqual( + param_dict["w13_weight"].shape, + (self.num_experts, 2 * self.intermediate_size, self.hidden_size)) + + def test_get_dynamic_quant_param(self): + param_dict = self.quant_method.get_dynamic_quant_param( + self.num_experts, self.intermediate_size, self.hidden_size, + torch.bfloat16) + self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16) + self.assertEqual(param_dict["w13_weight_scale"].shape, + (self.num_experts, 2 * self.intermediate_size, 1)) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 14396c1..11c4ec5 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -70,7 +70,8 @@ def unified_fused_experts_eager(hidden_states: torch.Tensor, shared_dequant_scale: Optional[Any] = None, mc2_mask: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, - with_quant: bool = False): + with_quant: bool = False, + fusion_mlp: bool = False): token_dispatcher = get_forward_context().token_dispatcher results = token_dispatcher.token_dispatch( @@ -100,7 +101,8 @@ def unified_fused_experts_eager(hidden_states: torch.Tensor, w1_scale_bias=w1_scale_bias, w2_scale_bias=w2_scale_bias, topk_scales=results.get("topk_scales"), - with_quant=with_quant) + with_quant=with_quant, + fusion=fusion_mlp) final_hidden_states = token_dispatcher.token_combine(expert_output) return final_hidden_states diff --git a/vllm_ascend/ops/layers/moe_mlp.py b/vllm_ascend/ops/layers/moe_mlp.py index c73e8ea..d6f67bb 100644 --- a/vllm_ascend/ops/layers/moe_mlp.py +++ b/vllm_ascend/ops/layers/moe_mlp.py @@ -18,22 +18,52 @@ from typing import Optional import torch import torch_npu +from torch.nn.functional import pad from vllm.forward_context import get_forward_context from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.utils import dispose_tensor, is_310p +def cumsum_group_list(group_list: torch.Tensor, + group_list_type: int, + active_num: int = 0, + expert_num: int = 0) -> torch.Tensor: + if group_list_type not in [0, 1, 2]: + raise ValueError( + f"group_list_type should be in [0, 1, 2], but received {group_list_type}" + ) + + if group_list_type == 0: + return group_list + if group_list_type == 1: + return group_list.cumsum(dim=0) + + experts = pad(group_list[:, 0], (1, 0)) + tokens = pad(group_list[:, 1].cumsum(dim=0), (1, 0)) + cumsum_group_list = torch.full(size=(expert_num, ), + fill_value=active_num, + dtype=group_list.dtype, + device=group_list.device) + + for i, (start, end) in enumerate(zip(experts[:-1], experts[1:])): + if end > start: + cumsum_group_list[start:end] = tokens[i] + + return cumsum_group_list + + def quant_apply_mlp(hidden_states: torch.Tensor, w1: torch.Tensor, w1_scale: torch.Tensor, w2: torch.Tensor, w2_scale: torch.Tensor, group_list: torch.Tensor, - dynamic_scale: torch.Tensor = None, group_list_type: int = 1, + dynamic_scale: torch.Tensor = None, w1_scale_bias: torch.Tensor = None, - w2_scale_bias: torch.Tensor = None) -> torch.Tensor: + w2_scale_bias: torch.Tensor = None, + fusion: bool = False) -> torch.Tensor: if dynamic_scale is None: unquantized_hidden_states = hidden_states hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( @@ -49,31 +79,38 @@ def quant_apply_mlp(hidden_states: torch.Tensor, is_mc2 = get_forward_context().fused_moe_state == FusedMoEState.MC2 if w1_scale_bias is None and is_mc2: - w1_scale = w1_scale.to(torch.float32) - - # gmm1: gate_up_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w1], - split_item=3, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=torch.int32)[0] - - # act_fn: swiglu - hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( - x=hidden_states, - weight_scale=w1_scale, - activation_scale=pertoken_scale, - bias=None, - quant_scale=None, - quant_offset=None, - group_index=group_list, - activate_left=True, - quant_mode=1, - ) - + if w1_scale.dtype != torch.float32: + w1_scale = w1_scale.to(torch.float32) + if fusion: + # gmm1: gate_up_proj & act_fn: swiglu + hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant( + x=hidden_states, + weight=w1, + group_list=cumsum_group_list(group_list, group_list_type), + weight_scale=w1_scale, + x_scale=pertoken_scale) + else: + # gmm1: gate_up_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w1], + split_item=3, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=torch.int32)[0] + # act_fn: swiglu + hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( + x=hidden_states, + weight_scale=w1_scale, + activation_scale=pertoken_scale, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=group_list, + activate_left=True, + quant_mode=1, + ) # gmm2: down_proj hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], @@ -92,29 +129,37 @@ def quant_apply_mlp(hidden_states: torch.Tensor, [group_list[:1], torch.diff(group_list, dim=0)]) group_list_type = 1 - bias1 = [w1_scale_bias] + bias1 = [w1_scale_bias] if not fusion else w1_scale_bias bias2 = [w2_scale_bias] # TODO w4a8 scene: dynamic acquisition of dtype in the future _output_dtype = torch.bfloat16 - # gmm1: gate_up_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w1], - scale=[w1_scale], - bias=bias1, - per_token_scale=[pertoken_scale], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=_output_dtype)[0] - - # act_fn: swiglu - hidden_states = torch_npu.npu_swiglu(hidden_states) - hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant( - hidden_states) - + if fusion: + # gmm1: gate_up_proj & act_fn: swiglu + hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant( + x=hidden_states, + weight=w1, + bias=bias1, + group_list=cumsum_group_list(group_list, group_list_type), + weight_scale=w1_scale, + x_scale=pertoken_scale) + else: + # gmm1: gate_up_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w1], + scale=[w1_scale.to(w2_scale.dtype)], + bias=bias1, + per_token_scale=[pertoken_scale], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=_output_dtype)[0] + # act_fn: swiglu + hidden_states = torch_npu.npu_swiglu(hidden_states) + hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant( + hidden_states) # gmm2: down_proj hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], @@ -127,6 +172,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor, group_type=0, group_list=group_list, output_dtype=_output_dtype)[0] + return hidden_states @@ -178,7 +224,8 @@ def unified_apply_mlp(hidden_states: torch.Tensor, w1_scale_bias: torch.Tensor = None, w2_scale_bias: torch.Tensor = None, topk_scales: Optional[torch.Tensor] = None, - with_quant: bool = False) -> torch.Tensor: + with_quant: bool = False, + fusion: bool = False) -> torch.Tensor: if with_quant: return quant_apply_mlp(hidden_states=hidden_states, w1=w1, @@ -189,7 +236,8 @@ def unified_apply_mlp(hidden_states: torch.Tensor, dynamic_scale=dynamic_scale, group_list_type=group_list_type, w1_scale_bias=w1_scale_bias, - w2_scale_bias=w2_scale_bias) + w2_scale_bias=w2_scale_bias, + fusion=fusion) else: return unquant_apply_mlp(hidden_states=hidden_states, w1=w1, diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 20c68be..f710bd2 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -31,173 +31,7 @@ from vllm_ascend.ops.common_fused_moe import \ fused_experts as unified_fused_experts from vllm_ascend.ops.fused_moe import unified_fused_experts_eager from vllm_ascend.ops.layers.experts_selector import select_experts -from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, dispose_tensor - - -def apply_mlp_decode(hidden_states: torch.Tensor, - w1: torch.Tensor, - w1_scale: torch.Tensor, - w2: torch.Tensor, - w2_scale: torch.Tensor, - group_list: torch.Tensor, - dynamic_scale: torch.Tensor = None, - group_list_type: int = 1) -> torch.Tensor: - """ - apply MLP: gate_up_proj -> swiglu -> down_proj - Args: - hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size). - w1: expert weights1 with shape - (num_experts, hidden_size, intermediate_size * 2) - w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2) - w2: expert weights2 with shape - (num_experts, intermediate_size, hidden_size) - w2_scale: weights2 scale with shape (num_experts, hidden_size) - group_list: number of tokens for each expert, follow cumsum mode, and - with shape (num_experts). - transpose_weight: - w1: (num_experts, intermediate_size * 2, hidden_size) -> - (num_experts, hidden_size, intermediate_size * 2) - w2: (num_experts, hidden_size, intermediate_size) -> - (num_experts, intermediate_size, hidden_size) - Returns: - hidden_states: output hidden states after MLP. - """ - - if dynamic_scale is None: - unquantized_hidden_states = hidden_states - hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( - hidden_states) - # Dispose the original unquantized hidden states - # to save npu memory because they're no longer used. - dispose_tensor(unquantized_hidden_states) - else: - pertoken_scale = dynamic_scale - - # gmm1: gate_up_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w1], - split_item=3, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=torch.int32)[0] - - # act_fn: swiglu - hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( - x=hidden_states, - weight_scale=w1_scale, - activation_scale=pertoken_scale, - bias=None, - quant_scale=None, - quant_offset=None, - group_index=group_list, - activate_left=True, - quant_mode=1, - ) - - # gmm2: down_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w2], - scale=[w2_scale], - per_token_scale=[swiglu_out_scale], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=w2_scale.dtype)[0] - return hidden_states - - -def apply_mlp(hidden_states: torch.Tensor, - w1: torch.Tensor, - w1_scale: torch.Tensor, - w2: torch.Tensor, - w2_scale: torch.Tensor, - group_list: torch.Tensor, - dynamic_scale: torch.Tensor = None, - group_list_type: int = 1, - w1_scale_bias: torch.Tensor = None, - w2_scale_bias: torch.Tensor = None) -> torch.Tensor: - """ - apply MLP: gate_up_proj -> swiglu -> down_proj - - Args: - hidden_states: input hidden states with shape (num_tokens, hidden_size). - w1: expert weights1 with shape - (num_experts, hidden_size, intermediate_size * 2) - w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2) - w2: expert weights2 with shape - (num_experts, intermediate_size, hidden_size) - w2_scale: weights2 scale with shape (num_experts, hidden_size) - group_list: number of tokens for each expert, follow cumsum mode, and - with shape (num_experts). - transpose_weight: - w1: (num_experts, intermediate_size * 2, hidden_size) -> - (num_experts, hidden_size, intermediate_size * 2) - w2: (num_experts, hidden_size, intermediate_size) -> - (num_experts, intermediate_size, hidden_size) - - Returns: - hidden_states: output hidden states after MLP. - """ - - if dynamic_scale is None: - unquantized_hidden_states = hidden_states - hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( - hidden_states) - # Dispose the original unquantized hidden states - # to save npu memory because they're no longer used. - dispose_tensor(unquantized_hidden_states) - else: - pertoken_scale = dynamic_scale - - bias1, bias2 = None, None - _output_dtype = w2_scale.dtype - - if w1_scale_bias is not None: - if group_list_type == 0: - group_list = torch.cat( - [group_list[:1], torch.diff(group_list, dim=0)]) - group_list_type = 1 - bias1 = [w1_scale_bias] - bias2 = [w2_scale_bias] - # TODO w4a8 scene: dynamic acquisition of dtype in the future - _output_dtype = torch.bfloat16 - - # gmm1: gate_up_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w1], - scale=[w1_scale], - bias=bias1, - per_token_scale=[pertoken_scale], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=_output_dtype)[0] - - # act_fn: swiglu - hidden_states = torch_npu.npu_swiglu(hidden_states) - hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant( - hidden_states) - - # gmm2: down_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w2], - scale=[w2_scale], - bias=bias2, - per_token_scale=[swiglu_out_scale], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=_output_dtype)[0] - - return hidden_states +from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ class AscendW8A8DynamicLinearMethod: @@ -418,7 +252,7 @@ class AscendW8A8DynamicFusedMoEMethod: return unified_fused_experts_eager( hidden_states=x, w1=layer.w13_weight, - w1_scale=layer.w13_weight_scale, + w1_scale=layer.w13_weight_scale_fp32, w2=layer.w2_weight, w2_scale=layer.w2_weight_scale, topk_weights=topk_weights, @@ -431,7 +265,8 @@ class AscendW8A8DynamicFusedMoEMethod: shared_gate_up=shared_gate_up, shared_dequant_scale=shared_dequant_scale, mc2_mask=kwargs.get("mc2_mask", None), - with_quant=True) + with_quant=True, + fusion_mlp=True) def process_weights_after_loading(self, layer): if self.transpose_weight: @@ -439,6 +274,7 @@ class AscendW8A8DynamicFusedMoEMethod: 1, 2).contiguous() layer.w2_weight.data = layer.w2_weight.data.transpose( 1, 2).contiguous() + torch_npu.npu_format_cast_(layer.w13_weight, ACL_FORMAT_FRACTAL_NZ) if envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP: torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ) layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(