diff --git a/tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py b/tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py new file mode 100644 index 0000000..cd94101 --- /dev/null +++ b/tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py @@ -0,0 +1,176 @@ +import copy +from unittest.mock import Mock, patch + +import torch + +from tests.ut.base import TestBase +from vllm_ascend.torchair.quantization.torchair_w4a8_dynamic import ( + TorchairAscendW4A8DynamicFusedMoEMethod, + TorchairAscendW4A8DynamicLinearMethod) + + +class TestAscendW4A8DynamicLinearMethod(TestBase): + + def setUp(self): + self.method = TorchairAscendW4A8DynamicLinearMethod() + self.method.group_size = 8 + + def test_get_weight(self): + weight = self.method.get_weight(8, 32, torch.bfloat16) + self.assertEqual(weight["weight"].dtype, torch.int8) + self.assertEqual(weight["weight"].shape, (32, 8)) + + def test_get_pergroup_param(self): + params = self.method.get_pergroup_param(8, 32, torch.bfloat16) + self.assertEqual(params["weight_scale"].dtype, torch.bfloat16) + self.assertEqual(params["weight_scale"].shape, (32, 1)) + self.assertEqual(params["weight_offset"].dtype, torch.bfloat16) + self.assertEqual(params["weight_offset"].shape, (32, 1)) + self.assertEqual(params["weight_scale_second"].dtype, torch.bfloat16) + self.assertEqual(params["weight_scale_second"].shape, (32, 1)) + self.assertEqual(params["weight_offset_second"].dtype, torch.bfloat16) + self.assertEqual(params["weight_offset_second"].shape, (32, 1)) + + +class TestAscendW4A8DynamicFusedMoEMethod(TestBase): + experts = 8 + input_size = 16 + output_size = 56 + group_size = 2 + + @patch( + 'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_current_vllm_config' + ) + @patch( + 'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_ep_group') + @patch("vllm_ascend.ascend_config.get_ascend_config") + @patch( + 'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_mc2_group' + ) + @patch('torch.distributed.get_rank', return_value=0) + def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ascend_config, + mock_get_ep_group, get_current_vllm_config): + mock_ascend_config = Mock() + mock_ascend_config.torchair_graph_config = Mock(enabled=False) + mock_get_ascend_config.return_value = mock_ascend_config + mock_vllm_config = Mock() + mock_vllm_config.quant_config = Mock(quant_description={ + "group_size": self.group_size, + "version": "0.0.0" + }) + mock_vllm_config.parallel_config = Mock(enable_expert_parallel=True) + get_current_vllm_config.return_value = mock_vllm_config + self.quant_method = TorchairAscendW4A8DynamicFusedMoEMethod() + + def test_get_weight(self): + # old quant version w4a8 weight + param_dict = self.quant_method.get_weight(self.experts, + self.input_size, + self.output_size, + torch.bfloat16) + self.assertEqual(param_dict["w13_weight"].dtype, torch.int8) + self.assertEqual(param_dict["w13_weight"].shape, + (self.experts, 2 * self.input_size, self.output_size)) + # new quant version weight + self.quant_method.new_quant_version = True + param_dict = self.quant_method.get_weight(self.experts, + self.input_size, + self.output_size, + torch.bfloat16) + self.assertEqual(param_dict["w13_weight"].dtype, torch.int8) + self.assertEqual(param_dict["w13_weight"].shape, + (self.experts, self.input_size, self.output_size)) + + def test_get_dynamic_quant_param(self): + # old quant version weight + param_dict = self.quant_method.get_dynamic_quant_param( + self.experts, self.input_size, self.output_size, torch.bfloat16) + self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16) + self.assertEqual(param_dict["w13_weight_scale"].shape, + (self.experts, 2 * self.input_size, 1)) + self.assertEqual(param_dict["w13_weight_scale_second"].dtype, + torch.bfloat16) + self.assertEqual(param_dict["w13_weight_scale_second"].shape, + (self.experts, 2 * self.input_size, + self.output_size // self.group_size)) + self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.bfloat16) + self.assertEqual(param_dict["w2_weight_scale"].shape, + (self.experts, self.output_size, 1)) + self.assertEqual(param_dict["w2_weight_scale_second"].dtype, + torch.bfloat16) + self.assertEqual(param_dict["w2_weight_scale_second"].shape, + (self.experts, self.output_size, + self.input_size // self.group_size)) + # new quant version weight + self.quant_method.new_quant_version = True + param_dict = self.quant_method.get_dynamic_quant_param( + self.experts, self.input_size, self.output_size, torch.bfloat16) + self.assertEqual(param_dict["w2_scale_bias"].dtype, torch.float32) + self.assertEqual( + param_dict["w2_scale_bias"].shape, + (self.experts, self.output_size, 16 // self.quant_method.tp_size)) + + @patch('torch_npu.npu_quantize') + @patch('torch.Tensor.npu') + def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize): + # old quant version weight + layer = torch.nn.Module() + layer.w13_weight = torch.nn.Parameter(torch.zeros( + (self.experts, 2 * self.input_size, self.output_size), + dtype=torch.int8), + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(torch.zeros( + (self.experts, self.output_size, self.input_size), + dtype=torch.int8), + requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter(torch.ones( + (self.experts, 2 * self.input_size, 1), dtype=torch.bfloat16), + requires_grad=False) + layer.w13_weight_scale_second = torch.nn.Parameter(torch.ones( + (self.experts, 2 * self.input_size, + self.output_size // self.group_size), + dtype=torch.bfloat16), + requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter(torch.ones( + (self.experts, self.output_size, 1), dtype=torch.bfloat16), + requires_grad=False) + layer.w2_weight_scale_second = torch.nn.Parameter(torch.ones( + (self.experts, self.output_size, + self.input_size // self.group_size), + dtype=torch.bfloat16), + requires_grad=False) + new_layer = copy.deepcopy(layer) + + mock_npu.return_value = torch.Tensor() + mock_npu_quantize.return_value = torch.Tensor() + self.quant_method.process_weights_after_loading(layer) + self.assertTrue(hasattr(layer, "w13_scale_bias")) + self.assertEqual(layer.w13_scale_bias.data.shape, + (self.experts, 2 * self.input_size)) + self.assertEqual(layer.w13_scale_bias.data.dtype, torch.float32) + self.assertTrue(hasattr(layer, "w2_scale_bias")) + self.assertEqual(layer.w2_scale_bias.data.shape, + (self.experts, self.output_size)) + self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32) + # new quant version weight + self.quant_method.new_quant_version = True + new_layer.w13_weight.data = torch.zeros( + (self.experts, self.input_size, self.output_size), + dtype=torch.int8) + new_layer.w2_weight.data = torch.zeros( + (self.experts, self.output_size // 2, self.input_size), + dtype=torch.int8) + w13_scale_bias = torch.zeros((self.experts, 2 * self.input_size, 1), + dtype=torch.float32) + new_layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias, + requires_grad=False) + w2_scale_bias = torch.zeros( + (self.experts, self.output_size, 16 // self.quant_method.tp_size), + dtype=torch.float32) + new_layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias, + requires_grad=False) + self.quant_method.process_weights_after_loading(new_layer) + self.assertEqual(new_layer.w13_scale_bias.data.shape, + (self.experts, 2 * self.input_size)) + self.assertEqual(new_layer.w2_scale_bias.data.shape, + (self.experts, self.output_size)) diff --git a/tests/ut/torchair/quantization/test_torchair_w8a8_dynamic.py b/tests/ut/torchair/quantization/test_torchair_w8a8_dynamic.py new file mode 100644 index 0000000..520155d --- /dev/null +++ b/tests/ut/torchair/quantization/test_torchair_w8a8_dynamic.py @@ -0,0 +1,75 @@ +from unittest.mock import MagicMock, patch + +import torch + +from tests.ut.base import TestBase +from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \ + torchair_fused_experts_with_all2all + + +class TestAscendW8A8FusedMoEMethod(TestBase): + + def setUp(self): + self.hidden_size = 128 + self.num_tokens = 128 + self.placeholder = torch.randn(self.num_tokens, + self.hidden_size, + dtype=torch.bfloat16) + + @patch("torch.distributed.all_to_all_single") + @patch("torch_npu.npu_moe_re_routing") + @patch("torch_npu.npu_grouped_matmul") + @patch("torch_npu.npu_swiglu") + @patch("torch_npu.npu_dynamic_quant") + @patch("torch_npu.npu_moe_finalize_routing") + @patch("torch_npu.npu_moe_init_routing") + def test_torchair_fused_experts_with_all2all( + self, mock_moe_init_routing, mock_moe_finalize_routing, + mock_dynamic_quant, mock_swiglu, mock_grouped_matmul, + mock_moe_re_routing, mock_all_to_all_single): + + expert_map = MagicMock() + ep_group = MagicMock() + placeholder_int8 = torch.randint(0, + 100, + (self.num_tokens, self.hidden_size), + dtype=torch.int8) + placeholder_ones = torch.ones(self.num_tokens, dtype=torch.int32) + mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_( + input) + mock_moe_init_routing.return_value = ( + placeholder_int8, + placeholder_ones, + placeholder_ones, + ) + mock_moe_re_routing.return_value = (placeholder_int8, self.placeholder, + torch.randint(0, + 100, + (self.num_tokens, ), + dtype=torch.int32), + self.placeholder) + mock_grouped_matmul.return_value = self.placeholder + mock_swiglu.return_value = self.placeholder + mock_dynamic_quant.return_value = ( + placeholder_int8, + torch.randn(self.num_tokens), + ) + mock_moe_finalize_routing.return_value = self.placeholder + + result = torchair_fused_experts_with_all2all( + hidden_states=self.placeholder, + w1=self.placeholder, + w1_scale=self.placeholder, + w2=self.placeholder, + w2_scale=self.placeholder, + topk_weights=self.placeholder, + topk_ids=self.placeholder, + top_k=8, + expert_map=expert_map, + ep_group=ep_group, + log2phy=None, + global_redundant_expert_num=256, + ) + self.assertIsNotNone(result) + self.assertEqual(result.dtype, torch.bfloat16) + self.assertEqual(result.shape, (128, 128)) diff --git a/tests/ut/torchair/test_utils.py b/tests/ut/torchair/test_utils.py index 45416bd..ab85d28 100644 --- a/tests/ut/torchair/test_utils.py +++ b/tests/ut/torchair/test_utils.py @@ -6,6 +6,7 @@ from unittest.mock import MagicMock, patch import torch from tests.ut.base import TestBase +from vllm_ascend.quantization.quantizer import SUPPORT_ASCEND_QUANTIZER_TYPE from vllm_ascend.torchair import utils @@ -120,3 +121,15 @@ class TestTorchairUtils(TestBase): utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ) mock_npu_cast.assert_not_called() + + def test_torchair_quant_method_register(self): + + TorchairW8A8DYNAMICQuantizer = SUPPORT_ASCEND_QUANTIZER_TYPE[ + "W8A8_DYNAMIC"] + TorchairW4A8DYNAMICQuantizer = SUPPORT_ASCEND_QUANTIZER_TYPE[ + "W4A8_DYNAMIC"] + utils.torchair_quant_method_register() + self.assertNotEqual(TorchairW8A8DYNAMICQuantizer, + SUPPORT_ASCEND_QUANTIZER_TYPE["W8A8_DYNAMIC"]) + self.assertNotEqual(TorchairW4A8DYNAMICQuantizer, + SUPPORT_ASCEND_QUANTIZER_TYPE["W4A8_DYNAMIC"]) diff --git a/vllm_ascend/torchair/models/torchair_deepseek_v2.py b/vllm_ascend/torchair/models/torchair_deepseek_v2.py index aa3e923..b31549d 100644 --- a/vllm_ascend/torchair/models/torchair_deepseek_v2.py +++ b/vllm_ascend/torchair/models/torchair_deepseek_v2.py @@ -71,8 +71,9 @@ from vllm.sequence import IntermediateTensors from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.quantization.quant_config import AscendLinearMethod -from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE +from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \ + TorchairAscendW8A8DynamicLinearMethod from vllm_ascend.utils import dispose_tensor, npu_prefetch @@ -261,8 +262,9 @@ class TorchairDeepseekV2MLP(nn.Module): quant_method = self.gate_up_proj.quant_method if isinstance(quant_method, UnquantizedLinearMethod): self.act_fn = TorchairDeepseekV2SiluAndMul() - elif (isinstance(quant_method, AscendLinearMethod) and isinstance( - quant_method.quant_method, AscendW8A8DynamicLinearMethod)): + elif (isinstance(quant_method, AscendLinearMethod) + and isinstance(quant_method.quant_method, + TorchairAscendW8A8DynamicLinearMethod)): # TODO(sdmyzlp): Currently preserved as before: # 1. The only quantization supported for silu is W8A8Dynamic # 2. Output dtype of gate_up/down is fixed to be int32/bfloat16 diff --git a/vllm_ascend/torchair/quantization/__init__.py b/vllm_ascend/torchair/quantization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_ascend/torchair/quantization/torchair_quantizer.py b/vllm_ascend/torchair/quantization/torchair_quantizer.py new file mode 100644 index 0000000..1d1d584 --- /dev/null +++ b/vllm_ascend/torchair/quantization/torchair_quantizer.py @@ -0,0 +1,29 @@ +from vllm_ascend.quantization.quantizer import VLLMAscendQuantizer +from vllm_ascend.torchair.quantization.torchair_w4a8_dynamic import ( + TorchairAscendW4A8DynamicFusedMoEMethod, + TorchairAscendW4A8DynamicLinearMethod) +from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import ( + TorchairAscendW8A8DynamicFusedMoEMethod, + TorchairAscendW8A8DynamicLinearMethod) + + +class TorchairW8A8DYNAMICQuantizer(VLLMAscendQuantizer): + + @staticmethod + def build_linear_method(): + return TorchairAscendW8A8DynamicLinearMethod() + + @staticmethod + def build_moe_method(): + return TorchairAscendW8A8DynamicFusedMoEMethod() + + +class TorchairW4A8DYNAMICQuantizer(VLLMAscendQuantizer): + + @staticmethod + def build_linear_method(): + return TorchairAscendW4A8DynamicLinearMethod() + + @staticmethod + def build_moe_method(): + return TorchairAscendW4A8DynamicFusedMoEMethod() diff --git a/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py b/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py new file mode 100644 index 0000000..0354b47 --- /dev/null +++ b/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py @@ -0,0 +1,424 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, Callable, Dict, Optional + +import numpy as np +import torch +import torch_npu +from vllm.config import get_current_vllm_config +from vllm.distributed import get_ep_group +from vllm.forward_context import get_forward_context + +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ascend_forward_context import FusedMoEState +from vllm_ascend.distributed.parallel_state import get_mc2_group +from vllm_ascend.ops.layers.experts_selector import select_experts +from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import ( + torchair_fused_experts_with_all2all, torchair_fused_experts_with_mc2) +from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor + + +class TorchairAscendW4A8DynamicLinearMethod: + """Linear method for Ascend W4A8_DYNAMIC + """ + + def __init__(self): + self.transpose_weight = True + try: + self.group_size = get_current_vllm_config( + ).quant_config.quant_description.get("group_size", 256) + except AttributeError: + self.group_size = 256 + + @staticmethod + def get_weight(input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + params_dict = { + "weight": torch.empty(output_size, input_size, dtype=torch.int8) + } + return params_dict + + @staticmethod + def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: + return {} + + @staticmethod + def get_perchannel_param(output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + return {} + + def get_pergroup_param(self, input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + params_dict = {} + params_dict["weight_scale"] = torch.empty(output_size, + 1, + dtype=params_dtype) + params_dict["weight_offset"] = torch.empty(output_size, + 1, + dtype=params_dtype) + params_dict["weight_scale_second"] = torch.empty(output_size, + input_size // + self.group_size, + dtype=params_dtype) + params_dict["weight_offset_second"] = torch.empty(output_size, + input_size // + self.group_size, + dtype=params_dtype) + return params_dict + + @staticmethod + def process_scale_second(weight: torch.Tensor, scale: torch.Tensor, + per_group_scale: torch.Tensor): + k, n = weight.shape + group_num, n = per_group_scale.shape + weight_high = weight.to(torch.float32).reshape( + group_num, -1, n) * per_group_scale.reshape(group_num, 1, n) + weight_high = weight_high.reshape(k, n) + bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0) + antiquant_scale = (scale * per_group_scale).reshape(group_num, n) + return antiquant_scale.npu(), bias + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + tp_rank: Optional[int] = None, + ) -> torch.Tensor: + return torch_npu.npu_weight_quant_batchmatmul( + x, + layer.weight, + antiquant_scale=layer.weight_scale_second.to(x.dtype), + antiquant_group_size=self.group_size, + ) + + def process_weights_after_loading(self, layer: torch.nn.Module): + if self.transpose_weight: + layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() + layer.weight_scale.data = layer.weight_scale.data.flatten().to( + torch.float32) + layer.weight_offset.data = layer.weight_offset.data.flatten() + layer.weight_scale_second.data, scale_bias = self.process_scale_second( + layer.weight.data, + layer.weight_scale.data, + layer.weight_scale_second.data.transpose(0, 1).contiguous(), + ) + param = torch.nn.Parameter(scale_bias, requires_grad=False) + layer.register_parameter("weight_scale_bias", param) + layer.weight.data = torch_npu.npu_convert_weight_to_int4pack( + layer.weight.data.to(torch.int32)) + + +class TorchairAscendW4A8DynamicFusedMoEMethod: + """FusedMoe method for Ascend W4A8_DYNAMIC. + """ + + def __init__(self): + self.transpose_weight = True + + self.ep_group = get_ep_group() + + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + + vllm_config = get_current_vllm_config() + self.group_size = vllm_config.quant_config.quant_description.get( + "group_size", 256) + quant_version = vllm_config.quant_config.quant_description.get( + "version", "0") + # NOTE: new quantize weights: 2 int4 pack into int8 + self.new_quant_version = quant_version == "1.0.0" + self.tp_size = 1 if vllm_config.parallel_config.enable_expert_parallel else self.ep_group.world_size + if self.new_quant_version and self.tp_size > 16: + raise ValueError( + "The current weight does not support moe part tp>16.") + + try: + device_group = get_mc2_group().device_group + # TODO: Try local_rank = ep_group.rank_in_group + local_rank = torch.distributed.get_rank(group=device_group) + backend = device_group._get_backend(torch.device("npu")) + self.moe_all_to_all_group_name = backend.get_hccl_comm_name( + local_rank) + except AttributeError: + self.moe_all_to_all_group_name = "" + + def get_weight(self, num_experts: int, + intermediate_size_per_partition: int, hidden_sizes: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + param_dict = {} + if self.new_quant_version: + w13_output_size = intermediate_size_per_partition + w2_output_size = hidden_sizes // 2 + else: + w13_output_size = 2 * intermediate_size_per_partition + w2_output_size = hidden_sizes + + param_dict["w13_weight"] = torch.empty(num_experts, + w13_output_size, + hidden_sizes, + dtype=torch.int8) + param_dict["w2_weight"] = torch.empty(num_experts, + w2_output_size, + intermediate_size_per_partition, + dtype=torch.int8) + return param_dict + + def get_dynamic_quant_param(self, num_experts: int, + intermediate_size_per_partition: int, + hidden_sizes: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + param_dict = {} + param_dict["w13_weight_scale"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=params_dtype) + + param_dict["w13_weight_offset"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=params_dtype) + + param_dict["w13_weight_scale_second"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_sizes // self.group_size, + dtype=params_dtype) + + param_dict["w13_weight_offset_second"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_sizes // self.group_size, + dtype=params_dtype) + + param_dict["w2_weight_scale"] = torch.empty(num_experts, + hidden_sizes, + 1, + dtype=params_dtype) + param_dict["w2_weight_offset"] = torch.empty(num_experts, + hidden_sizes, + 1, + dtype=params_dtype) + param_dict["w2_weight_scale_second"] = torch.empty( + num_experts, + hidden_sizes, + intermediate_size_per_partition // self.group_size, + dtype=params_dtype) + param_dict["w2_weight_offset_second"] = torch.empty( + num_experts, + hidden_sizes, + intermediate_size_per_partition // self.group_size, + dtype=params_dtype) + + if self.new_quant_version: + param_dict["w13_scale_bias"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=torch.float32) + param_dict["w2_scale_bias"] = torch.empty(num_experts, + hidden_sizes, + 16 // self.tp_size, + dtype=torch.float32) + + return param_dict + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + is_prefill: bool = True, + enable_force_load_balance: bool = True, + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, + **kwargs, + ) -> torch.Tensor: + assert router_logits.shape[ + 1] == global_num_experts, "Number of global experts mismatch" + + # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + global_num_experts=global_num_experts) + + fused_moe_state = get_forward_context().fused_moe_state + shared_gate_up, shared_dequant_scale = None, None + if shared_experts is not None and fused_moe_state == FusedMoEState.MC2: + with npu_stream_switch("moe_secondary", 0): + npu_wait_tensor(quantized_x_for_share, router_logits) + share_up_out, _ = shared_experts.gate_up_proj( + (quantized_x_for_share, dynamic_scale_for_share)) + shared_gate_up, shared_dequant_scale = share_up_out[ + 0], share_up_out[1] + + # this is a naive implementation for experts load balance so as + # to avoid accumulating too much tokens on a single rank. + # currently it is only activated when doing profile runs. + if enable_force_load_balance: + topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) + + topk_weights = topk_weights.to(x.dtype) + if fused_moe_state == FusedMoEState.MC2: + return torchair_fused_experts_with_mc2( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + w1_scale=layer.w13_weight_scale_second, + w2_scale=layer.w2_weight_scale_second, + w1_scale_bias=layer.w13_scale_bias, + w2_scale_bias=layer.w2_scale_bias, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, + moe_all_to_all_group_name=self.moe_all_to_all_group_name, + log2phy=log2phy, + global_redundant_expert_num=global_redundant_expert_num, + shared_experts=shared_experts, + is_torchair=self.torchair_graph_enabled, + quantized_x_for_share=shared_gate_up, + dynamic_scale_for_share=shared_dequant_scale, + mc2_mask=kwargs.get("mc2_mask", None)) + else: + # The current implementation of deepseek moe splits hidden_states + # according to tp_size before they are feed into layers module. + # Therefore, all2all is needed no matter how dp/tp is set so as to + # dispatch/combine tokens. + return torchair_fused_experts_with_all2all( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + w1_scale=layer.w13_weight_scale_second, + w2_scale=layer.w2_weight_scale_second, + w1_scale_bias=layer.w13_scale_bias, + w2_scale_bias=layer.w2_scale_bias, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, + ep_group=self.ep_group, + log2phy=log2phy, + global_redundant_expert_num=global_redundant_expert_num, + ) + + def process_scale(self, weight: torch.Tensor, scale, per_group_scale): + group_num, k, n = weight.shape + # the weight of the new version is reduced by half by pack n, so it needs to be restored + if self.new_quant_version: + n = n * 2 + per_group_scale = per_group_scale.reshape(group_num, -1, n) + group_num, quantgroup_num, n = per_group_scale.shape + bias = None + if not self.new_quant_version: + weight_high = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * \ + per_group_scale.reshape([group_num, quantgroup_num, 1, n]) + weight_high = weight_high.reshape([group_num, k, n]) + bias = 8 * (weight_high.to(torch.float32) * scale).sum(axis=1) + scale_fp32 = (scale * per_group_scale).to(torch.float16).to( + torch.float32) + scale_fp32_np = scale_fp32.cpu().numpy() + scale_fp32_np.dtype = np.uint32 + sscale_uint64 = np.zeros((group_num, quantgroup_num, n * 2), + dtype=np.uint32) + + sscale_uint64[..., ::2] = scale_fp32_np + + sscale_uint64_buffer = np.frombuffer(sscale_uint64.tobytes(), + dtype=np.int64).copy() + sscale_uint64_tensor = torch.from_numpy(sscale_uint64_buffer).reshape( + group_num, quantgroup_num, n) + sscale_uint64_tensor = sscale_uint64_tensor.npu() + return sscale_uint64_tensor, bias + + def update_bias(self, layer, w13_bias, w2_bias): + if self.new_quant_version: + layer.w13_scale_bias.data = layer.w13_scale_bias.data.transpose( + 1, 2).contiguous().sum(axis=1) + layer.w2_scale_bias.data = layer.w2_scale_bias.data.transpose( + 1, 2).contiguous().sum(axis=1) + else: + w13_scale_bias = torch.nn.Parameter(w13_bias, requires_grad=False) + layer.register_parameter("w13_scale_bias", w13_scale_bias) + w2_scale_bias = torch.nn.Parameter(w2_bias, requires_grad=False) + layer.register_parameter("w2_scale_bias", w2_scale_bias) + + def pack_to_int32(self, weight: torch.Tensor): + if self.new_quant_version: + group_num, k, n = weight.shape + assert n % 4 == 0, "the last dim of weight needs to be divided by 4" + packed_n = n // 4 + # pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4 + packed_weight = torch.from_numpy( + np.frombuffer(weight.cpu().numpy().tobytes(), dtype=np.int32)) + return packed_weight.reshape(group_num, k, packed_n).npu() + else: + return torch_npu.npu_quantize(weight.to(torch.float32), + torch.tensor([1.]).npu(), None, + torch.quint4x2, -1, False) + + def process_weights_after_loading(self, layer): + if self.transpose_weight: + layer.w13_weight.data = layer.w13_weight.data.transpose( + 1, 2).contiguous() + layer.w2_weight.data = layer.w2_weight.data.transpose( + 1, 2).contiguous() + layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose( + 1, 2).contiguous() + layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose( + 1, 2).contiguous() + layer.w13_weight_scale_second.data = layer.w13_weight_scale_second.data.transpose( + 1, 2).contiguous() + layer.w2_weight_scale_second.data = layer.w2_weight_scale_second.data.transpose( + 1, 2).contiguous() + + layer.w13_weight_scale_second.data, w13_bias = self.process_scale( + layer.w13_weight, layer.w13_weight_scale.data, + layer.w13_weight_scale_second.data) + layer.w2_weight_scale_second.data, w2_bias = self.process_scale( + layer.w2_weight, layer.w2_weight_scale.data, + layer.w2_weight_scale_second.data) + + self.update_bias(layer, w13_bias, w2_bias) + + layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data) + layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data) diff --git a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py new file mode 100644 index 0000000..9de9cc7 --- /dev/null +++ b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py @@ -0,0 +1,1016 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, Callable, Dict, Optional, Tuple, Union + +import torch +import torch.distributed as dist +import torch_npu +from vllm.distributed import GroupCoordinator, get_ep_group +from vllm.forward_context import get_forward_context + +import vllm_ascend.envs as envs_ascend +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ascend_forward_context import FusedMoEState +from vllm_ascend.distributed.parallel_state import get_mc2_group +from vllm_ascend.ops.layers.experts_selector import select_experts +from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion, + dispose_tensor, get_ascend_soc_version) + + +def torchair_apply_mlp_decode(hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + group_list: torch.Tensor, + dynamic_scale: torch.Tensor = None, + group_list_type: int = 1) -> torch.Tensor: + """ + apply MLP: gate_up_proj -> swiglu -> down_proj + Args: + hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size). + w1: expert weights1 with shape + (num_experts, hidden_size, intermediate_size * 2) + w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2) + w2: expert weights2 with shape + (num_experts, intermediate_size, hidden_size) + w2_scale: weights2 scale with shape (num_experts, hidden_size) + group_list: number of tokens for each expert, follow cumsum mode, and + with shape (num_experts). + transpose_weight: + w1: (num_experts, intermediate_size * 2, hidden_size) -> + (num_experts, hidden_size, intermediate_size * 2) + w2: (num_experts, hidden_size, intermediate_size) -> + (num_experts, intermediate_size, hidden_size) + Returns: + hidden_states: output hidden states after MLP. + """ + + if dynamic_scale is None: + unquantized_hidden_states = hidden_states + hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( + hidden_states) + # Dispose the original unquantized hidden states + # to save npu memory because they're no longer used. + dispose_tensor(unquantized_hidden_states) + else: + pertoken_scale = dynamic_scale + + # gmm1: gate_up_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w1], + split_item=3, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=torch.int32)[0] + + # act_fn: swiglu + hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( + x=hidden_states, + weight_scale=w1_scale, + activation_scale=pertoken_scale, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=group_list, + activate_left=True, + quant_mode=1, + ) + + # gmm2: down_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w2], + scale=[w2_scale], + per_token_scale=[swiglu_out_scale], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=w2_scale.dtype)[0] + return hidden_states + + +def torchair_apply_mlp(hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + group_list: torch.Tensor, + dynamic_scale: torch.Tensor = None, + group_list_type: int = 1, + w1_scale_bias: torch.Tensor = None, + w2_scale_bias: torch.Tensor = None) -> torch.Tensor: + """ + apply MLP: gate_up_proj -> swiglu -> down_proj + + Args: + hidden_states: input hidden states with shape (num_tokens, hidden_size). + w1: expert weights1 with shape + (num_experts, hidden_size, intermediate_size * 2) + w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2) + w2: expert weights2 with shape + (num_experts, intermediate_size, hidden_size) + w2_scale: weights2 scale with shape (num_experts, hidden_size) + group_list: number of tokens for each expert, follow cumsum mode, and + with shape (num_experts). + transpose_weight: + w1: (num_experts, intermediate_size * 2, hidden_size) -> + (num_experts, hidden_size, intermediate_size * 2) + w2: (num_experts, hidden_size, intermediate_size) -> + (num_experts, intermediate_size, hidden_size) + + Returns: + hidden_states: output hidden states after MLP. + """ + + if dynamic_scale is None: + unquantized_hidden_states = hidden_states + hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( + hidden_states) + # Dispose the original unquantized hidden states + # to save npu memory because they're no longer used. + dispose_tensor(unquantized_hidden_states) + else: + pertoken_scale = dynamic_scale + + bias1, bias2 = None, None + _output_dtype = w2_scale.dtype + + if w1_scale_bias is not None: + if group_list_type == 0: + group_list = torch.cat( + [group_list[:1], torch.diff(group_list, dim=0)]) + group_list_type = 1 + bias1 = [w1_scale_bias] + bias2 = [w2_scale_bias] + # TODO w4a8 scene: dynamic acquisition of dtype in the future + _output_dtype = torch.bfloat16 + + # gmm1: gate_up_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w1], + scale=[w1_scale], + bias=bias1, + per_token_scale=[pertoken_scale], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=_output_dtype)[0] + + # act_fn: swiglu + hidden_states = torch_npu.npu_swiglu(hidden_states) + hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant( + hidden_states) + + # gmm2: down_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w2], + scale=[w2_scale], + bias=bias2, + per_token_scale=[swiglu_out_scale], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=_output_dtype)[0] + + return hidden_states + + +def torchair_fused_experts_with_mc2( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + expert_map: torch.Tensor = None, + moe_all_to_all_group_name: str = "", + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, + shared_experts: Optional[Any] = None, + is_torchair: bool = False, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, + mc2_mask: Optional[torch.Tensor] = None, + shared_gate_up: Optional[Any] = None, + shared_dequant_scale: Optional[Any] = None, + w1_scale_bias: torch.Tensor = None, + w2_scale_bias: torch.Tensor = None, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + assert mc2_mask is not None + if log2phy is not None: + topk_ids = log2phy[topk_ids] + + quant_mode = 2 + ep_group = get_mc2_group() + ep_rank_id = ep_group.rank_in_group + ep_world_size = ep_group.world_size + + # NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine + need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3 + or is_torchair) + + # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine + a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3 + + enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2") + + if (expert_map is not None): + moe_expert_num = len(expert_map) + global_redundant_expert_num + else: + moe_expert_num = global_redundant_expert_num + # hidden_states = hidden_states.bfloat16() + kwargs_mc2 = { + "x": hidden_states, + "expert_ids": topk_ids, + "expert_shard_type": 0, + "shared_expert_rank_num": 0, + "moe_expert_num": moe_expert_num, + "global_bs": 0, + } + + stage1_kwargs = { + "scales": None, + "quant_mode": quant_mode, + "group_ep": moe_all_to_all_group_name, + "ep_world_size": ep_world_size, + "ep_rank_id": ep_rank_id, + } + if need_extra_args: + stage1_kwargs.update({ + "group_tp": moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + }) + if a3_need_extra_args and enable_dispatch_v2: + stage1_kwargs.update({ + "x_active_mask": mc2_mask, + }) + kwargs_mc2.update(stage1_kwargs) + + output = torch_npu.npu_moe_distribute_dispatch_v2( + **kwargs_mc2 + ) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch( + **kwargs_mc2) + # comm_stream.wait_stream(torch.npu.current_stream()) + expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, ep_recv_counts = output[ + 0:5] + + if shared_experts is not None: + with npu_stream_switch("moe_secondary", 0): + npu_wait_tensor(shared_gate_up, expand_x) + shared_act_out = shared_experts.act_fn( + (shared_gate_up, shared_dequant_scale)) + shared_act, swiglu_out_scale = shared_act_out[0], shared_act_out[1] + + # `expand_x` will be disposed in the `apply_mlp` function + if w1_scale_bias is None: + down_out_list = torchair_apply_mlp_decode(expand_x, + w1, + w1_scale, + w2, + w2_scale, + expert_token_nums, + dynamic_scale=dynamic_scale) + else: + # w4a8 scene, cannot use apply_mlp_decode because the operator is not supported + down_out_list = torchair_apply_mlp(expand_x, + w1, + w1_scale, + w2, + w2_scale, + expert_token_nums, + dynamic_scale=dynamic_scale, + w1_scale_bias=w1_scale_bias, + w2_scale_bias=w2_scale_bias) + + # moeCombine + kwargs_mc2 = { + "expand_x": down_out_list, + "expert_ids": topk_ids, + "expert_scales": topk_weights.to(torch.float32), + "expert_shard_type": 0, + "shared_expert_rank_num": 0, + "moe_expert_num": moe_expert_num, + "global_bs": 0, + } + tp_recv_counts = torch.empty(1, + dtype=torch.int32, + device=hidden_states.device) + stage3_kwargs = { + "ep_send_counts": ep_recv_counts, + "group_ep": moe_all_to_all_group_name, + "ep_world_size": ep_world_size, + "ep_rank_id": ep_rank_id, + } + if enable_dispatch_v2: + stage3_kwargs.update({ + "assist_info_for_combine": + assist_info_for_combine, + }) + else: + stage3_kwargs.update({ + "expand_idx": assist_info_for_combine, + }) + if need_extra_args: + stage3_kwargs.update({ + "tp_send_counts": tp_recv_counts, + "group_tp": moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + }) + if a3_need_extra_args and enable_dispatch_v2: + stage3_kwargs.update({ + "x_active_mask": mc2_mask, + }) + kwargs_mc2.update(stage3_kwargs) + + hidden_states = torch_npu.npu_moe_distribute_combine_v2( + **kwargs_mc2 + ) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine( + **kwargs_mc2) + + if shared_experts is None: + return hidden_states + else: + with npu_stream_switch("moe_secondary", 0): + npu_wait_tensor(shared_act, down_out_list) + shared_output, _ = shared_experts.down_proj( + (shared_act, swiglu_out_scale)) + return hidden_states, shared_output + + +def torchair_init_routing_quant(hidden_states, top_k, topk_ids, + global_num_experts): + num_tokens, _ = hidden_states.shape + row_idx_len = num_tokens * top_k + row_idx = (torch.arange(0, + row_idx_len, + dtype=torch.int32, + device=hidden_states.device).view( + top_k, -1).permute(1, 0).contiguous()) + hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( + hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=num_tokens) + + expanded_row_idx = (expanded_row_idx.view(top_k, -1).permute( + 1, 0).contiguous().view(-1)) + global_expert_tokens = torch.bincount(expanded_expert_idx, + minlength=global_num_experts) + global_expert_tokens = global_expert_tokens.to(torch.int32) + quantized_tokens, token_scales = torch_npu.npu_dynamic_quant(hidden_states) + return quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales + + +# currently expert parallelism implemented with all2all +# is under-optimized. +def torchair_fused_experts_with_all2all( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + expert_map: torch.Tensor = None, + ep_group: GroupCoordinator = None, + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, + w1_scale_bias: torch.Tensor = None, + w2_scale_bias: torch.Tensor = None, +): + if log2phy is not None: + topk_ids = log2phy[topk_ids] + original_shape = hidden_states.shape + if len(original_shape) == 3: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + num_tokens, _ = hidden_states.shape + num_experts = w1.shape[0] + + if expert_map is not None: + global_num_experts = len(expert_map) + global_redundant_expert_num + if hasattr(torch_npu, "npu_moe_init_routing_quant"): + quantized_tokens, expanded_row_idx, global_expert_tokens, _, token_scales = torch_npu.npu_moe_init_routing_quant( + hidden_states, + expert_idx=topk_ids.to(torch.int32), + active_num=0, + expert_capacity=0, + expert_num=global_num_experts, + drop_pad_mode=0, + expert_tokens_num_mode=2, + expert_tokens_before_capacity_flag=False, + quant_mode=1, + ) + else: + quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales = torchair_init_routing_quant( + hidden_states, top_k, topk_ids, global_num_experts) + + gather_sizes = global_expert_tokens.new_empty( + global_expert_tokens.shape[0]) + dist.all_to_all_single(gather_sizes, global_expert_tokens) + + token_counts_combined = torch.stack( + [gather_sizes, global_expert_tokens], dim=0) + token_counts_combined = token_counts_combined.view( + 2, ep_group.world_size, -1).sum(dim=2) + token_counts_combined_cpu = token_counts_combined.to( + torch.device("cpu"), non_blocking=True).numpy() + all_tokens = gather_sizes.sum() + + gathered_tokens = quantized_tokens.new_empty(all_tokens.item(), + quantized_tokens.shape[1]) + dynamic_scale = token_scales.new_empty(gathered_tokens.shape[0]) + gather_size_list = token_counts_combined_cpu[1] + scatter_size_list = token_counts_combined_cpu[0] + + dist.all_to_all_single(gathered_tokens, quantized_tokens, + scatter_size_list, gather_size_list) + dist.all_to_all_single(dynamic_scale, token_scales, scatter_size_list, + gather_size_list) + + hidden_states, dynamic_scale, inverse_indices, expert_tokens = torch_npu.npu_moe_re_routing( + gathered_tokens, + gather_sizes.view(ep_group.world_size, -1), + per_token_scales=dynamic_scale) + expert_tokens = expert_tokens.to(torch.int64) + group_list_type = 1 + else: + row_idx_len = num_tokens * top_k + row_idx = torch.arange(0, + row_idx_len, + dtype=torch.int32, + device=topk_weights.device).view( + top_k, -1).permute(1, 0).contiguous() + hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( + hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=num_tokens) + + expert_tokens = torch_npu.npu_moe_compute_expert_tokens( + expanded_expert_idx, num_experts) + expert_tokens = expert_tokens.to(torch.int64) + group_list_type = 0 + dynamic_scale = None + + # `hidden_states` will be disposed in the `apply_mlp` function + hidden_states = torchair_apply_mlp( + hidden_states, + w1, + w1_scale, #17 + w2, + w2_scale, + expert_tokens, #16 + dynamic_scale=dynamic_scale, + group_list_type=group_list_type, + w1_scale_bias=w1_scale_bias, + w2_scale_bias=w2_scale_bias) + + if expert_map is not None: + reordered_outputs = torch.index_select( + hidden_states, + dim=0, + # Workaround: Convert to float so that argsort runs on AI Core instead of slower AICPU + index=inverse_indices.to(torch.float32).argsort().to(torch.int32)) + + hidden_states = reordered_outputs.new_empty(*quantized_tokens.shape) + dist.all_to_all_single(hidden_states, reordered_outputs, + gather_size_list, scatter_size_list) + + final_hidden_states = torch_npu.npu_moe_finalize_routing( + hidden_states, + skip1=None, + skip2=None, + bias=None, + scales=topk_weights, + expanded_src_to_dst_row=expanded_row_idx, + export_for_source_row=None, + drop_pad_mode=2) + else: + # TODO: Reorder device memory 2 times here, replace the current + # implementation here when suitable operators become available. + final_hidden_states = torch_npu.npu_moe_finalize_routing( + hidden_states, + skip1=None, + skip2=None, + bias=None, + scales=topk_weights, + expanded_src_to_dst_row=expanded_row_idx, + export_for_source_row=topk_ids, + ) + if len(original_shape) == 3: + final_hidden_states = final_hidden_states.view(original_shape) + return final_hidden_states + + +def torchair_fused_experts_with_allgather(hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + expert_map: torch.Tensor = None): + original_shape = hidden_states.shape + if len(original_shape) == 3: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + num_tokens = hidden_states.shape[0] + batch_size, hidden_size = hidden_states.shape + topk_weights = topk_weights.to(hidden_states.dtype) + + ep_group = get_ep_group().device_group + ep_rank = torch.distributed.get_rank(group=ep_group) + ep_size = torch.distributed.get_world_size(ep_group) + + global_num_experts = len(expert_map) + local_num_experts = global_num_experts // ep_size + + hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states) + + hidden_states, expanded_x_idx, expert_tokens, pertoken_scale = torch_npu.npu_moe_init_routing_v2( + hidden_states, + topk_ids, + scale=pertoken_scale, + offset=None, + active_num=num_tokens * top_k, + expert_num=global_num_experts, + expert_tokens_num_type=1, + expert_tokens_num_flag=True, + active_expert_range=[ + ep_rank * local_num_experts, (ep_rank + 1) * local_num_experts + ], + quant_mode=-1, + row_idx_type=1) + group_list_type = 1 + + sorted_topk_weight = torch.index_select(topk_weights.view(-1), 0, + expanded_x_idx) + row_index = expanded_x_idx // topk_ids.shape[-1] + row_index = row_index.to(torch.int64) + share_input = torch.zeros((batch_size, hidden_size), + dtype=torch.bfloat16, + device="npu") + + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w1], + split_item=3, + group_list_type=group_list_type, + group_type=0, + group_list=expert_tokens, + output_dtype=torch.int32)[0] + + # act_fn: swiglu + hidden_states, pertoken_scale = torch_npu.npu_dequant_swiglu_quant( + x=hidden_states, + weight_scale=w1_scale.to(torch.float32), + activation_scale=pertoken_scale, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=expert_tokens, + activate_left=True, + quant_mode=1, + ) + + final_hidden_states = torch_npu.npu_grouped_matmul_finalize_routing( + hidden_states, + w2, + scale=w2_scale.to(torch.float32), + bias=None, + pertoken_scale=pertoken_scale.view(-1), + group_list=expert_tokens, + shared_input=share_input, + logit=sorted_topk_weight.to(torch.float32), + row_index=row_index, + output_bs=batch_size).to(torch.bfloat16) + + if len(original_shape) == 3: + final_hidden_states = final_hidden_states.view(original_shape) + + return final_hidden_states + + +def torchair_fused_experts(hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + expert_map: torch.Tensor = None): + original_shape = hidden_states.shape + if len(original_shape) == 3: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + num_tokens, _ = hidden_states.shape + num_experts = w1.shape[0] + dtype = hidden_states.dtype + device = hidden_states.device + + if expert_map is not None: + # Generate token indices and flatten + token_indices = (torch.arange(num_tokens, + device=device, + dtype=torch.int64).unsqueeze(1).expand( + -1, top_k).reshape(-1)) + + # Flatten token-to-expert mappings and map to local experts + weights_flat = topk_weights.view(-1) + experts_flat = topk_ids.view(-1) + local_experts_flat = expert_map[experts_flat] + + # Filter valid token-expert pairs + mask = local_experts_flat != -1 + filtered_weights = torch.where( + mask, weights_flat, torch.zeros_like(weights_flat)).to(dtype) + filtered_experts = torch.where( + mask, local_experts_flat, + torch.full_like(local_experts_flat, + num_experts)).to(topk_ids.dtype) + + # Sort by local expert IDs + sort_indices = torch.argsort(filtered_experts) + sorted_token_indices = token_indices[sort_indices] + sorted_weights = filtered_weights[sort_indices] + + # Compute token counts with minlength of num_experts + # This is equivalent to but faster than: + # >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1] + token_counts = torch.zeros(num_experts + 1, + device=device, + dtype=torch.int64) + ones = torch.ones_like(filtered_experts, dtype=torch.int64) + token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones) + expert_tokens = token_counts[:num_experts] + # Rearrange hidden_states + hidden_states = hidden_states[sorted_token_indices] + group_list_type = 1 + else: + row_idx_len = num_tokens * top_k + row_idx = torch.arange(0, + row_idx_len, + dtype=torch.int32, + device=topk_weights.device).view( + top_k, -1).permute(1, 0).contiguous() + hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( + hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=num_tokens) + + expert_tokens = torch_npu.npu_moe_compute_expert_tokens( + expanded_expert_idx, num_experts) + expert_tokens = expert_tokens.to(torch.int64) + group_list_type = 0 + + # `hidden_states` will be disposed in the `apply_mlp` function + hidden_states = torchair_apply_mlp(hidden_states, + w1, + w1_scale, + w2, + w2_scale, + expert_tokens, + group_list_type=group_list_type) + + if expert_map is not None: + hidden_states.mul_(sorted_weights.unsqueeze(1)) + final_hidden_states = torch.zeros(*original_shape, + device=device, + dtype=dtype) + + num_valid_tokens = mask.sum() + valid_token_mask = torch.arange( + 0, sorted_token_indices.shape[0], + device=device).unsqueeze(1) < num_valid_tokens + hidden_states = hidden_states.masked_fill_(~valid_token_mask, + 0).to(dtype) + final_hidden_states.index_add_(0, sorted_token_indices, hidden_states) + else: + # TODO: Reorder device memory 2 times here, replace the current + # implementation here when suitable operators become available. + final_hidden_states = torch_npu.npu_moe_finalize_routing( + hidden_states, + skip1=None, + skip2=None, + bias=None, + scales=topk_weights, + expanded_src_to_dst_row=expanded_row_idx, + export_for_source_row=topk_ids, + ) + + if len(original_shape) == 3: + final_hidden_states = final_hidden_states.view(original_shape) + return final_hidden_states + + +class TorchairAscendW8A8DynamicLinearMethod: + """Linear method for Ascend W8A8_DYNAMIC. + """ + + def __init__(self): + self.transpose_weight = True + + @staticmethod + def get_weight(input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + params_dict = { + "weight": torch.empty(output_size, input_size, dtype=torch.int8) + } + return params_dict + + @staticmethod + def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: + return {} + + @staticmethod + def get_perchannel_param( + output_size: int, + params_dtype: torch.dtype, + ) -> Dict[str, Any]: + params_dict = {} + params_dict["weight_scale"] = torch.empty(output_size, + 1, + dtype=params_dtype) + params_dict["weight_offset"] = torch.empty(output_size, + 1, + dtype=params_dtype) + return params_dict + + def get_pergroup_param(self, input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + return {} + + @staticmethod + def apply( + layer: torch.nn.Module, + x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + bias: Optional[torch.Tensor] = None, + tp_rank: Optional[int] = 0, + ) -> torch.Tensor: + config = getattr(layer, "_ascend_quant_config", {}) + if not isinstance(x, tuple): + output_dtype = config.get("output_dtype", x.dtype) + quantized_x, dynamic_scale = torch_npu.npu_dynamic_quant(x) + else: + assert "output_dtype" in config.keys(), ( + f"DynamicLinearMethod needs explicitly specified `output_dtype`" + f"for pre-quantized input, got config [{config}]") + output_dtype = config["output_dtype"] + quantized_x, dynamic_scale = x + pertoken_scale = (dynamic_scale + if config.get("pertoken_scale", True) else None) + + output = torch_npu.npu_quant_matmul( + quantized_x, + layer.weight, + layer.weight_scale, + pertoken_scale=pertoken_scale, + bias=bias, + output_dtype=output_dtype, + ) + return ((output, dynamic_scale) + if config.get("return_scale", False) else output) + + def process_weights_after_loading(self, layer): + if self.transpose_weight: + layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() + # cast quantized weight tensors in NZ format (29) for higher inference speed + layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29) + layer.weight_scale.data = layer.weight_scale.data.flatten() + layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32) + layer.weight_offset.data = layer.weight_offset.data.flatten() + + +class TorchairAscendW8A8DynamicFusedMoEMethod: + """FusedMoe method for Ascend W8A8_DYNAMIC. + """ + + def __init__(self): + self.transpose_weight = True + + self.ep_group = get_ep_group() + + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + + try: + device_group = get_mc2_group().device_group + # TODO: Try local_rank = ep_group.rank_in_group + local_rank = torch.distributed.get_rank(group=device_group) + backend = device_group._get_backend(torch.device("npu")) + self.moe_all_to_all_group_name = backend.get_hccl_comm_name( + local_rank) + except AttributeError: + self.moe_all_to_all_group_name = "" + + @staticmethod + def get_weight(num_experts: int, intermediate_size_per_partition: int, + hidden_sizes: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + param_dict = {} + param_dict["w13_weight"] = torch.empty(num_experts, + 2 * + intermediate_size_per_partition, + hidden_sizes, + dtype=torch.int8) + param_dict["w2_weight"] = torch.empty(num_experts, + hidden_sizes, + intermediate_size_per_partition, + dtype=torch.int8) + return param_dict + + @staticmethod + def get_dynamic_quant_param(num_experts: int, + intermediate_size_per_partition: int, + hidden_sizes: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + param_dict = {} + param_dict["w13_weight_scale"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=params_dtype) + param_dict["w13_weight_offset"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=params_dtype) + param_dict["w2_weight_scale"] = torch.empty(num_experts, + hidden_sizes, + 1, + dtype=params_dtype) + param_dict["w2_weight_offset"] = torch.empty(num_experts, + hidden_sizes, + 1, + dtype=params_dtype) + return param_dict + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + is_prefill: bool = True, + enable_force_load_balance: bool = True, + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, + **kwargs, + ) -> torch.Tensor: + assert router_logits.shape[ + 1] == global_num_experts, "Number of global experts mismatch" + + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + global_num_experts=global_num_experts) + + fused_moe_state = get_forward_context().fused_moe_state + shared_gate_up, shared_dequant_scale = None, None + if shared_experts is not None and fused_moe_state == FusedMoEState.MC2: + with npu_stream_switch("moe_secondary", 0): + npu_wait_tensor(quantized_x_for_share, router_logits) + share_up_out, _ = shared_experts.gate_up_proj( + (quantized_x_for_share, dynamic_scale_for_share)) + shared_gate_up, shared_dequant_scale = share_up_out[ + 0], share_up_out[1] + + # this is a naive implementation for experts load balance so as + # to avoid accumulating too much tokens on a single rank. + # currently it is only activated when doing profile runs. + if enable_force_load_balance: + topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) + + topk_weights = topk_weights.to(x.dtype) + if fused_moe_state == FusedMoEState.AllGatherEP: + return torchair_fused_experts_with_allgather( + hidden_states=x, + w1=layer.w13_weight, + w1_scale=layer.w13_weight_scale, + w2=layer.w2_weight, + w2_scale=layer.w2_weight_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map) + elif fused_moe_state == FusedMoEState.MC2: + return torchair_fused_experts_with_mc2( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + w1_scale=layer.w13_weight_scale_fp32, + w2_scale=layer.w2_weight_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, + moe_all_to_all_group_name=self.moe_all_to_all_group_name, + log2phy=log2phy, + global_redundant_expert_num=global_redundant_expert_num, + shared_experts=shared_experts, + is_torchair=self.torchair_graph_enabled, + mc2_mask=kwargs.get("mc2_mask", None), + shared_gate_up=shared_gate_up, + shared_dequant_scale=shared_dequant_scale) + elif fused_moe_state in [ + FusedMoEState.AllGather, FusedMoEState.NaiveMulticast + ]: + return torchair_fused_experts(hidden_states=x, + w1=layer.w13_weight, + w1_scale=layer.w13_weight_scale, + w2=layer.w2_weight, + w2_scale=layer.w2_weight_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map) + else: + # The current implementation of deepseek moe splits hidden_states + # according to tp_size before they are feed into layers module. + # Therefore, all2all is needed no matter how dp/tp is set so as to + # dispatch/combine tokens. + return torchair_fused_experts_with_all2all( + hidden_states=x, + w1=layer.w13_weight, + w1_scale=layer.w13_weight_scale, + w2=layer.w2_weight, + w2_scale=layer.w2_weight_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, + ep_group=self.ep_group, + log2phy=log2phy, + global_redundant_expert_num=global_redundant_expert_num, + ) + + def process_weights_after_loading(self, layer): + if self.transpose_weight: + layer.w13_weight.data = layer.w13_weight.data.transpose( + 1, 2).contiguous() + layer.w2_weight.data = layer.w2_weight.data.transpose( + 1, 2).contiguous() + if envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP: + torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ) + layer.w13_weight_scale.data = layer.w13_weight_scale.data.view( + layer.w13_weight_scale.data.shape[0], -1) + layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to( + torch.float32) + layer.w13_weight_offset.data = layer.w13_weight_offset.data.view( + layer.w13_weight_offset.data.shape[0], -1) + layer.w2_weight_scale.data = layer.w2_weight_scale.data.view( + layer.w2_weight_scale.data.shape[0], -1) + layer.w2_weight_offset.data = layer.w2_weight_offset.data.view( + layer.w2_weight_offset.data.shape[0], -1) diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index c8e0bf1..fb4f583 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -38,6 +38,7 @@ from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata, check_torchair_cache_exist, converting_weight_acl_format, register_torchair_model, + torchair_quant_method_register, write_kv_cache_bytes_to_file) from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, is_310p) @@ -67,6 +68,7 @@ class NPUTorchairModelRunner(NPUModelRunner): self._check_batch_sizes_consistency() register_torchair_model() + torchair_quant_method_register() def _get_forward_metadata_across_dp_and_pad( self, num_tokens: int, with_prefill: bool, enable_dbo: bool diff --git a/vllm_ascend/torchair/utils.py b/vllm_ascend/torchair/utils.py index 9d3254f..8448ddc 100644 --- a/vllm_ascend/torchair/utils.py +++ b/vllm_ascend/torchair/utils.py @@ -170,3 +170,15 @@ def register_torchair_model(): ModelRegistry.register_model( "Qwen3ForCausalLM", "vllm_ascend.torchair.models.qwen3_moe:CustomQwen3MoeForCausalLM") + + +def torchair_quant_method_register(): + from vllm_ascend.quantization.quantizer import \ + SUPPORT_ASCEND_QUANTIZER_TYPE + from vllm_ascend.torchair.quantization.torchair_quantizer import ( + TorchairW4A8DYNAMICQuantizer, TorchairW8A8DYNAMICQuantizer) + + SUPPORT_ASCEND_QUANTIZER_TYPE[ + "W8A8_DYNAMIC"] = TorchairW8A8DYNAMICQuantizer + SUPPORT_ASCEND_QUANTIZER_TYPE[ + "W4A8_DYNAMIC"] = TorchairW4A8DYNAMICQuantizer