### What this PR does / why we need it?
While using the LLM Compressor quantization tool from the VLLM community
to generate quantized weights, the VLLM Ascend engine needs to be
adapted to support the compressed tensors quantization format.
1. Support Moe model W4A8 dynamic weight.
- vLLM version: v0.13.0
- vLLM main:
bde38c11df
---------
Signed-off-by: LHXuuu <scut_xlh@163.com>
Signed-off-by: menogrey <1299267905@qq.com>
Co-authored-by: menogrey <1299267905@qq.com>
614 lines
28 KiB
Python
614 lines
28 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
from typing import Any, Callable, Dict, Optional
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch_npu
|
|
from vllm.config import get_current_vllm_config
|
|
from vllm.distributed import get_ep_group
|
|
from vllm.forward_context import get_forward_context
|
|
|
|
from vllm_ascend.ascend_config import get_ascend_config
|
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
|
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
|
from vllm_ascend.utils import maybe_trans_nz, COMPRESSED_TENSORS_METHOD
|
|
|
|
from .base import AscendLinearScheme, AscendMoEScheme, QuantType
|
|
from .registry import register_scheme
|
|
|
|
|
|
@register_scheme("W4A8_DYNAMIC", "linear")
|
|
class AscendW4A8DynamicLinearMethod(AscendLinearScheme):
|
|
"""Linear method for Ascend W4A8_DYNAMIC."""
|
|
|
|
def __init__(self):
|
|
vllm_config = get_current_vllm_config()
|
|
self.group_size = vllm_config.quant_config.quant_description.get(
|
|
"group_size", 256)
|
|
quant_version = vllm_config.quant_config.quant_description.get(
|
|
"version", "0")
|
|
self.new_quant_version = quant_version == "1.0.0"
|
|
|
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
|
|
def get_weight(self, input_size: int, output_size: int,
|
|
params_dtype: torch.dtype) -> Dict[str, Any]:
|
|
"""Create weight parameters.
|
|
|
|
For new quantization version (double int4 pack into int8), the output dimension
|
|
is compressed by factor 2 (e.g., [2048, 3072] -> [1024, 3072]). The returned
|
|
dict includes "_packed_dim" and "_packed_factor" for vLLM's weight loader.
|
|
"""
|
|
params_dict = {}
|
|
|
|
if self.new_quant_version:
|
|
# double int4 pack into int8: output dimension is compressed
|
|
pack_factor = 2
|
|
actual_output_size = output_size // pack_factor
|
|
params_dict["weight"] = torch.empty(actual_output_size,
|
|
input_size,
|
|
dtype=torch.int8)
|
|
# Add packing information for vLLM's weight_loader
|
|
params_dict["_packed_dim"] = 0
|
|
params_dict["_packed_factor"] = pack_factor
|
|
else:
|
|
params_dict["weight"] = torch.empty(output_size,
|
|
input_size,
|
|
dtype=torch.int8)
|
|
|
|
return params_dict
|
|
|
|
def get_pergroup_param(self,
|
|
input_size: int,
|
|
output_size: int,
|
|
params_dtype: torch.dtype,
|
|
layer_type: Optional[str] = None) -> Dict[str, Any]:
|
|
"""Create per-group quantization parameters."""
|
|
params_dict = {}
|
|
params_dict["weight_scale"] = torch.empty(output_size,
|
|
1,
|
|
dtype=params_dtype)
|
|
params_dict["weight_offset"] = torch.empty(output_size,
|
|
1,
|
|
dtype=params_dtype)
|
|
params_dict["weight_scale_second"] = torch.empty(output_size,
|
|
input_size //
|
|
self.group_size,
|
|
dtype=params_dtype)
|
|
params_dict["weight_offset_second"] = torch.empty(output_size,
|
|
input_size //
|
|
self.group_size,
|
|
dtype=params_dtype)
|
|
|
|
# NOTE: In w4a8 quantization implementation,
|
|
# for down_proj and o_proj(layer_type == "row") scale_bias shape is [output_size, 16],
|
|
# others are [output_size, 1]
|
|
if self.new_quant_version:
|
|
scale_bias_dim = 16 if layer_type == "row" else 1
|
|
|
|
params_dict["scale_bias"] = torch.empty(output_size,
|
|
scale_bias_dim,
|
|
dtype=torch.float32)
|
|
return params_dict
|
|
|
|
@staticmethod
|
|
def process_scale_second(weight: torch.Tensor,
|
|
scale: torch.Tensor,
|
|
per_group_scale: torch.Tensor,
|
|
is_new_quant: bool = False):
|
|
"""Process the scale for second-level quantization.
|
|
|
|
Args:
|
|
weight: weight tensor [k, n] (in new version, n is already compressed to n/2)
|
|
scale: first-level quantization scale [output_size]
|
|
per_group_scale: second-level per-group quantization scale [group_num, n_scale]
|
|
is_new_quant: whether it's the new quantization version (weight already compressed)
|
|
|
|
Returns:
|
|
(antiquant_scale, bias): dequantization scale and bias (bias=None for new version)
|
|
"""
|
|
k, n = weight.shape
|
|
group_num, n_scale = per_group_scale.shape
|
|
|
|
if is_new_quant:
|
|
# Restore logical dimension for compressed weight
|
|
n = n * 2
|
|
|
|
bias = None
|
|
if not is_new_quant:
|
|
weight_high = weight.to(torch.float32).reshape(
|
|
group_num, -1, n) * per_group_scale.reshape(group_num, 1, n)
|
|
weight_high = weight_high.reshape(k, n)
|
|
bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0)
|
|
# NOTE: scale_bias is not used currently
|
|
# because in msmodelslim w4a8 uses symmetric quantization
|
|
|
|
# TODO: support potential future asymmetric quantization
|
|
antiquant_scale = (scale * per_group_scale).reshape(group_num, n)
|
|
return antiquant_scale.npu(), bias
|
|
|
|
def apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
bias: Optional[torch.Tensor] = None,
|
|
tp_rank: Optional[int] = None,
|
|
) -> torch.Tensor:
|
|
return torch_npu.npu_weight_quant_batchmatmul(
|
|
x,
|
|
layer.weight,
|
|
antiquant_scale=layer.weight_scale_second.to(x.dtype),
|
|
antiquant_group_size=self.group_size,
|
|
)
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module):
|
|
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
|
layer.weight.data = maybe_trans_nz(layer.weight.data)
|
|
layer.weight_scale.data = layer.weight_scale.data.flatten().to(
|
|
torch.float32)
|
|
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
|
layer.weight_scale_second.data, scale_bias = self.process_scale_second(
|
|
layer.weight.data,
|
|
layer.weight_scale.data,
|
|
layer.weight_scale_second.data.transpose(0, 1).contiguous(),
|
|
is_new_quant=self.new_quant_version,
|
|
)
|
|
|
|
if self.new_quant_version:
|
|
# Process the loaded data based on layer type
|
|
if hasattr(layer, "scale_bias"):
|
|
if layer.scale_bias.data.shape[1] == 1:
|
|
layer.scale_bias.data = layer.scale_bias.data.flatten()
|
|
else:
|
|
layer.scale_bias.data = layer.scale_bias.data.contiguous()
|
|
else:
|
|
if scale_bias is not None:
|
|
param = torch.nn.Parameter(scale_bias, requires_grad=False)
|
|
layer.register_parameter("weight_scale_bias", param)
|
|
|
|
# Convert to NPU-specific int4pack format
|
|
if self.new_quant_version:
|
|
# weights on disk are already in packed int4 format
|
|
# pack 4 int8(int4*2) to int32
|
|
assert layer.weight.data.shape[-1] % 4 == 0, \
|
|
f"the last dim of weight needs to be divided by 4, got shape {layer.weight.data.shape}"
|
|
layer.weight.data = layer.weight.data.view(
|
|
torch.int32).contiguous()
|
|
else:
|
|
# weights are not compressed
|
|
# need to be packed via npu_convert_weight_to_int4pack
|
|
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(
|
|
layer.weight.data.to(torch.int32))
|
|
|
|
|
|
@register_scheme("W4A8_DYNAMIC", "moe")
|
|
class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
|
|
"""FusedMoE method for Ascend W4A8_DYNAMIC."""
|
|
|
|
# Declare the quantization type for this scheme
|
|
quant_type: QuantType = QuantType.W4A8
|
|
|
|
def __init__(self):
|
|
self.ep_group = get_ep_group()
|
|
|
|
vllm_config = get_current_vllm_config()
|
|
self.group_size = vllm_config.quant_config.quant_description.get(
|
|
"group_size", 256)
|
|
# NOTE: the weights are quantized from bf16 to int4 through a per-channel quantization process
|
|
self.is_per_channel_weight = self.group_size == 0
|
|
quant_version = vllm_config.quant_config.quant_description.get(
|
|
"version", "0")
|
|
# NOTE: new quantize weights: 2 int4 pack into int8
|
|
self.new_quant_version = quant_version == "1.0.0"
|
|
|
|
self.quant_method = vllm_config.quant_config.quant_description.get(
|
|
"ascend_quant_method", "")
|
|
if self.quant_method == COMPRESSED_TENSORS_METHOD:
|
|
self.weight_strategy = vllm_config.quant_config.quant_description.get(
|
|
"weight_strategy", "group")
|
|
|
|
self.tp_size = 1 if vllm_config.parallel_config.enable_expert_parallel else self.ep_group.world_size
|
|
self.dynamic_eplb = get_ascend_config().eplb_config.dynamic_eplb
|
|
if self.new_quant_version and self.tp_size > 16:
|
|
raise ValueError(
|
|
"The current weight does not support moe part tp>16.")
|
|
|
|
try:
|
|
device_group = get_mc2_group().device_group
|
|
# TODO: Try local_rank = ep_group.rank_in_group
|
|
local_rank = torch.distributed.get_rank(group=device_group)
|
|
backend = device_group._get_backend(torch.device("npu"))
|
|
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
|
|
local_rank)
|
|
except AttributeError:
|
|
self.moe_all_to_all_group_name = ""
|
|
|
|
def get_weight(self, num_experts: int,
|
|
intermediate_size_per_partition: int, hidden_sizes: int,
|
|
params_dtype: torch.dtype) -> Dict[str, Any]:
|
|
if self.quant_method == COMPRESSED_TENSORS_METHOD:
|
|
return self.get_weight_compressed_tensors(
|
|
num_experts, intermediate_size_per_partition,
|
|
hidden_sizes, params_dtype)
|
|
else:
|
|
return self.get_weight_modelslim(
|
|
num_experts, intermediate_size_per_partition,
|
|
hidden_sizes, params_dtype)
|
|
|
|
def get_weight_compressed_tensors(self, num_experts: int,
|
|
intermediate_size_per_partition: int, hidden_sizes: int,
|
|
params_dtype: torch.dtype) -> Dict[str, Any]:
|
|
|
|
param_dict = {}
|
|
E = num_experts
|
|
H = hidden_sizes
|
|
IN = intermediate_size_per_partition
|
|
g = self.group_size
|
|
|
|
param_dict["w13_weight"] = torch.empty(E, 2 * IN, H,
|
|
dtype=torch.int8)
|
|
param_dict["w2_weight"] = torch.empty(E, H, IN,
|
|
dtype=torch.int8)
|
|
return param_dict
|
|
|
|
|
|
def get_weight_modelslim(self, num_experts: int,
|
|
intermediate_size_per_partition: int, hidden_sizes: int,
|
|
params_dtype: torch.dtype) -> Dict[str, Any]:
|
|
param_dict = {}
|
|
if self.new_quant_version:
|
|
w13_output_size = intermediate_size_per_partition
|
|
w2_output_size = hidden_sizes // 2
|
|
else:
|
|
w13_output_size = 2 * intermediate_size_per_partition
|
|
w2_output_size = hidden_sizes
|
|
|
|
param_dict["w13_weight"] = torch.empty(num_experts,
|
|
w13_output_size,
|
|
hidden_sizes,
|
|
dtype=torch.int8)
|
|
param_dict["w2_weight"] = torch.empty(num_experts,
|
|
w2_output_size,
|
|
intermediate_size_per_partition,
|
|
dtype=torch.int8)
|
|
return param_dict
|
|
|
|
def get_dynamic_quant_param(self, num_experts: int,
|
|
intermediate_size_per_partition: int,
|
|
hidden_sizes: int,
|
|
params_dtype: torch.dtype) -> Dict[str, Any]:
|
|
if self.quant_method == COMPRESSED_TENSORS_METHOD:
|
|
return self.get_dynamic_quant_param_compressed_tensors(
|
|
num_experts, intermediate_size_per_partition,
|
|
hidden_sizes, params_dtype)
|
|
else:
|
|
return self.get_dynamic_quant_param_modelslim(
|
|
num_experts, intermediate_size_per_partition,
|
|
hidden_sizes, params_dtype)
|
|
|
|
def get_dynamic_quant_param_compressed_tensors(self, num_experts: int,
|
|
intermediate_size_per_partition: int,
|
|
hidden_sizes: int,
|
|
params_dtype: torch.dtype) -> Dict[str, Any]:
|
|
param_dict = {}
|
|
|
|
E = num_experts
|
|
H = hidden_sizes
|
|
IN = intermediate_size_per_partition
|
|
g = self.group_size
|
|
|
|
# Per-row scale columns
|
|
def _n_scale_cols(in_features: int) -> int:
|
|
return 1 if g <= 0 else (in_features // g)
|
|
|
|
param_dict["w13_weight_scale"] = torch.empty(
|
|
E, 2 * IN, _n_scale_cols(H), dtype=torch.bfloat16)
|
|
|
|
param_dict["w2_weight_scale"] = torch.empty(E, H, _n_scale_cols(IN),
|
|
dtype=torch.bfloat16)
|
|
|
|
return param_dict
|
|
|
|
def get_dynamic_quant_param_modelslim(self, num_experts: int,
|
|
intermediate_size_per_partition: int,
|
|
hidden_sizes: int,
|
|
params_dtype: torch.dtype) -> Dict[str, Any]:
|
|
param_dict = {}
|
|
param_dict["w13_weight_scale"] = torch.empty(
|
|
num_experts,
|
|
2 * intermediate_size_per_partition,
|
|
1,
|
|
dtype=torch.float32)
|
|
|
|
param_dict["w13_weight_offset"] = torch.empty(
|
|
num_experts,
|
|
2 * intermediate_size_per_partition,
|
|
1,
|
|
dtype=torch.float32)
|
|
|
|
param_dict["w2_weight_scale"] = torch.empty(num_experts,
|
|
hidden_sizes,
|
|
1,
|
|
dtype=torch.float32)
|
|
param_dict["w2_weight_offset"] = torch.empty(num_experts,
|
|
hidden_sizes,
|
|
1,
|
|
dtype=torch.float32)
|
|
if not self.is_per_channel_weight:
|
|
param_dict["w13_weight_scale_second"] = torch.empty(
|
|
num_experts,
|
|
2 * intermediate_size_per_partition,
|
|
hidden_sizes // self.group_size,
|
|
dtype=torch.float32)
|
|
param_dict["w13_weight_offset_second"] = torch.empty(
|
|
num_experts,
|
|
2 * intermediate_size_per_partition,
|
|
hidden_sizes // self.group_size,
|
|
dtype=torch.float32)
|
|
|
|
param_dict["w2_weight_scale_second"] = torch.empty(
|
|
num_experts,
|
|
hidden_sizes,
|
|
intermediate_size_per_partition // self.group_size,
|
|
dtype=torch.float32)
|
|
param_dict["w2_weight_offset_second"] = torch.empty(
|
|
num_experts,
|
|
hidden_sizes,
|
|
intermediate_size_per_partition // self.group_size,
|
|
dtype=torch.float32)
|
|
|
|
if self.new_quant_version:
|
|
param_dict["w13_scale_bias"] = torch.empty(
|
|
num_experts,
|
|
2 * intermediate_size_per_partition,
|
|
1,
|
|
dtype=torch.float32)
|
|
param_dict["w2_scale_bias"] = torch.empty(num_experts,
|
|
hidden_sizes,
|
|
16 // self.tp_size,
|
|
dtype=torch.float32)
|
|
|
|
return param_dict
|
|
|
|
def apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
top_k: int,
|
|
renormalize: bool,
|
|
use_grouped_topk: bool = False,
|
|
global_num_experts: int = -1,
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
scoring_func: str = "softmax",
|
|
routed_scaling_factor: float = 1.0,
|
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
is_prefill: bool = True,
|
|
enable_force_load_balance: bool = False,
|
|
log2phy: Optional[torch.Tensor] = None,
|
|
global_redundant_expert_num: int = 0,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
assert router_logits.shape[
|
|
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
|
|
|
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
|
topk_weights, topk_ids = select_experts(
|
|
hidden_states=x,
|
|
router_logits=router_logits,
|
|
top_k=top_k,
|
|
use_grouped_topk=use_grouped_topk,
|
|
renormalize=renormalize,
|
|
topk_group=topk_group,
|
|
num_expert_group=num_expert_group,
|
|
custom_routing_function=custom_routing_function,
|
|
scoring_func=scoring_func,
|
|
e_score_correction_bias=e_score_correction_bias,
|
|
global_num_experts=global_num_experts)
|
|
|
|
# this is a naive implementation for experts load balance so as
|
|
# to avoid accumulating too much tokens on a single rank.
|
|
# currently it is only activated when doing profile runs.
|
|
if enable_force_load_balance:
|
|
random_matrix = torch.rand(topk_ids.size(0),
|
|
global_num_experts -
|
|
global_redundant_expert_num,
|
|
device=topk_ids.device)
|
|
topk_ids = torch.argsort(
|
|
random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)
|
|
|
|
topk_weights = topk_weights.to(x.dtype)
|
|
|
|
moe_comm_method = get_forward_context().moe_comm_method
|
|
return moe_comm_method.fused_experts(
|
|
hidden_states=x,
|
|
w1=[layer.w13_weight],
|
|
w2=[layer.w2_weight],
|
|
w1_scale=[layer.w13_weight_scale],
|
|
w2_scale=[layer.w2_weight_scale],
|
|
w1_scale_bias=layer.w13_scale_bias if hasattr(
|
|
layer, "w13_scale_bias") else None,
|
|
w2_scale_bias=layer.w2_scale_bias if hasattr(
|
|
layer, "w2_scale_bias") else None,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
use_int4_w4a8=True,
|
|
expert_map=expert_map,
|
|
log2phy=log2phy,
|
|
dynamic_eplb=self.dynamic_eplb,
|
|
mc2_mask=kwargs.get("mc2_mask", None))
|
|
|
|
def process_scale(self, weight: torch.Tensor, scale, per_group_scale):
|
|
scale = scale.transpose(1, 2).contiguous()
|
|
if self.is_per_channel_weight:
|
|
scale_np = scale.cpu().numpy()
|
|
scale_np.dtype = np.uint32
|
|
scale_uint64_tensor = torch.from_numpy(scale_np.astype(
|
|
np.int64)).npu()
|
|
return scale_uint64_tensor, None
|
|
per_group_scale = per_group_scale.transpose(1, 2).contiguous()
|
|
group_num, k, n = weight.shape
|
|
# the weight of the new version is reduced by half by pack n, so it needs to be restored
|
|
if self.new_quant_version:
|
|
n = n * 2
|
|
per_group_scale = per_group_scale.reshape(group_num, -1, n)
|
|
group_num, quantgroup_num, n = per_group_scale.shape
|
|
bias = None
|
|
if not self.new_quant_version:
|
|
weight_high = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * \
|
|
per_group_scale.reshape([group_num, quantgroup_num, 1, n])
|
|
weight_high = weight_high.reshape([group_num, k, n])
|
|
bias = 8 * (weight_high.to(torch.float32) * scale).sum(axis=1)
|
|
scale_fp32 = (scale * per_group_scale).to(torch.float16).to(
|
|
torch.float32)
|
|
scale_fp32_np = scale_fp32.cpu().numpy()
|
|
scale_fp32_np.dtype = np.uint32
|
|
sscale_uint64 = np.zeros((group_num, quantgroup_num, n * 2),
|
|
dtype=np.uint32)
|
|
|
|
sscale_uint64[..., ::2] = scale_fp32_np
|
|
|
|
sscale_uint64_buffer = np.frombuffer(sscale_uint64.tobytes(),
|
|
dtype=np.int64).copy()
|
|
sscale_uint64_tensor = torch.from_numpy(sscale_uint64_buffer).reshape(
|
|
group_num, quantgroup_num, n)
|
|
sscale_uint64_tensor = sscale_uint64_tensor.npu()
|
|
return sscale_uint64_tensor, bias
|
|
|
|
def update_bias(self, layer, w13_bias, w2_bias):
|
|
if self.new_quant_version:
|
|
layer.w13_scale_bias.data = layer.w13_scale_bias.data.transpose(
|
|
1, 2).contiguous().sum(axis=1)
|
|
layer.w2_scale_bias.data = layer.w2_scale_bias.data.transpose(
|
|
1, 2).contiguous().sum(axis=1)
|
|
else:
|
|
w13_scale_bias = torch.nn.Parameter(w13_bias, requires_grad=False)
|
|
layer.register_parameter("w13_scale_bias", w13_scale_bias)
|
|
w2_scale_bias = torch.nn.Parameter(w2_bias, requires_grad=False)
|
|
layer.register_parameter("w2_scale_bias", w2_scale_bias)
|
|
|
|
def pack_to_int32(self, weight: torch.Tensor):
|
|
if self.new_quant_version:
|
|
# pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4
|
|
assert weight.shape[
|
|
-1] % 4 == 0, "the last dim of weight needs to be divided by 4"
|
|
return weight.view(torch.int32).contiguous()
|
|
else:
|
|
return torch_npu.npu_quantize(weight.to(torch.float32),
|
|
torch.tensor([1.]).npu(), None,
|
|
torch.quint4x2, -1, False)
|
|
|
|
def process_weights_after_loading(self, layer):
|
|
if self.quant_method == COMPRESSED_TENSORS_METHOD:
|
|
self.process_weights_after_loading_compressed_tensors(layer)
|
|
else:
|
|
self.process_weights_after_loading_modelslim(layer)
|
|
|
|
|
|
def process_weights_after_loading_compressed_tensors(self, layer):
|
|
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
|
1, 2).contiguous()
|
|
layer.w2_weight.data = layer.w2_weight.data.transpose(1,
|
|
2).contiguous()
|
|
|
|
def process_scale_compressed_tensors(scale: torch.Tensor):
|
|
scale = scale.transpose(1, 2).to(torch.float32).contiguous()
|
|
scale_np = scale.cpu().numpy()
|
|
scale_np.dtype = np.uint32
|
|
scale_uint64_tensor = torch.from_numpy(scale_np.astype(
|
|
np.int64)).npu()
|
|
return scale_uint64_tensor
|
|
|
|
def update_bias_compressed_tensors(weight: torch.Tensor,
|
|
scale: torch.Tensor, strategy:str):
|
|
group_num, k, n = weight.shape
|
|
scale = scale.transpose(1, 2).contiguous()
|
|
scale = scale.reshape(group_num, -1, n)
|
|
group_num, quantgroup_num, n = scale.shape
|
|
|
|
bias = None
|
|
if strategy == "group":
|
|
tmp = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * \
|
|
scale.reshape([group_num, quantgroup_num, 1, n])
|
|
tmp = tmp.reshape([group_num, k, n])
|
|
bias = 8 * tmp.sum(axis=1)
|
|
elif strategy == "channel":
|
|
bias = 8 * (weight.to(torch.float32) * scale).sum(axis=1)
|
|
else:
|
|
raise ValueError(f"Unsupported weight strategy: {strategy}")
|
|
return bias
|
|
|
|
w13_bias = update_bias_compressed_tensors(layer.w13_weight.data,
|
|
layer.w13_weight_scale.data,
|
|
self.weight_strategy)
|
|
w2_bias = update_bias_compressed_tensors(layer.w2_weight.data,
|
|
layer.w2_weight_scale.data,
|
|
self.weight_strategy)
|
|
|
|
layer.w13_weight_scale.data = process_scale_compressed_tensors(
|
|
layer.w13_weight_scale.data)
|
|
layer.w2_weight_scale.data = process_scale_compressed_tensors(
|
|
layer.w2_weight_scale.data)
|
|
|
|
|
|
w13_scale_bias = torch.nn.Parameter(w13_bias, requires_grad=False)
|
|
layer.register_parameter("w13_scale_bias", w13_scale_bias)
|
|
w2_scale_bias = torch.nn.Parameter(w2_bias, requires_grad=False)
|
|
layer.register_parameter("w2_scale_bias", w2_scale_bias)
|
|
|
|
# Accuracy problem in nz format
|
|
# layer.w13_weight.data = maybe_trans_nz(layer.w13_weight.data)
|
|
# layer.w2_weight.data = maybe_trans_nz(layer.w2_weight.data)
|
|
layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data)
|
|
layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data)
|
|
|
|
def process_weights_after_loading_modelslim(self, layer):
|
|
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
|
1, 2).contiguous()
|
|
layer.w2_weight.data = layer.w2_weight.data.transpose(1,
|
|
2).contiguous()
|
|
|
|
w13_weight_scale_second = layer.w13_weight_scale_second.data if hasattr(
|
|
layer, "w13_weight_scale_second") else None
|
|
w2_weight_scale_second = layer.w2_weight_scale_second.data if hasattr(
|
|
layer, "w2_weight_scale_second") else None
|
|
layer.w13_weight_scale.data, w13_bias = self.process_scale(
|
|
layer.w13_weight, layer.w13_weight_scale.data,
|
|
w13_weight_scale_second)
|
|
layer.w2_weight_scale.data, w2_bias = self.process_scale(
|
|
layer.w2_weight, layer.w2_weight_scale.data,
|
|
w2_weight_scale_second)
|
|
if hasattr(layer, "w13_weight_scale_second"):
|
|
# scale_second is no longer used, release this part of the memory
|
|
del layer.w13_weight_scale_second
|
|
del layer.w2_weight_scale_second
|
|
del layer.w13_weight_offset_second
|
|
del layer.w2_weight_offset_second
|
|
|
|
self.update_bias(layer, w13_bias, w2_bias)
|
|
|
|
layer.w13_weight.data = maybe_trans_nz(layer.w13_weight.data)
|
|
layer.w2_weight.data = maybe_trans_nz(layer.w2_weight.data)
|
|
layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data)
|
|
layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data)
|