[3/N][refactor] refactoer quantization (#2504)
### What this PR does / why we need it? Move torchair related qunatization section into torchair dir to make the code clear. Next step we'll remove all torchair related code outside of torchair quantization. ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? vLLM version: main vLLM main:ab9f2cfd19- vLLM version: v0.10.1.1 - vLLM main:959783fb99Signed-off-by: hust17yixuan <303660421@qq.com>
This commit is contained in:
176
tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py
Normal file
176
tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import copy
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.torchair.quantization.torchair_w4a8_dynamic import (
|
||||
TorchairAscendW4A8DynamicFusedMoEMethod,
|
||||
TorchairAscendW4A8DynamicLinearMethod)
|
||||
|
||||
|
||||
class TestAscendW4A8DynamicLinearMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.method = TorchairAscendW4A8DynamicLinearMethod()
|
||||
self.method.group_size = 8
|
||||
|
||||
def test_get_weight(self):
|
||||
weight = self.method.get_weight(8, 32, torch.bfloat16)
|
||||
self.assertEqual(weight["weight"].dtype, torch.int8)
|
||||
self.assertEqual(weight["weight"].shape, (32, 8))
|
||||
|
||||
def test_get_pergroup_param(self):
|
||||
params = self.method.get_pergroup_param(8, 32, torch.bfloat16)
|
||||
self.assertEqual(params["weight_scale"].dtype, torch.bfloat16)
|
||||
self.assertEqual(params["weight_scale"].shape, (32, 1))
|
||||
self.assertEqual(params["weight_offset"].dtype, torch.bfloat16)
|
||||
self.assertEqual(params["weight_offset"].shape, (32, 1))
|
||||
self.assertEqual(params["weight_scale_second"].dtype, torch.bfloat16)
|
||||
self.assertEqual(params["weight_scale_second"].shape, (32, 1))
|
||||
self.assertEqual(params["weight_offset_second"].dtype, torch.bfloat16)
|
||||
self.assertEqual(params["weight_offset_second"].shape, (32, 1))
|
||||
|
||||
|
||||
class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
|
||||
experts = 8
|
||||
input_size = 16
|
||||
output_size = 56
|
||||
group_size = 2
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_current_vllm_config'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_ep_group')
|
||||
@patch("vllm_ascend.ascend_config.get_ascend_config")
|
||||
@patch(
|
||||
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_mc2_group'
|
||||
)
|
||||
@patch('torch.distributed.get_rank', return_value=0)
|
||||
def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ascend_config,
|
||||
mock_get_ep_group, get_current_vllm_config):
|
||||
mock_ascend_config = Mock()
|
||||
mock_ascend_config.torchair_graph_config = Mock(enabled=False)
|
||||
mock_get_ascend_config.return_value = mock_ascend_config
|
||||
mock_vllm_config = Mock()
|
||||
mock_vllm_config.quant_config = Mock(quant_description={
|
||||
"group_size": self.group_size,
|
||||
"version": "0.0.0"
|
||||
})
|
||||
mock_vllm_config.parallel_config = Mock(enable_expert_parallel=True)
|
||||
get_current_vllm_config.return_value = mock_vllm_config
|
||||
self.quant_method = TorchairAscendW4A8DynamicFusedMoEMethod()
|
||||
|
||||
def test_get_weight(self):
|
||||
# old quant version w4a8 weight
|
||||
param_dict = self.quant_method.get_weight(self.experts,
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
|
||||
self.assertEqual(param_dict["w13_weight"].shape,
|
||||
(self.experts, 2 * self.input_size, self.output_size))
|
||||
# new quant version weight
|
||||
self.quant_method.new_quant_version = True
|
||||
param_dict = self.quant_method.get_weight(self.experts,
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
|
||||
self.assertEqual(param_dict["w13_weight"].shape,
|
||||
(self.experts, self.input_size, self.output_size))
|
||||
|
||||
def test_get_dynamic_quant_param(self):
|
||||
# old quant version weight
|
||||
param_dict = self.quant_method.get_dynamic_quant_param(
|
||||
self.experts, self.input_size, self.output_size, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight_scale"].shape,
|
||||
(self.experts, 2 * self.input_size, 1))
|
||||
self.assertEqual(param_dict["w13_weight_scale_second"].dtype,
|
||||
torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight_scale_second"].shape,
|
||||
(self.experts, 2 * self.input_size,
|
||||
self.output_size // self.group_size))
|
||||
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w2_weight_scale"].shape,
|
||||
(self.experts, self.output_size, 1))
|
||||
self.assertEqual(param_dict["w2_weight_scale_second"].dtype,
|
||||
torch.bfloat16)
|
||||
self.assertEqual(param_dict["w2_weight_scale_second"].shape,
|
||||
(self.experts, self.output_size,
|
||||
self.input_size // self.group_size))
|
||||
# new quant version weight
|
||||
self.quant_method.new_quant_version = True
|
||||
param_dict = self.quant_method.get_dynamic_quant_param(
|
||||
self.experts, self.input_size, self.output_size, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w2_scale_bias"].dtype, torch.float32)
|
||||
self.assertEqual(
|
||||
param_dict["w2_scale_bias"].shape,
|
||||
(self.experts, self.output_size, 16 // self.quant_method.tp_size))
|
||||
|
||||
@patch('torch_npu.npu_quantize')
|
||||
@patch('torch.Tensor.npu')
|
||||
def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize):
|
||||
# old quant version weight
|
||||
layer = torch.nn.Module()
|
||||
layer.w13_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, 2 * self.input_size, self.output_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, self.output_size, self.input_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, 2 * self.input_size, 1), dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale_second = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, 2 * self.input_size,
|
||||
self.output_size // self.group_size),
|
||||
dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, self.output_size, 1), dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale_second = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, self.output_size,
|
||||
self.input_size // self.group_size),
|
||||
dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
new_layer = copy.deepcopy(layer)
|
||||
|
||||
mock_npu.return_value = torch.Tensor()
|
||||
mock_npu_quantize.return_value = torch.Tensor()
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
self.assertTrue(hasattr(layer, "w13_scale_bias"))
|
||||
self.assertEqual(layer.w13_scale_bias.data.shape,
|
||||
(self.experts, 2 * self.input_size))
|
||||
self.assertEqual(layer.w13_scale_bias.data.dtype, torch.float32)
|
||||
self.assertTrue(hasattr(layer, "w2_scale_bias"))
|
||||
self.assertEqual(layer.w2_scale_bias.data.shape,
|
||||
(self.experts, self.output_size))
|
||||
self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32)
|
||||
# new quant version weight
|
||||
self.quant_method.new_quant_version = True
|
||||
new_layer.w13_weight.data = torch.zeros(
|
||||
(self.experts, self.input_size, self.output_size),
|
||||
dtype=torch.int8)
|
||||
new_layer.w2_weight.data = torch.zeros(
|
||||
(self.experts, self.output_size // 2, self.input_size),
|
||||
dtype=torch.int8)
|
||||
w13_scale_bias = torch.zeros((self.experts, 2 * self.input_size, 1),
|
||||
dtype=torch.float32)
|
||||
new_layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
|
||||
requires_grad=False)
|
||||
w2_scale_bias = torch.zeros(
|
||||
(self.experts, self.output_size, 16 // self.quant_method.tp_size),
|
||||
dtype=torch.float32)
|
||||
new_layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias,
|
||||
requires_grad=False)
|
||||
self.quant_method.process_weights_after_loading(new_layer)
|
||||
self.assertEqual(new_layer.w13_scale_bias.data.shape,
|
||||
(self.experts, 2 * self.input_size))
|
||||
self.assertEqual(new_layer.w2_scale_bias.data.shape,
|
||||
(self.experts, self.output_size))
|
||||
75
tests/ut/torchair/quantization/test_torchair_w8a8_dynamic.py
Normal file
75
tests/ut/torchair/quantization/test_torchair_w8a8_dynamic.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
|
||||
torchair_fused_experts_with_all2all
|
||||
|
||||
|
||||
class TestAscendW8A8FusedMoEMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.hidden_size = 128
|
||||
self.num_tokens = 128
|
||||
self.placeholder = torch.randn(self.num_tokens,
|
||||
self.hidden_size,
|
||||
dtype=torch.bfloat16)
|
||||
|
||||
@patch("torch.distributed.all_to_all_single")
|
||||
@patch("torch_npu.npu_moe_re_routing")
|
||||
@patch("torch_npu.npu_grouped_matmul")
|
||||
@patch("torch_npu.npu_swiglu")
|
||||
@patch("torch_npu.npu_dynamic_quant")
|
||||
@patch("torch_npu.npu_moe_finalize_routing")
|
||||
@patch("torch_npu.npu_moe_init_routing")
|
||||
def test_torchair_fused_experts_with_all2all(
|
||||
self, mock_moe_init_routing, mock_moe_finalize_routing,
|
||||
mock_dynamic_quant, mock_swiglu, mock_grouped_matmul,
|
||||
mock_moe_re_routing, mock_all_to_all_single):
|
||||
|
||||
expert_map = MagicMock()
|
||||
ep_group = MagicMock()
|
||||
placeholder_int8 = torch.randint(0,
|
||||
100,
|
||||
(self.num_tokens, self.hidden_size),
|
||||
dtype=torch.int8)
|
||||
placeholder_ones = torch.ones(self.num_tokens, dtype=torch.int32)
|
||||
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
|
||||
input)
|
||||
mock_moe_init_routing.return_value = (
|
||||
placeholder_int8,
|
||||
placeholder_ones,
|
||||
placeholder_ones,
|
||||
)
|
||||
mock_moe_re_routing.return_value = (placeholder_int8, self.placeholder,
|
||||
torch.randint(0,
|
||||
100,
|
||||
(self.num_tokens, ),
|
||||
dtype=torch.int32),
|
||||
self.placeholder)
|
||||
mock_grouped_matmul.return_value = self.placeholder
|
||||
mock_swiglu.return_value = self.placeholder
|
||||
mock_dynamic_quant.return_value = (
|
||||
placeholder_int8,
|
||||
torch.randn(self.num_tokens),
|
||||
)
|
||||
mock_moe_finalize_routing.return_value = self.placeholder
|
||||
|
||||
result = torchair_fused_experts_with_all2all(
|
||||
hidden_states=self.placeholder,
|
||||
w1=self.placeholder,
|
||||
w1_scale=self.placeholder,
|
||||
w2=self.placeholder,
|
||||
w2_scale=self.placeholder,
|
||||
topk_weights=self.placeholder,
|
||||
topk_ids=self.placeholder,
|
||||
top_k=8,
|
||||
expert_map=expert_map,
|
||||
ep_group=ep_group,
|
||||
log2phy=None,
|
||||
global_redundant_expert_num=256,
|
||||
)
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
self.assertEqual(result.shape, (128, 128))
|
||||
@@ -6,6 +6,7 @@ from unittest.mock import MagicMock, patch
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.quantizer import SUPPORT_ASCEND_QUANTIZER_TYPE
|
||||
from vllm_ascend.torchair import utils
|
||||
|
||||
|
||||
@@ -120,3 +121,15 @@ class TestTorchairUtils(TestBase):
|
||||
|
||||
utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
|
||||
mock_npu_cast.assert_not_called()
|
||||
|
||||
def test_torchair_quant_method_register(self):
|
||||
|
||||
TorchairW8A8DYNAMICQuantizer = SUPPORT_ASCEND_QUANTIZER_TYPE[
|
||||
"W8A8_DYNAMIC"]
|
||||
TorchairW4A8DYNAMICQuantizer = SUPPORT_ASCEND_QUANTIZER_TYPE[
|
||||
"W4A8_DYNAMIC"]
|
||||
utils.torchair_quant_method_register()
|
||||
self.assertNotEqual(TorchairW8A8DYNAMICQuantizer,
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE["W8A8_DYNAMIC"])
|
||||
self.assertNotEqual(TorchairW4A8DYNAMICQuantizer,
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE["W4A8_DYNAMIC"])
|
||||
|
||||
@@ -71,8 +71,9 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.quantization.quant_config import AscendLinearMethod
|
||||
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
|
||||
from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE
|
||||
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
|
||||
TorchairAscendW8A8DynamicLinearMethod
|
||||
from vllm_ascend.utils import dispose_tensor, npu_prefetch
|
||||
|
||||
|
||||
@@ -261,8 +262,9 @@ class TorchairDeepseekV2MLP(nn.Module):
|
||||
quant_method = self.gate_up_proj.quant_method
|
||||
if isinstance(quant_method, UnquantizedLinearMethod):
|
||||
self.act_fn = TorchairDeepseekV2SiluAndMul()
|
||||
elif (isinstance(quant_method, AscendLinearMethod) and isinstance(
|
||||
quant_method.quant_method, AscendW8A8DynamicLinearMethod)):
|
||||
elif (isinstance(quant_method, AscendLinearMethod)
|
||||
and isinstance(quant_method.quant_method,
|
||||
TorchairAscendW8A8DynamicLinearMethod)):
|
||||
# TODO(sdmyzlp): Currently preserved as before:
|
||||
# 1. The only quantization supported for silu is W8A8Dynamic
|
||||
# 2. Output dtype of gate_up/down is fixed to be int32/bfloat16
|
||||
|
||||
0
vllm_ascend/torchair/quantization/__init__.py
Normal file
0
vllm_ascend/torchair/quantization/__init__.py
Normal file
29
vllm_ascend/torchair/quantization/torchair_quantizer.py
Normal file
29
vllm_ascend/torchair/quantization/torchair_quantizer.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from vllm_ascend.quantization.quantizer import VLLMAscendQuantizer
|
||||
from vllm_ascend.torchair.quantization.torchair_w4a8_dynamic import (
|
||||
TorchairAscendW4A8DynamicFusedMoEMethod,
|
||||
TorchairAscendW4A8DynamicLinearMethod)
|
||||
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import (
|
||||
TorchairAscendW8A8DynamicFusedMoEMethod,
|
||||
TorchairAscendW8A8DynamicLinearMethod)
|
||||
|
||||
|
||||
class TorchairW8A8DYNAMICQuantizer(VLLMAscendQuantizer):
|
||||
|
||||
@staticmethod
|
||||
def build_linear_method():
|
||||
return TorchairAscendW8A8DynamicLinearMethod()
|
||||
|
||||
@staticmethod
|
||||
def build_moe_method():
|
||||
return TorchairAscendW8A8DynamicFusedMoEMethod()
|
||||
|
||||
|
||||
class TorchairW4A8DYNAMICQuantizer(VLLMAscendQuantizer):
|
||||
|
||||
@staticmethod
|
||||
def build_linear_method():
|
||||
return TorchairAscendW4A8DynamicLinearMethod()
|
||||
|
||||
@staticmethod
|
||||
def build_moe_method():
|
||||
return TorchairAscendW4A8DynamicFusedMoEMethod()
|
||||
424
vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py
Normal file
424
vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py
Normal file
@@ -0,0 +1,424 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import (
|
||||
torchair_fused_experts_with_all2all, torchair_fused_experts_with_mc2)
|
||||
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
||||
|
||||
|
||||
class TorchairAscendW4A8DynamicLinearMethod:
|
||||
"""Linear method for Ascend W4A8_DYNAMIC
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.transpose_weight = True
|
||||
try:
|
||||
self.group_size = get_current_vllm_config(
|
||||
).quant_config.quant_description.get("group_size", 256)
|
||||
except AttributeError:
|
||||
self.group_size = 256
|
||||
|
||||
@staticmethod
|
||||
def get_weight(input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
params_dict = {
|
||||
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
|
||||
}
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def get_perchannel_param(output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
def get_pergroup_param(self, input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
params_dict = {}
|
||||
params_dict["weight_scale"] = torch.empty(output_size,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
params_dict["weight_offset"] = torch.empty(output_size,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
params_dict["weight_scale_second"] = torch.empty(output_size,
|
||||
input_size //
|
||||
self.group_size,
|
||||
dtype=params_dtype)
|
||||
params_dict["weight_offset_second"] = torch.empty(output_size,
|
||||
input_size //
|
||||
self.group_size,
|
||||
dtype=params_dtype)
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def process_scale_second(weight: torch.Tensor, scale: torch.Tensor,
|
||||
per_group_scale: torch.Tensor):
|
||||
k, n = weight.shape
|
||||
group_num, n = per_group_scale.shape
|
||||
weight_high = weight.to(torch.float32).reshape(
|
||||
group_num, -1, n) * per_group_scale.reshape(group_num, 1, n)
|
||||
weight_high = weight_high.reshape(k, n)
|
||||
bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0)
|
||||
antiquant_scale = (scale * per_group_scale).reshape(group_num, n)
|
||||
return antiquant_scale.npu(), bias
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
tp_rank: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch_npu.npu_weight_quant_batchmatmul(
|
||||
x,
|
||||
layer.weight,
|
||||
antiquant_scale=layer.weight_scale_second.to(x.dtype),
|
||||
antiquant_group_size=self.group_size,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
if self.transpose_weight:
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
layer.weight_scale.data = layer.weight_scale.data.flatten().to(
|
||||
torch.float32)
|
||||
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
||||
layer.weight_scale_second.data, scale_bias = self.process_scale_second(
|
||||
layer.weight.data,
|
||||
layer.weight_scale.data,
|
||||
layer.weight_scale_second.data.transpose(0, 1).contiguous(),
|
||||
)
|
||||
param = torch.nn.Parameter(scale_bias, requires_grad=False)
|
||||
layer.register_parameter("weight_scale_bias", param)
|
||||
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(
|
||||
layer.weight.data.to(torch.int32))
|
||||
|
||||
|
||||
class TorchairAscendW4A8DynamicFusedMoEMethod:
|
||||
"""FusedMoe method for Ascend W4A8_DYNAMIC.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.transpose_weight = True
|
||||
|
||||
self.ep_group = get_ep_group()
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.group_size = vllm_config.quant_config.quant_description.get(
|
||||
"group_size", 256)
|
||||
quant_version = vllm_config.quant_config.quant_description.get(
|
||||
"version", "0")
|
||||
# NOTE: new quantize weights: 2 int4 pack into int8
|
||||
self.new_quant_version = quant_version == "1.0.0"
|
||||
self.tp_size = 1 if vllm_config.parallel_config.enable_expert_parallel else self.ep_group.world_size
|
||||
if self.new_quant_version and self.tp_size > 16:
|
||||
raise ValueError(
|
||||
"The current weight does not support moe part tp>16.")
|
||||
|
||||
try:
|
||||
device_group = get_mc2_group().device_group
|
||||
# TODO: Try local_rank = ep_group.rank_in_group
|
||||
local_rank = torch.distributed.get_rank(group=device_group)
|
||||
backend = device_group._get_backend(torch.device("npu"))
|
||||
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
|
||||
local_rank)
|
||||
except AttributeError:
|
||||
self.moe_all_to_all_group_name = ""
|
||||
|
||||
def get_weight(self, num_experts: int,
|
||||
intermediate_size_per_partition: int, hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
param_dict = {}
|
||||
if self.new_quant_version:
|
||||
w13_output_size = intermediate_size_per_partition
|
||||
w2_output_size = hidden_sizes // 2
|
||||
else:
|
||||
w13_output_size = 2 * intermediate_size_per_partition
|
||||
w2_output_size = hidden_sizes
|
||||
|
||||
param_dict["w13_weight"] = torch.empty(num_experts,
|
||||
w13_output_size,
|
||||
hidden_sizes,
|
||||
dtype=torch.int8)
|
||||
param_dict["w2_weight"] = torch.empty(num_experts,
|
||||
w2_output_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int8)
|
||||
return param_dict
|
||||
|
||||
def get_dynamic_quant_param(self, num_experts: int,
|
||||
intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
param_dict = {}
|
||||
param_dict["w13_weight_scale"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
|
||||
param_dict["w13_weight_offset"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
|
||||
param_dict["w13_weight_scale_second"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_sizes // self.group_size,
|
||||
dtype=params_dtype)
|
||||
|
||||
param_dict["w13_weight_offset_second"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_sizes // self.group_size,
|
||||
dtype=params_dtype)
|
||||
|
||||
param_dict["w2_weight_scale"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
param_dict["w2_weight_offset"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
param_dict["w2_weight_scale_second"] = torch.empty(
|
||||
num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition // self.group_size,
|
||||
dtype=params_dtype)
|
||||
param_dict["w2_weight_offset_second"] = torch.empty(
|
||||
num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition // self.group_size,
|
||||
dtype=params_dtype)
|
||||
|
||||
if self.new_quant_version:
|
||||
param_dict["w13_scale_bias"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_scale_bias"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
16 // self.tp_size,
|
||||
dtype=torch.float32)
|
||||
|
||||
return param_dict
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = True,
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts, "Number of global experts mismatch"
|
||||
|
||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
fused_moe_state = get_forward_context().fused_moe_state
|
||||
shared_gate_up, shared_dequant_scale = None, None
|
||||
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
|
||||
with npu_stream_switch("moe_secondary", 0):
|
||||
npu_wait_tensor(quantized_x_for_share, router_logits)
|
||||
share_up_out, _ = shared_experts.gate_up_proj(
|
||||
(quantized_x_for_share, dynamic_scale_for_share))
|
||||
shared_gate_up, shared_dequant_scale = share_up_out[
|
||||
0], share_up_out[1]
|
||||
|
||||
# this is a naive implementation for experts load balance so as
|
||||
# to avoid accumulating too much tokens on a single rank.
|
||||
# currently it is only activated when doing profile runs.
|
||||
if enable_force_load_balance:
|
||||
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
||||
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
if fused_moe_state == FusedMoEState.MC2:
|
||||
return torchair_fused_experts_with_mc2(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
w1_scale=layer.w13_weight_scale_second,
|
||||
w2_scale=layer.w2_weight_scale_second,
|
||||
w1_scale_bias=layer.w13_scale_bias,
|
||||
w2_scale_bias=layer.w2_scale_bias,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map,
|
||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
shared_experts=shared_experts,
|
||||
is_torchair=self.torchair_graph_enabled,
|
||||
quantized_x_for_share=shared_gate_up,
|
||||
dynamic_scale_for_share=shared_dequant_scale,
|
||||
mc2_mask=kwargs.get("mc2_mask", None))
|
||||
else:
|
||||
# The current implementation of deepseek moe splits hidden_states
|
||||
# according to tp_size before they are feed into layers module.
|
||||
# Therefore, all2all is needed no matter how dp/tp is set so as to
|
||||
# dispatch/combine tokens.
|
||||
return torchair_fused_experts_with_all2all(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
w1_scale=layer.w13_weight_scale_second,
|
||||
w2_scale=layer.w2_weight_scale_second,
|
||||
w1_scale_bias=layer.w13_scale_bias,
|
||||
w2_scale_bias=layer.w2_scale_bias,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map,
|
||||
ep_group=self.ep_group,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
)
|
||||
|
||||
def process_scale(self, weight: torch.Tensor, scale, per_group_scale):
|
||||
group_num, k, n = weight.shape
|
||||
# the weight of the new version is reduced by half by pack n, so it needs to be restored
|
||||
if self.new_quant_version:
|
||||
n = n * 2
|
||||
per_group_scale = per_group_scale.reshape(group_num, -1, n)
|
||||
group_num, quantgroup_num, n = per_group_scale.shape
|
||||
bias = None
|
||||
if not self.new_quant_version:
|
||||
weight_high = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * \
|
||||
per_group_scale.reshape([group_num, quantgroup_num, 1, n])
|
||||
weight_high = weight_high.reshape([group_num, k, n])
|
||||
bias = 8 * (weight_high.to(torch.float32) * scale).sum(axis=1)
|
||||
scale_fp32 = (scale * per_group_scale).to(torch.float16).to(
|
||||
torch.float32)
|
||||
scale_fp32_np = scale_fp32.cpu().numpy()
|
||||
scale_fp32_np.dtype = np.uint32
|
||||
sscale_uint64 = np.zeros((group_num, quantgroup_num, n * 2),
|
||||
dtype=np.uint32)
|
||||
|
||||
sscale_uint64[..., ::2] = scale_fp32_np
|
||||
|
||||
sscale_uint64_buffer = np.frombuffer(sscale_uint64.tobytes(),
|
||||
dtype=np.int64).copy()
|
||||
sscale_uint64_tensor = torch.from_numpy(sscale_uint64_buffer).reshape(
|
||||
group_num, quantgroup_num, n)
|
||||
sscale_uint64_tensor = sscale_uint64_tensor.npu()
|
||||
return sscale_uint64_tensor, bias
|
||||
|
||||
def update_bias(self, layer, w13_bias, w2_bias):
|
||||
if self.new_quant_version:
|
||||
layer.w13_scale_bias.data = layer.w13_scale_bias.data.transpose(
|
||||
1, 2).contiguous().sum(axis=1)
|
||||
layer.w2_scale_bias.data = layer.w2_scale_bias.data.transpose(
|
||||
1, 2).contiguous().sum(axis=1)
|
||||
else:
|
||||
w13_scale_bias = torch.nn.Parameter(w13_bias, requires_grad=False)
|
||||
layer.register_parameter("w13_scale_bias", w13_scale_bias)
|
||||
w2_scale_bias = torch.nn.Parameter(w2_bias, requires_grad=False)
|
||||
layer.register_parameter("w2_scale_bias", w2_scale_bias)
|
||||
|
||||
def pack_to_int32(self, weight: torch.Tensor):
|
||||
if self.new_quant_version:
|
||||
group_num, k, n = weight.shape
|
||||
assert n % 4 == 0, "the last dim of weight needs to be divided by 4"
|
||||
packed_n = n // 4
|
||||
# pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4
|
||||
packed_weight = torch.from_numpy(
|
||||
np.frombuffer(weight.cpu().numpy().tobytes(), dtype=np.int32))
|
||||
return packed_weight.reshape(group_num, k, packed_n).npu()
|
||||
else:
|
||||
return torch_npu.npu_quantize(weight.to(torch.float32),
|
||||
torch.tensor([1.]).npu(), None,
|
||||
torch.quint4x2, -1, False)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.transpose_weight:
|
||||
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w13_weight_scale_second.data = layer.w13_weight_scale_second.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight_scale_second.data = layer.w2_weight_scale_second.data.transpose(
|
||||
1, 2).contiguous()
|
||||
|
||||
layer.w13_weight_scale_second.data, w13_bias = self.process_scale(
|
||||
layer.w13_weight, layer.w13_weight_scale.data,
|
||||
layer.w13_weight_scale_second.data)
|
||||
layer.w2_weight_scale_second.data, w2_bias = self.process_scale(
|
||||
layer.w2_weight, layer.w2_weight_scale.data,
|
||||
layer.w2_weight_scale_second.data)
|
||||
|
||||
self.update_bias(layer, w13_bias, w2_bias)
|
||||
|
||||
layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data)
|
||||
layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data)
|
||||
1016
vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py
Normal file
1016
vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -38,6 +38,7 @@ from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata,
|
||||
check_torchair_cache_exist,
|
||||
converting_weight_acl_format,
|
||||
register_torchair_model,
|
||||
torchair_quant_method_register,
|
||||
write_kv_cache_bytes_to_file)
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
is_310p)
|
||||
@@ -67,6 +68,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
|
||||
self._check_batch_sizes_consistency()
|
||||
register_torchair_model()
|
||||
torchair_quant_method_register()
|
||||
|
||||
def _get_forward_metadata_across_dp_and_pad(
|
||||
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
|
||||
|
||||
@@ -170,3 +170,15 @@ def register_torchair_model():
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3ForCausalLM",
|
||||
"vllm_ascend.torchair.models.qwen3_moe:CustomQwen3MoeForCausalLM")
|
||||
|
||||
|
||||
def torchair_quant_method_register():
|
||||
from vllm_ascend.quantization.quantizer import \
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE
|
||||
from vllm_ascend.torchair.quantization.torchair_quantizer import (
|
||||
TorchairW4A8DYNAMICQuantizer, TorchairW8A8DYNAMICQuantizer)
|
||||
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE[
|
||||
"W8A8_DYNAMIC"] = TorchairW8A8DYNAMICQuantizer
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE[
|
||||
"W4A8_DYNAMIC"] = TorchairW4A8DYNAMICQuantizer
|
||||
|
||||
Reference in New Issue
Block a user