### What this PR does / why we need it? **Problem Description:** The existing implementation for the w4a8-dynamic linear method only supports the old quantization format from msmodelslim. When attempting to load models quantized with the new version, vLLM encounters errors due to mismatched tensor shapes and unprocessed quantization parameters. Relavant issues: - https://github.com/vllm-project/vllm-ascend/issues/3192 - https://github.com/vllm-project/vllm-ascend/issues/3152 **Proposed Changes:** 1. Add support for w4a8 dynamic(new format) in AscendW4A8DynamicLinearMethod and TorchairAscendW4A8DynamicLinearMethod 2. Add unit tests and e2e tests for w4a8 dynamic new and old format models <details> <summary><b>details</b></summary> 1. **Support for new w4a8-dynamic format:** * Detects quantization format by reading the "version" field in quant_description to ensure backward compatibility. * Handles the new pre-packed weight format (`2x int4` in an `int8`), which has a halved dimension. It tells the vLLM loader how to unpack it using `_packed_dim` and `_packed_factor`. * Supports the new `scale_bias` parameter, setting its shape based on the layer type, as required by msmodelslim. For api consistency and future use, the `layer_type` parameter was also added to other quantization methods. * Updates the weight processing logic: new format weights are handled with `.view(torch.int32)` since they're pre-packed, while old ones are processed with `npu_convert_weight_to_int4pack`. 2. **New unit and E2E tests:** * Added unit tests that verify the logic for both the old and new formats. * Split the distributed E2E test to confirm that both old and new format models work correctly. </details> Theoretically, these changes will provide support for all common new version w4a8(dynamic) models from msmodelslim. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? I implement relevant unit tests and e2e tests and test the changes with following commands: ```bash # unit tests python -m pytest tests/ut/quantization/test_w4a8_dynamic.py tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py -v # e2e tests pytest tests/e2e/singlecard/test_quantization.py -v -s pytest tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC_new_version -v -s pytest tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC_old_version -v -s pytest tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC -v -s ``` I also tested Hunyuan-1.8B-Instruct quantized with the new w4a8-dynamic format: ``` vllm serve ./models/Hunyuan-1.8B-Instruct-quantized --gpu-memory-utilization 0.96 --quantization ascend --max-model-len 9600 --seed 0 --max-num-batched-tokens 16384 ``` All tests mentioned passed locally. **NOTE: I use quantization model from my own repo in test_offline_inference_distributed.py**. Here is the description: [Anionex/Qwen3-1.7B-W4A8-V1](https://modelscope.cn/models/Anionex/Qwen3-1.7B-W4A8-V1/summary) (including quantization steps).This should be replaced by a model in vllm-ascend ci modelscope repo. Thanks for reading! - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: Anionex <1005128408@qq.com>
491 lines
22 KiB
Python
491 lines
22 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
from typing import Any, Callable, Dict, Optional
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch_npu
|
|
from vllm.config import get_current_vllm_config
|
|
from vllm.distributed import get_ep_group
|
|
from vllm.forward_context import get_forward_context
|
|
|
|
from vllm_ascend.ascend_config import get_ascend_config
|
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
|
from vllm_ascend.ops.moe.experts_selector import select_experts
|
|
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
|
|
|
|
|
|
class AscendW4A8DynamicLinearMethod:
|
|
"""Linear method for Ascend W4A8_DYNAMIC
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.transpose_weight = True
|
|
|
|
vllm_config = get_current_vllm_config()
|
|
self.group_size = vllm_config.quant_config.quant_description.get(
|
|
"group_size", 256)
|
|
quant_version = vllm_config.quant_config.quant_description.get(
|
|
"version", "0")
|
|
self.new_quant_version = quant_version == "1.0.0"
|
|
|
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
|
|
def get_weight(self, input_size: int, output_size: int,
|
|
params_dtype: torch.dtype) -> Dict[str, Any]:
|
|
"""Create weight parameters.
|
|
|
|
For new quantization version (double int4 pack into int8), the output dimension
|
|
is compressed by factor 2 (e.g., [2048, 3072] -> [1024, 3072]). The returned
|
|
dict includes "_packed_dim" and "_packed_factor" for vLLM's weight loader.
|
|
"""
|
|
params_dict = {}
|
|
|
|
if self.new_quant_version:
|
|
# double int4 pack into int8: output dimension is compressed
|
|
pack_factor = 2
|
|
actual_output_size = output_size // pack_factor
|
|
params_dict["weight"] = torch.empty(actual_output_size,
|
|
input_size,
|
|
dtype=torch.int8)
|
|
# Add packing information for vLLM's weight_loader
|
|
params_dict["_packed_dim"] = 0
|
|
params_dict["_packed_factor"] = pack_factor
|
|
else:
|
|
params_dict["weight"] = torch.empty(output_size,
|
|
input_size,
|
|
dtype=torch.int8)
|
|
|
|
return params_dict
|
|
|
|
@staticmethod
|
|
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
|
return {}
|
|
|
|
@staticmethod
|
|
def get_perchannel_param(output_size: int,
|
|
params_dtype: torch.dtype) -> Dict[str, Any]:
|
|
return {}
|
|
|
|
def get_pergroup_param(self,
|
|
input_size: int,
|
|
output_size: int,
|
|
params_dtype: torch.dtype,
|
|
layer_type: Optional[str] = None) -> Dict[str, Any]:
|
|
"""
|
|
Create per-group quantization parameters.
|
|
"""
|
|
params_dict = {}
|
|
params_dict["weight_scale"] = torch.empty(output_size,
|
|
1,
|
|
dtype=params_dtype)
|
|
params_dict["weight_offset"] = torch.empty(output_size,
|
|
1,
|
|
dtype=params_dtype)
|
|
params_dict["weight_scale_second"] = torch.empty(output_size,
|
|
input_size //
|
|
self.group_size,
|
|
dtype=params_dtype)
|
|
params_dict["weight_offset_second"] = torch.empty(output_size,
|
|
input_size //
|
|
self.group_size,
|
|
dtype=params_dtype)
|
|
|
|
# NOTE: In w4a8 quantization implementation,
|
|
# for down_proj and o_proj(layer_type == "row") scale_bias shape is [output_size, 16],
|
|
# others are [output_size, 1]
|
|
if self.new_quant_version:
|
|
scale_bias_dim = 16 if layer_type == "row" else 1
|
|
|
|
params_dict["scale_bias"] = torch.empty(output_size,
|
|
scale_bias_dim,
|
|
dtype=torch.float32)
|
|
return params_dict
|
|
|
|
@staticmethod
|
|
def process_scale_second(weight: torch.Tensor,
|
|
scale: torch.Tensor,
|
|
per_group_scale: torch.Tensor,
|
|
is_new_quant: bool = False):
|
|
"""
|
|
Process the scale for second-level quantization.
|
|
|
|
Args:
|
|
weight: weight tensor [k, n] (in new version, n is already compressed to n/2)
|
|
scale: first-level quantization scale [output_size]
|
|
per_group_scale: second-level per-group quantization scale [group_num, n_scale]
|
|
is_new_quant: whether it's the new quantization version (weight already compressed)
|
|
|
|
Returns:
|
|
(antiquant_scale, bias): dequantization scale and bias (bias=None for new version)
|
|
"""
|
|
k, n = weight.shape
|
|
group_num, n_scale = per_group_scale.shape
|
|
|
|
if is_new_quant:
|
|
# Restore logical dimension for compressed weight
|
|
n = n * 2
|
|
|
|
bias = None
|
|
if not is_new_quant:
|
|
weight_high = weight.to(torch.float32).reshape(
|
|
group_num, -1, n) * per_group_scale.reshape(group_num, 1, n)
|
|
weight_high = weight_high.reshape(k, n)
|
|
bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0)
|
|
# NOTE: scale_bias is not used currently
|
|
# because in msmodelslim w4a8 uses symmetric quantization
|
|
|
|
# TODO: support potential future asymmetric quantization
|
|
antiquant_scale = (scale * per_group_scale).reshape(group_num, n)
|
|
return antiquant_scale.npu(), bias
|
|
|
|
def apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
bias: Optional[torch.Tensor] = None,
|
|
tp_rank: Optional[int] = None,
|
|
) -> torch.Tensor:
|
|
return torch_npu.npu_weight_quant_batchmatmul(
|
|
x,
|
|
layer.weight,
|
|
antiquant_scale=layer.weight_scale_second.to(x.dtype),
|
|
antiquant_group_size=self.group_size,
|
|
)
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module):
|
|
if self.transpose_weight:
|
|
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
|
layer.weight_scale.data = layer.weight_scale.data.flatten().to(
|
|
torch.float32)
|
|
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
|
layer.weight_scale_second.data, scale_bias = self.process_scale_second(
|
|
layer.weight.data,
|
|
layer.weight_scale.data,
|
|
layer.weight_scale_second.data.transpose(0, 1).contiguous(),
|
|
is_new_quant=self.new_quant_version,
|
|
)
|
|
|
|
if self.new_quant_version:
|
|
# Process the loaded data based on layer type
|
|
if hasattr(layer, "scale_bias"):
|
|
if layer.scale_bias.data.shape[1] == 1:
|
|
layer.scale_bias.data = layer.scale_bias.data.flatten()
|
|
else:
|
|
layer.scale_bias.data = layer.scale_bias.data.contiguous()
|
|
else:
|
|
if scale_bias is not None:
|
|
param = torch.nn.Parameter(scale_bias, requires_grad=False)
|
|
layer.register_parameter("weight_scale_bias", param)
|
|
|
|
# Convert to NPU-specific int4pack format
|
|
if self.new_quant_version:
|
|
# weights on disk are already in packed int4 format
|
|
# pack 4 int8(int4*2) to int32
|
|
assert layer.weight.data.shape[-1] % 4 == 0, \
|
|
f"the last dim of weight needs to be divided by 4, got shape {layer.weight.data.shape}"
|
|
layer.weight.data = layer.weight.data.view(
|
|
torch.int32).contiguous()
|
|
else:
|
|
# weights are not compressed
|
|
# need to be packed via npu_convert_weight_to_int4pack
|
|
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(
|
|
layer.weight.data.to(torch.int32))
|
|
|
|
|
|
class AscendW4A8DynamicFusedMoEMethod:
|
|
"""FusedMoe method for Ascend W4A8_DYNAMIC.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.transpose_weight = True
|
|
|
|
self.ep_group = get_ep_group()
|
|
|
|
vllm_config = get_current_vllm_config()
|
|
self.group_size = vllm_config.quant_config.quant_description.get(
|
|
"group_size", 256)
|
|
# NOTE: the weights are quantized from bf16 to int4 through a per-channel quantization process
|
|
self.is_per_channel_weight = self.group_size == 0
|
|
quant_version = vllm_config.quant_config.quant_description.get(
|
|
"version", "0")
|
|
# NOTE: new quantize weights: 2 int4 pack into int8
|
|
self.new_quant_version = quant_version == "1.0.0"
|
|
self.tp_size = 1 if vllm_config.parallel_config.enable_expert_parallel else self.ep_group.world_size
|
|
ascend_config = get_ascend_config()
|
|
self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
|
|
if self.new_quant_version and self.tp_size > 16:
|
|
raise ValueError(
|
|
"The current weight does not support moe part tp>16.")
|
|
|
|
try:
|
|
device_group = get_mc2_group().device_group
|
|
# TODO: Try local_rank = ep_group.rank_in_group
|
|
local_rank = torch.distributed.get_rank(group=device_group)
|
|
backend = device_group._get_backend(torch.device("npu"))
|
|
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
|
|
local_rank)
|
|
except AttributeError:
|
|
self.moe_all_to_all_group_name = ""
|
|
|
|
def get_weight(self, num_experts: int,
|
|
intermediate_size_per_partition: int, hidden_sizes: int,
|
|
params_dtype: torch.dtype) -> Dict[str, Any]:
|
|
param_dict = {}
|
|
if self.new_quant_version:
|
|
w13_output_size = intermediate_size_per_partition
|
|
w2_output_size = hidden_sizes // 2
|
|
else:
|
|
w13_output_size = 2 * intermediate_size_per_partition
|
|
w2_output_size = hidden_sizes
|
|
|
|
param_dict["w13_weight"] = torch.empty(num_experts,
|
|
w13_output_size,
|
|
hidden_sizes,
|
|
dtype=torch.int8)
|
|
param_dict["w2_weight"] = torch.empty(num_experts,
|
|
w2_output_size,
|
|
intermediate_size_per_partition,
|
|
dtype=torch.int8)
|
|
return param_dict
|
|
|
|
def get_dynamic_quant_param(self, num_experts: int,
|
|
intermediate_size_per_partition: int,
|
|
hidden_sizes: int,
|
|
params_dtype: torch.dtype) -> Dict[str, Any]:
|
|
param_dict = {}
|
|
param_dict["w13_weight_scale"] = torch.empty(
|
|
num_experts,
|
|
2 * intermediate_size_per_partition,
|
|
1,
|
|
dtype=torch.float32)
|
|
|
|
param_dict["w13_weight_offset"] = torch.empty(
|
|
num_experts,
|
|
2 * intermediate_size_per_partition,
|
|
1,
|
|
dtype=torch.float32)
|
|
|
|
param_dict["w2_weight_scale"] = torch.empty(num_experts,
|
|
hidden_sizes,
|
|
1,
|
|
dtype=torch.float32)
|
|
param_dict["w2_weight_offset"] = torch.empty(num_experts,
|
|
hidden_sizes,
|
|
1,
|
|
dtype=torch.float32)
|
|
if not self.is_per_channel_weight:
|
|
param_dict["w13_weight_scale_second"] = torch.empty(
|
|
num_experts,
|
|
2 * intermediate_size_per_partition,
|
|
hidden_sizes // self.group_size,
|
|
dtype=torch.float32)
|
|
param_dict["w13_weight_offset_second"] = torch.empty(
|
|
num_experts,
|
|
2 * intermediate_size_per_partition,
|
|
hidden_sizes // self.group_size,
|
|
dtype=torch.float32)
|
|
|
|
param_dict["w2_weight_scale_second"] = torch.empty(
|
|
num_experts,
|
|
hidden_sizes,
|
|
intermediate_size_per_partition // self.group_size,
|
|
dtype=torch.float32)
|
|
param_dict["w2_weight_offset_second"] = torch.empty(
|
|
num_experts,
|
|
hidden_sizes,
|
|
intermediate_size_per_partition // self.group_size,
|
|
dtype=torch.float32)
|
|
|
|
if self.new_quant_version:
|
|
param_dict["w13_scale_bias"] = torch.empty(
|
|
num_experts,
|
|
2 * intermediate_size_per_partition,
|
|
1,
|
|
dtype=torch.float32)
|
|
param_dict["w2_scale_bias"] = torch.empty(num_experts,
|
|
hidden_sizes,
|
|
16 // self.tp_size,
|
|
dtype=torch.float32)
|
|
|
|
return param_dict
|
|
|
|
def apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
top_k: int,
|
|
renormalize: bool,
|
|
use_grouped_topk: bool = False,
|
|
global_num_experts: int = -1,
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
scoring_func: str = "softmax",
|
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
is_prefill: bool = True,
|
|
enable_force_load_balance: bool = True,
|
|
log2phy: torch.Tensor = None,
|
|
global_redundant_expert_num: int = 0,
|
|
shared_experts: Optional[Any] = None,
|
|
quantized_x_for_share: Optional[Any] = None,
|
|
dynamic_scale_for_share: Optional[Any] = None,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
assert router_logits.shape[
|
|
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
|
|
|
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
|
topk_weights, topk_ids = select_experts(
|
|
hidden_states=x,
|
|
router_logits=router_logits,
|
|
top_k=top_k,
|
|
use_grouped_topk=use_grouped_topk,
|
|
renormalize=renormalize,
|
|
topk_group=topk_group,
|
|
num_expert_group=num_expert_group,
|
|
custom_routing_function=custom_routing_function,
|
|
scoring_func=scoring_func,
|
|
e_score_correction_bias=e_score_correction_bias,
|
|
global_num_experts=global_num_experts)
|
|
|
|
# this is a naive implementation for experts load balance so as
|
|
# to avoid accumulating too much tokens on a single rank.
|
|
# currently it is only activated when doing profile runs.
|
|
if enable_force_load_balance:
|
|
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
|
|
|
topk_weights = topk_weights.to(x.dtype)
|
|
|
|
moe_comm_method = get_forward_context().moe_comm_method
|
|
return moe_comm_method.fused_experts(
|
|
hidden_states=x,
|
|
w1=layer.w13_weight,
|
|
w2=layer.w2_weight,
|
|
w1_scale=layer.w13_weight_scale,
|
|
w2_scale=layer.w2_weight_scale,
|
|
w1_scale_bias=layer.w13_scale_bias,
|
|
w2_scale_bias=layer.w2_scale_bias,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
use_int4_w4a8=True,
|
|
expert_map=expert_map,
|
|
log2phy=log2phy,
|
|
global_redundant_expert_num=global_redundant_expert_num,
|
|
shared_experts=shared_experts,
|
|
quantized_x_for_share=quantized_x_for_share,
|
|
dynamic_scale_for_share=dynamic_scale_for_share,
|
|
dynamic_eplb=self.dynamic_eplb)
|
|
|
|
def process_scale(self, weight: torch.Tensor, scale, per_group_scale):
|
|
scale = scale.transpose(1, 2).contiguous()
|
|
if self.is_per_channel_weight:
|
|
scale_np = scale.cpu().numpy()
|
|
scale_np.dtype = np.uint32
|
|
scale_uint64_tensor = torch.from_numpy(scale_np.astype(
|
|
np.int64)).npu()
|
|
return scale_uint64_tensor, None
|
|
per_group_scale = per_group_scale.transpose(1, 2).contiguous()
|
|
group_num, k, n = weight.shape
|
|
# the weight of the new version is reduced by half by pack n, so it needs to be restored
|
|
if self.new_quant_version:
|
|
n = n * 2
|
|
per_group_scale = per_group_scale.reshape(group_num, -1, n)
|
|
group_num, quantgroup_num, n = per_group_scale.shape
|
|
bias = None
|
|
if not self.new_quant_version:
|
|
weight_high = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * \
|
|
per_group_scale.reshape([group_num, quantgroup_num, 1, n])
|
|
weight_high = weight_high.reshape([group_num, k, n])
|
|
bias = 8 * (weight_high.to(torch.float32) * scale).sum(axis=1)
|
|
scale_fp32 = (scale * per_group_scale).to(torch.float16).to(
|
|
torch.float32)
|
|
scale_fp32_np = scale_fp32.cpu().numpy()
|
|
scale_fp32_np.dtype = np.uint32
|
|
sscale_uint64 = np.zeros((group_num, quantgroup_num, n * 2),
|
|
dtype=np.uint32)
|
|
|
|
sscale_uint64[..., ::2] = scale_fp32_np
|
|
|
|
sscale_uint64_buffer = np.frombuffer(sscale_uint64.tobytes(),
|
|
dtype=np.int64).copy()
|
|
sscale_uint64_tensor = torch.from_numpy(sscale_uint64_buffer).reshape(
|
|
group_num, quantgroup_num, n)
|
|
sscale_uint64_tensor = sscale_uint64_tensor.npu()
|
|
return sscale_uint64_tensor, bias
|
|
|
|
def update_bias(self, layer, w13_bias, w2_bias):
|
|
if self.new_quant_version:
|
|
layer.w13_scale_bias.data = layer.w13_scale_bias.data.transpose(
|
|
1, 2).contiguous().sum(axis=1)
|
|
layer.w2_scale_bias.data = layer.w2_scale_bias.data.transpose(
|
|
1, 2).contiguous().sum(axis=1)
|
|
else:
|
|
w13_scale_bias = torch.nn.Parameter(w13_bias, requires_grad=False)
|
|
layer.register_parameter("w13_scale_bias", w13_scale_bias)
|
|
w2_scale_bias = torch.nn.Parameter(w2_bias, requires_grad=False)
|
|
layer.register_parameter("w2_scale_bias", w2_scale_bias)
|
|
|
|
def pack_to_int32(self, weight: torch.Tensor):
|
|
if self.new_quant_version:
|
|
# pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4
|
|
assert weight.shape[
|
|
-1] % 4 == 0, "the last dim of weight needs to be divided by 4"
|
|
return weight.view(torch.int32).contiguous()
|
|
else:
|
|
return torch_npu.npu_quantize(weight.to(torch.float32),
|
|
torch.tensor([1.]).npu(), None,
|
|
torch.quint4x2, -1, False)
|
|
|
|
def process_weights_after_loading(self, layer):
|
|
if self.transpose_weight:
|
|
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
|
1, 2).contiguous()
|
|
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
|
1, 2).contiguous()
|
|
|
|
w13_weight_scale_second = layer.w13_weight_scale_second.data if hasattr(
|
|
layer, "w13_weight_scale_second") else None
|
|
w2_weight_scale_second = layer.w2_weight_scale_second.data if hasattr(
|
|
layer, "w2_weight_scale_second") else None
|
|
layer.w13_weight_scale.data, w13_bias = self.process_scale(
|
|
layer.w13_weight, layer.w13_weight_scale.data,
|
|
w13_weight_scale_second)
|
|
layer.w2_weight_scale.data, w2_bias = self.process_scale(
|
|
layer.w2_weight, layer.w2_weight_scale.data,
|
|
w2_weight_scale_second)
|
|
if hasattr(layer, "w13_weight_scale_second"):
|
|
# scale_second is no longer used, release this part of the memory
|
|
del layer.w13_weight_scale_second
|
|
del layer.w2_weight_scale_second
|
|
del layer.w13_weight_offset_second
|
|
del layer.w2_weight_offset_second
|
|
|
|
self.update_bias(layer, w13_bias, w2_bias)
|
|
|
|
if is_enable_nz():
|
|
layer.w13_weight.data = torch_npu.npu_format_cast(
|
|
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
|
layer.w2_weight.data = torch_npu.npu_format_cast(
|
|
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
|
layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data)
|
|
layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data)
|