diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index bdb0c96..43b88b0 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -283,6 +283,7 @@ jobs: pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeekV3_dbo pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_alltoallv pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC + pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC pytest -sv tests/e2e/multicard/test_data_parallel.py pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py \ --ignore=tests/e2e/multicard/test_offline_inference_distributed.py \ diff --git a/tests/e2e/multicard/test_offline_inference_distributed.py b/tests/e2e/multicard/test_offline_inference_distributed.py index 7d1325d..f4d879a 100644 --- a/tests/e2e/multicard/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/test_offline_inference_distributed.py @@ -209,3 +209,28 @@ def test_models_distributed_Qwen3_W4A8DYNAMIC(): quantization="ascend", ) as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) + + +@patch.dict(os.environ, {"VLLM_ASCEND_MLA_PA": "1"}) +def test_models_distributed_DeepSeek_W4A8DYNAMIC(): + prompts = [ + "Hello, my name is", + ] + max_tokens = 5 + with VllmRunner( + snapshot_download("vllm-ascend/DeepSeek-R1-w4a8-pruning"), + dtype="auto", + tensor_parallel_size=2, + quantization="ascend", + enforce_eager=True, + enable_expert_parallel=True, + additional_config={ + "torchair_graph_config": { + "enabled": False, + }, + "ascend_scheduler_config": { + "enabled": True, + } + }, + ) as vllm_model: + vllm_model.generate_greedy(prompts, max_tokens) diff --git a/tests/ut/quantization/test_quantizer.py b/tests/ut/quantization/test_quantizer.py index 559cf19..a51faee 100644 --- a/tests/ut/quantization/test_quantizer.py +++ b/tests/ut/quantization/test_quantizer.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch from tests.ut.base import TestBase from vllm_ascend.quantization.quant_config import AscendQuantConfig from vllm_ascend.quantization.quantizer import (VLLMAscendQuantizer, + W4A8DYNAMICQuantizer, W8A8Quantizer) SUPPORT_ASCEND_QUANTIZER_TYPE = {"test": "1"} @@ -120,3 +121,25 @@ class TestW8A8Quantizer(TestBase): result = self.quantizer.build_attention_method() mock_linear.assert_called_once_with() self.assertIsInstance(result, MagicMock) + + +class TestW4A8DYNAMICQuantizer(TestBase): + + def setUp(self): + self.quantizer = W4A8DYNAMICQuantizer(quant_description={}) + + def test_build_linear_method(self): + with patch( + 'vllm_ascend.quantization.quantizer.AscendW4A8DynamicLinearMethod', + return_value=MagicMock()) as mock_linear: + result = self.quantizer.build_linear_method() + mock_linear.assert_called_once_with() + self.assertIsInstance(result, MagicMock) + + def test_build_moe_method(self): + with patch( + 'vllm_ascend.quantization.quantizer.AscendW4A8DynamicFusedMoEMethod', + return_value=MagicMock()) as mock_fused_moe: + result = self.quantizer.build_moe_method() + mock_fused_moe.assert_called_once_with() + self.assertIsInstance(result, MagicMock) diff --git a/tests/ut/quantization/test_w4a8_dynamic.py b/tests/ut/quantization/test_w4a8_dynamic.py index 445a0cd..8c52e32 100644 --- a/tests/ut/quantization/test_w4a8_dynamic.py +++ b/tests/ut/quantization/test_w4a8_dynamic.py @@ -1,7 +1,10 @@ +from unittest.mock import Mock, patch + import torch from tests.ut.base import TestBase -from vllm_ascend.quantization.w4a8_dynamic import AscendW4A8DynamicLinearMethod +from vllm_ascend.quantization.w4a8_dynamic import ( + AscendW4A8DynamicFusedMoEMethod, AscendW4A8DynamicLinearMethod) class TestAscendW4A8DynamicLinearMethod(TestBase): @@ -25,3 +28,82 @@ class TestAscendW4A8DynamicLinearMethod(TestBase): self.assertEqual(params["weight_scale_second"].shape, (32, 1)) self.assertEqual(params["weight_offset_second"].dtype, torch.bfloat16) self.assertEqual(params["weight_offset_second"].shape, (32, 1)) + + +class TestAscendW4A8DynamicFusedMoEMethod(TestBase): + + @patch('vllm_ascend.quantization.w4a8_dynamic.get_ep_group') + @patch("vllm_ascend.ascend_config.get_ascend_config") + @patch('vllm_ascend.quantization.w4a8_dynamic.get_mc2_group') + @patch('torch.distributed.get_rank', return_value=0) + def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ascend_config, + mock_get_ep_group): + mock_ascend_config = Mock() + mock_ascend_config.torchair_graph_config = Mock(enabled=False) + mock_get_ascend_config.return_value = mock_ascend_config + self.quant_method = AscendW4A8DynamicFusedMoEMethod() + + def test_get_weight(self): + param_dict = self.quant_method.get_weight(8, 4, 14, torch.bfloat16) + self.assertEqual(param_dict["w13_weight"].dtype, torch.int8) + self.assertEqual(param_dict["w13_weight"].shape, (8, 8, 14)) + + @patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config') + def test_get_dynamic_quant_param(self, mock_get_current_vllm_config): + mock_vllm_config = Mock() + mock_vllm_config.quant_config = Mock( + quant_description={"group_size": 2}) + mock_get_current_vllm_config.return_value = mock_vllm_config + param_dict = self.quant_method.get_dynamic_quant_param( + 8, 4, 14, torch.bfloat16) + self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16) + self.assertEqual(param_dict["w13_weight_scale"].shape, (8, 8, 1)) + self.assertEqual(param_dict["w13_weight_scale_second"].dtype, + torch.bfloat16) + self.assertEqual(param_dict["w13_weight_scale_second"].shape, + (8, 8, 7)) + self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.bfloat16) + self.assertEqual(param_dict["w2_weight_scale"].shape, (8, 14, 1)) + self.assertEqual(param_dict["w2_weight_scale_second"].dtype, + torch.bfloat16) + self.assertEqual(param_dict["w2_weight_scale_second"].shape, + (8, 14, 2)) + + @patch('torch_npu.npu_quantize') + @patch('torch.Tensor.npu') + def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize): + layer = torch.nn.Module() + layer.w13_weight = torch.nn.Parameter(torch.zeros((8, 8, 14), + dtype=torch.int8), + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(torch.zeros((8, 14, 4), + dtype=torch.int8), + requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter(torch.ones( + (8, 8, 1), dtype=torch.bfloat16), + requires_grad=False) + layer.w13_weight_offset = torch.nn.Parameter(torch.zeros( + (8, 8, 1), dtype=torch.bfloat16), + requires_grad=False) + layer.w13_weight_scale_second = torch.nn.Parameter(torch.ones( + (8, 8, 7), dtype=torch.bfloat16), + requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter(torch.ones( + (8, 14, 1), dtype=torch.bfloat16), + requires_grad=False) + layer.w2_weight_offset = torch.nn.Parameter(torch.zeros( + (8, 14, 1), dtype=torch.bfloat16), + requires_grad=False) + layer.w2_weight_scale_second = torch.nn.Parameter(torch.ones( + (8, 14, 2), dtype=torch.bfloat16), + requires_grad=False) + + mock_npu.return_value = torch.Tensor() + mock_npu_quantize.return_value = torch.Tensor() + self.quant_method.process_weights_after_loading(layer) + self.assertTrue(hasattr(layer, "w13_scale_bias")) + self.assertEqual(layer.w13_scale_bias.data.shape, (8, 8)) + self.assertEqual(layer.w13_scale_bias.data.dtype, torch.float32) + self.assertTrue(hasattr(layer, "w2_scale_bias")) + self.assertEqual(layer.w2_scale_bias.data.shape, (8, 14)) + self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32) diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index e1c2b1c..ce051c4 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -905,6 +905,8 @@ class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue + if "module" in name: + continue spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) if spec_layer is not None: diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 0b8935b..22c8bc8 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -302,6 +302,9 @@ class AscendFusedMoEMethod(FusedMoEMethodBase): param = torch.nn.Parameter(param_value, requires_grad=False) layer.register_parameter(param_key, param) set_weight_attrs(param, extra_weight_attrs) + if "weight_scale_second" in param_key or "weight_offset_second" in param_key: + setattr(param, "quant_method", + FusedMoeWeightScaleSupported.GROUP.value) def apply( self, @@ -348,4 +351,4 @@ class AscendEmbeddingMethod(AscendLinearMethod): packed_modules_mapping: Dict[str, Any]) -> None: self.quantizer = AscendQuantizer.get_quantizer( quant_config.quant_description, prefix, packed_modules_mapping) - self.quant_method = self.quantizer.build_linear_method() \ No newline at end of file + self.quant_method = self.quantizer.build_linear_method() diff --git a/vllm_ascend/quantization/quantizer.py b/vllm_ascend/quantization/quantizer.py index 90c7512..487597c 100644 --- a/vllm_ascend/quantization/quantizer.py +++ b/vllm_ascend/quantization/quantizer.py @@ -24,7 +24,8 @@ from vllm.logger import logger from .func_wrapper import (wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init, wrapper_vocab_parallel_embedding_init) -from .w4a8_dynamic import AscendW4A8DynamicLinearMethod +from .w4a8_dynamic import (AscendW4A8DynamicFusedMoEMethod, + AscendW4A8DynamicLinearMethod) from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod, AscendW8A8LinearMethod) from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod, @@ -97,12 +98,15 @@ class VLLMAscendQuantizer: if target_function is not None: setattr(original_module, target_function, candidate) - for key, value in sys.modules.copy().items(): - if (target_function is not None - and hasattr(value, target_function) - and id(getattr(value, - target_function)) == original_function_id): - setattr(value, target_function, candidate) + for _, value in sys.modules.copy().items(): + if target_function is None: + continue + try: + attr = getattr(value, target_function, None) + if attr is not None and id(attr) == original_function_id: + setattr(value, target_function, candidate) + except ImportError: + continue @staticmethod def parse_path(module_path, function_name, create_dummy): @@ -268,6 +272,10 @@ class W4A8DYNAMICQuantizer(VLLMAscendQuantizer): def build_linear_method(): return AscendW4A8DynamicLinearMethod() + @staticmethod + def build_moe_method(): + return AscendW4A8DynamicFusedMoEMethod() + class W8A8Quantizer(VLLMAscendQuantizer): diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py index c37b1c4..0b62fe1 100644 --- a/vllm_ascend/quantization/w4a8_dynamic.py +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -15,11 +15,22 @@ # limitations under the License. # -from typing import Any, Dict, Optional +from typing import Any, Callable, Dict, Optional +import numpy as np import torch import torch_npu from vllm.config import get_current_vllm_config +from vllm.distributed import get_ep_group +from vllm.forward_context import get_forward_context + +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ascend_forward_context import FusedMoEState +from vllm_ascend.distributed.parallel_state import get_mc2_group +from vllm_ascend.ops.fused_moe import select_experts +from vllm_ascend.quantization.w8a8_dynamic import (fused_experts_with_all2all, + fused_experts_with_mc2) +from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor class AscendW4A8DynamicLinearMethod: @@ -111,3 +122,275 @@ class AscendW4A8DynamicLinearMethod: layer.register_parameter("weight_scale_bias", param) layer.weight.data = torch_npu.npu_convert_weight_to_int4pack( layer.weight.data.to(torch.int32)) + + +class AscendW4A8DynamicFusedMoEMethod: + """FusedMoe method for Ascend W4A8_DYNAMIC. + """ + + def __init__(self): + self.transpose_weight = True + + self.ep_group = get_ep_group() + + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + + try: + device_group = get_mc2_group().device_group + # TODO: Try local_rank = ep_group.rank_in_group + local_rank = torch.distributed.get_rank(group=device_group) + backend = device_group._get_backend(torch.device("npu")) + self.moe_all_to_all_group_name = backend.get_hccl_comm_name( + local_rank) + except AttributeError: + self.moe_all_to_all_group_name = "" + + @staticmethod + def get_weight(num_experts: int, intermediate_size_per_partition: int, + hidden_sizes: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + param_dict = {} + param_dict["w13_weight"] = torch.empty(num_experts, + 2 * + intermediate_size_per_partition, + hidden_sizes, + dtype=torch.int8) + param_dict["w2_weight"] = torch.empty(num_experts, + hidden_sizes, + intermediate_size_per_partition, + dtype=torch.int8) + return param_dict + + @staticmethod + def get_dynamic_quant_param(num_experts: int, + intermediate_size_per_partition: int, + hidden_sizes: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + param_dict = {} + config = get_current_vllm_config() + group_size = config.quant_config.quant_description.get( + "group_size", 256) + + param_dict["w13_weight_scale"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=params_dtype) + + param_dict["w13_weight_offset"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=params_dtype) + + param_dict["w13_weight_scale_second"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_sizes // group_size, + dtype=params_dtype) + + param_dict["w13_weight_offset_second"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_sizes // group_size, + dtype=params_dtype) + + param_dict["w2_weight_scale"] = torch.empty(num_experts, + hidden_sizes, + 1, + dtype=params_dtype) + param_dict["w2_weight_offset"] = torch.empty(num_experts, + hidden_sizes, + 1, + dtype=params_dtype) + param_dict["w2_weight_scale_second"] = torch.empty( + num_experts, + hidden_sizes, + intermediate_size_per_partition // group_size, + dtype=params_dtype) + param_dict["w2_weight_offset_second"] = torch.empty( + num_experts, + hidden_sizes, + intermediate_size_per_partition // group_size, + dtype=params_dtype) + + return param_dict + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + is_prefill: bool = True, + enable_force_load_balance: bool = True, + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, + **kwargs, + ) -> torch.Tensor: + assert router_logits.shape[ + 1] == global_num_experts, "Number of global experts mismatch" + + # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern + if global_num_experts == 256: + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits, + k=top_k, # topk currently is 8 + bias=e_score_correction_bias, + k_group=topk_group, # fix: 4 + group_count=num_expert_group, # fix 8 + group_select_mode= + 1, # 0: the maximum in the group; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=1, # 0: softmax; 1: sigmoid(fix) + # out_flag=False, # todo new api; should the third output be output + # y2_flag=False, # old api; should the third output be output + routed_scaling_factor=1, + eps=float(1e-20)) + else: + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + + fused_moe_state = get_forward_context().fused_moe_state + shared_gate_up, shared_dequant_scale = None, None + if shared_experts is not None and fused_moe_state == FusedMoEState.MC2: + with npu_stream_switch("moe_secondary", 0): + npu_wait_tensor(quantized_x_for_share, router_logits) + share_up_out, _ = shared_experts.gate_up_proj( + (quantized_x_for_share, dynamic_scale_for_share)) + shared_gate_up, shared_dequant_scale = share_up_out[ + 0], share_up_out[1] + + # this is a naive implementation for experts load balance so as + # to avoid accumulating too much tokens on a single rank. + # currently it is only activated when doing profile runs. + if enable_force_load_balance: + topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) + + topk_weights = topk_weights.to(x.dtype) + if fused_moe_state == FusedMoEState.MC2: + return fused_experts_with_mc2( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + w1_scale=layer.w13_weight_scale_second, + w2_scale=layer.w2_weight_scale_second, + w1_scale_bias=layer.w13_scale_bias, + w2_scale_bias=layer.w2_scale_bias, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, + moe_all_to_all_group_name=self.moe_all_to_all_group_name, + log2phy=log2phy, + global_redundant_expert_num=global_redundant_expert_num, + shared_experts=shared_experts, + is_torchair=self.torchair_graph_enabled, + quantized_x_for_share=shared_gate_up, + dynamic_scale_for_share=shared_dequant_scale, + mc2_mask=kwargs.get("mc2_mask", None)) + else: + # The current implementation of deepseek moe splits hidden_states + # according to tp_size before they are feed into fused_moe module. + # Therefore, all2all is needed no matter how dp/tp is set so as to + # dispatch/combine tokens. + return fused_experts_with_all2all( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + w1_scale=layer.w13_weight_scale_second, + w2_scale=layer.w2_weight_scale_second, + w1_scale_bias=layer.w13_scale_bias, + w2_scale_bias=layer.w2_scale_bias, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, + ep_group=self.ep_group, + log2phy=log2phy, + global_redundant_expert_num=global_redundant_expert_num, + ) + + def process_scale(self, weight: torch.Tensor, scale, per_group_scale): + group_num, k, n = weight.shape + per_group_scale = per_group_scale.reshape(group_num, -1, n) + group_num, quantgroup_num, n = per_group_scale.shape + weight_high = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * \ + per_group_scale.reshape([group_num, quantgroup_num, 1, n]) + weight_high = weight_high.reshape([group_num, k, n]) + bias = 8 * (weight_high.to(torch.float32) * scale).sum(axis=1) + scale_fp32 = (scale * per_group_scale).to(torch.float16).to( + torch.float32) + scale_fp32_np = scale_fp32.cpu().numpy() + scale_fp32_np.dtype = np.uint32 + sscale_uint64 = np.zeros((group_num, quantgroup_num, n * 2), + dtype=np.uint32) + + sscale_uint64[..., ::2] = scale_fp32_np + + sscale_uint64_buffer = np.frombuffer(sscale_uint64.tobytes(), + dtype=np.int64).copy() + sscale_uint64_tensor = torch.from_numpy(sscale_uint64_buffer).reshape( + group_num, quantgroup_num, n) + sscale_uint64_tensor = sscale_uint64_tensor.npu() + return sscale_uint64_tensor, bias + + def process_weights_after_loading(self, layer): + if self.transpose_weight: + layer.w13_weight.data = layer.w13_weight.data.transpose( + 1, 2).contiguous() + layer.w2_weight.data = layer.w2_weight.data.transpose( + 1, 2).contiguous() + layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose( + 1, 2).contiguous() + layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose( + 1, 2).contiguous() + layer.w13_weight_offset.data = layer.w13_weight_offset.data.view( + layer.w13_weight_offset.data.shape[0], -1) + layer.w2_weight_offset.data = layer.w2_weight_offset.data.view( + layer.w2_weight_offset.data.shape[0], -1) + layer.w13_weight_scale_second.data = layer.w13_weight_scale_second.data.transpose( + 1, 2).contiguous() + layer.w2_weight_scale_second.data = layer.w2_weight_scale_second.data.transpose( + 1, 2).contiguous() + + layer.w13_weight_scale_second.data, bias = self.process_scale( + layer.w13_weight, layer.w13_weight_scale.data, + layer.w13_weight_scale_second.data) + param = torch.nn.Parameter(bias, requires_grad=False) + layer.register_parameter("w13_scale_bias", param) + layer.w2_weight_scale_second.data, bias1 = self.process_scale( + layer.w2_weight, layer.w2_weight_scale.data, + layer.w2_weight_scale_second.data) + param = torch.nn.Parameter(bias1, requires_grad=False) + layer.register_parameter("w2_scale_bias", param) + + layer.w13_weight.data = torch_npu.npu_quantize( + layer.w13_weight.data.to(torch.float32), + torch.tensor([1.]).npu(), None, torch.quint4x2, -1, False) + layer.w2_weight.data = torch_npu.npu_quantize( + layer.w2_weight.data.to(torch.float32), + torch.tensor([1.]).npu(), None, torch.quint4x2, -1, False) diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index e4afbb5..affc489 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -116,7 +116,9 @@ def apply_mlp(hidden_states: torch.Tensor, w2_scale: torch.Tensor, group_list: torch.Tensor, dynamic_scale: torch.Tensor = None, - group_list_type: int = 1) -> torch.Tensor: + group_list_type: int = 1, + w1_scale_bias: torch.Tensor = None, + w2_scale_bias: torch.Tensor = None) -> torch.Tensor: """ apply MLP: gate_up_proj -> swiglu -> down_proj @@ -150,17 +152,31 @@ def apply_mlp(hidden_states: torch.Tensor, else: pertoken_scale = dynamic_scale + bias1, bias2 = None, None + _output_dtype = w2_scale.dtype + + if w1_scale_bias is not None: + if group_list_type == 0: + group_list = torch.cat( + [group_list[:1], torch.diff(group_list, dim=0)]) + group_list_type = 1 + bias1 = [w1_scale_bias] + bias2 = [w2_scale_bias] + # TODO w4a8 scene: dynamic acquisition of dtype in the future + _output_dtype = torch.bfloat16 + # gmm1: gate_up_proj hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], weight=[w1], scale=[w1_scale], + bias=bias1, per_token_scale=[pertoken_scale], split_item=2, group_list_type=group_list_type, group_type=0, group_list=group_list, - output_dtype=w2_scale.dtype)[0] + output_dtype=_output_dtype)[0] # act_fn: swiglu hidden_states = torch_npu.npu_swiglu(hidden_states) @@ -172,12 +188,13 @@ def apply_mlp(hidden_states: torch.Tensor, x=[hidden_states], weight=[w2], scale=[w2_scale], + bias=bias2, per_token_scale=[swiglu_out_scale], split_item=2, group_list_type=group_list_type, group_type=0, group_list=group_list, - output_dtype=w2_scale.dtype)[0] + output_dtype=_output_dtype)[0] return hidden_states @@ -202,6 +219,8 @@ def fused_experts_with_mc2( mc2_mask: Optional[torch.Tensor] = None, shared_gate_up: Optional[Any] = None, shared_dequant_scale: Optional[Any] = None, + w1_scale_bias: torch.Tensor = None, + w2_scale_bias: torch.Tensor = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: assert mc2_mask is not None if log2phy is not None: @@ -270,13 +289,25 @@ def fused_experts_with_mc2( shared_act, swiglu_out_scale = shared_act_out[0], shared_act_out[1] # `expand_x` will be disposed in the `apply_mlp` function - down_out_list = apply_mlp_decode(expand_x, - w1, - w1_scale, - w2, - w2_scale, - expert_token_nums, - dynamic_scale=dynamic_scale) + if w1_scale_bias is None: + down_out_list = apply_mlp_decode(expand_x, + w1, + w1_scale, + w2, + w2_scale, + expert_token_nums, + dynamic_scale=dynamic_scale) + else: + # w4a8 scene, cannot use apply_mlp_decode because the operator is not supported + down_out_list = apply_mlp(expand_x, + w1, + w1_scale, + w2, + w2_scale, + expert_token_nums, + dynamic_scale=dynamic_scale, + w1_scale_bias=w1_scale_bias, + w2_scale_bias=w2_scale_bias) # moeCombine kwargs_mc2 = { @@ -372,6 +403,8 @@ def fused_experts_with_all2all( ep_group: GroupCoordinator = None, log2phy: torch.Tensor = None, global_redundant_expert_num: int = 0, + w1_scale_bias: torch.Tensor = None, + w2_scale_bias: torch.Tensor = None, ): if log2phy is not None: topk_ids = log2phy[topk_ids] @@ -457,7 +490,9 @@ def fused_experts_with_all2all( w2_scale, expert_tokens, #16 dynamic_scale=dynamic_scale, - group_list_type=group_list_type) + group_list_type=group_list_type, + w1_scale_bias=w1_scale_bias, + w2_scale_bias=w2_scale_bias) if expert_map is not None: reordered_outputs = torch.index_select(