[Lint]Style: Convert vllm-ascend/ to ruff format(Batch #7) (#6023)

### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
|` vllm_ascend/quantization/compressed_tensors/compressed_tensors.py`|
|` vllm_ascend/quantization/quant_config.py`|
|` vllm_ascend/quantization/utils.py`|
|` vllm_ascend/quantization/w4a16.py`|
|` vllm_ascend/quantization/w4a4_flatquant_dynamic.py`|
|` vllm_ascend/quantization/w4a8_dynamic.py`|
|` vllm_ascend/quantization/w8a16.py`|
|` vllm_ascend/quantization/w8a8.py`|
|` vllm_ascend/quantization/w8a8_dynamic.py`|
|` vllm_ascend/quantization/w8a8_pdmix.py`|
|` vllm_ascend/quantization/w8a8mxfp8.py`|
|` vllm_ascend/sample/rejection_sampler.py`|
|` vllm_ascend/sample/sampler.py`|
|` vllm_ascend/worker/block_table.py`|

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
2c24bc6996

Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
SILONG ZENG
2026-02-06 14:56:53 +08:00
committed by GitHub
parent d0bc16859c
commit 99aedaff63
20 changed files with 997 additions and 1307 deletions

View File

@@ -15,7 +15,8 @@
# limitations under the License.
#
from typing import Any, Callable, Dict, Optional
from collections.abc import Callable
from typing import Any
import numpy as np
import torch
@@ -27,7 +28,7 @@ from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.utils import maybe_trans_nz, COMPRESSED_TENSORS_METHOD
from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD, maybe_trans_nz
from .base import AscendLinearScheme, AscendMoEScheme, QuantType
from .registry import register_scheme
@@ -39,19 +40,17 @@ class AscendW4A8DynamicLinearMethod(AscendLinearScheme):
def __init__(self):
vllm_config = get_current_vllm_config()
self.group_size = vllm_config.quant_config.quant_description.get(
"group_size", 256)
quant_version = vllm_config.quant_config.quant_description.get(
"version", "0")
self.group_size = vllm_config.quant_config.quant_description.get("group_size", 256)
quant_version = vllm_config.quant_config.quant_description.get("version", "0")
self.new_quant_version = quant_version == "1.0.0"
from vllm.distributed import get_tensor_model_parallel_world_size
self.tp_size = get_tensor_model_parallel_world_size()
def get_weight(self, input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]:
"""Create weight parameters.
For new quantization version (double int4 pack into int8), the output dimension
is compressed by factor 2 (e.g., [2048, 3072] -> [1024, 3072]). The returned
dict includes "_packed_dim" and "_packed_factor" for vLLM's weight loader.
@@ -62,40 +61,26 @@ class AscendW4A8DynamicLinearMethod(AscendLinearScheme):
# double int4 pack into int8: output dimension is compressed
pack_factor = 2
actual_output_size = output_size // pack_factor
params_dict["weight"] = torch.empty(actual_output_size,
input_size,
dtype=torch.int8)
params_dict["weight"] = torch.empty(actual_output_size, input_size, dtype=torch.int8)
# Add packing information for vLLM's weight_loader
params_dict["_packed_dim"] = 0
params_dict["_packed_factor"] = pack_factor
else:
params_dict["weight"] = torch.empty(output_size,
input_size,
dtype=torch.int8)
params_dict["weight"] = torch.empty(output_size, input_size, dtype=torch.int8)
return params_dict
def get_pergroup_param(self,
input_size: int,
output_size: int,
params_dtype: torch.dtype,
layer_type: Optional[str] = None) -> Dict[str, Any]:
def get_pergroup_param(
self, input_size: int, output_size: int, params_dtype: torch.dtype, layer_type: str | None = None
) -> dict[str, Any]:
"""Create per-group quantization parameters."""
params_dict = {}
params_dict["weight_scale"] = torch.empty(output_size,
1,
dtype=params_dtype)
params_dict["weight_offset"] = torch.empty(output_size,
1,
dtype=params_dtype)
params_dict["weight_scale_second"] = torch.empty(output_size,
input_size //
self.group_size,
dtype=params_dtype)
params_dict["weight_offset_second"] = torch.empty(output_size,
input_size //
self.group_size,
dtype=params_dtype)
params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype)
params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype)
params_dict["weight_scale_second"] = torch.empty(output_size, input_size // self.group_size, dtype=params_dtype)
params_dict["weight_offset_second"] = torch.empty(
output_size, input_size // self.group_size, dtype=params_dtype
)
# NOTE: In w4a8 quantization implementation,
# for down_proj and o_proj(layer_type == "row") scale_bias shape is [output_size, 16],
@@ -103,24 +88,21 @@ class AscendW4A8DynamicLinearMethod(AscendLinearScheme):
if self.new_quant_version:
scale_bias_dim = 16 if layer_type == "row" else 1
params_dict["scale_bias"] = torch.empty(output_size,
scale_bias_dim,
dtype=torch.float32)
params_dict["scale_bias"] = torch.empty(output_size, scale_bias_dim, dtype=torch.float32)
return params_dict
@staticmethod
def process_scale_second(weight: torch.Tensor,
scale: torch.Tensor,
per_group_scale: torch.Tensor,
is_new_quant: bool = False):
def process_scale_second(
weight: torch.Tensor, scale: torch.Tensor, per_group_scale: torch.Tensor, is_new_quant: bool = False
):
"""Process the scale for second-level quantization.
Args:
weight: weight tensor [k, n] (in new version, n is already compressed to n/2)
scale: first-level quantization scale [output_size]
per_group_scale: second-level per-group quantization scale [group_num, n_scale]
is_new_quant: whether it's the new quantization version (weight already compressed)
Returns:
(antiquant_scale, bias): dequantization scale and bias (bias=None for new version)
"""
@@ -133,8 +115,7 @@ class AscendW4A8DynamicLinearMethod(AscendLinearScheme):
bias = None
if not is_new_quant:
weight_high = weight.to(torch.float32).reshape(
group_num, -1, n) * per_group_scale.reshape(group_num, 1, n)
weight_high = weight.to(torch.float32).reshape(group_num, -1, n) * per_group_scale.reshape(group_num, 1, n)
weight_high = weight_high.reshape(k, n)
bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0)
# NOTE: scale_bias is not used currently
@@ -148,8 +129,8 @@ class AscendW4A8DynamicLinearMethod(AscendLinearScheme):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = None,
bias: torch.Tensor | None = None,
tp_rank: int | None = None,
) -> torch.Tensor:
return torch_npu.npu_weight_quant_batchmatmul(
x,
@@ -161,8 +142,7 @@ class AscendW4A8DynamicLinearMethod(AscendLinearScheme):
def process_weights_after_loading(self, layer: torch.nn.Module):
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight.data = maybe_trans_nz(layer.weight.data)
layer.weight_scale.data = layer.weight_scale.data.flatten().to(
torch.float32)
layer.weight_scale.data = layer.weight_scale.data.flatten().to(torch.float32)
layer.weight_offset.data = layer.weight_offset.data.flatten()
layer.weight_scale_second.data, scale_bias = self.process_scale_second(
layer.weight.data,
@@ -187,15 +167,14 @@ class AscendW4A8DynamicLinearMethod(AscendLinearScheme):
if self.new_quant_version:
# weights on disk are already in packed int4 format
# pack 4 int8(int4*2) to int32
assert layer.weight.data.shape[-1] % 4 == 0, \
assert layer.weight.data.shape[-1] % 4 == 0, (
f"the last dim of weight needs to be divided by 4, got shape {layer.weight.data.shape}"
layer.weight.data = layer.weight.data.view(
torch.int32).contiguous()
)
layer.weight.data = layer.weight.data.view(torch.int32).contiguous()
else:
# weights are not compressed
# need to be packed via npu_convert_weight_to_int4pack
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(
layer.weight.data.to(torch.int32))
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(layer.weight.data.to(torch.int32))
@register_scheme("W4A8_DYNAMIC", "moe")
@@ -209,69 +188,56 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
self.ep_group = get_ep_group()
vllm_config = get_current_vllm_config()
self.group_size = vllm_config.quant_config.quant_description.get(
"group_size", 256)
self.group_size = vllm_config.quant_config.quant_description.get("group_size", 256)
# NOTE: the weights are quantized from bf16 to int4 through a per-channel quantization process
self.is_per_channel_weight = self.group_size == 0
quant_version = vllm_config.quant_config.quant_description.get(
"version", "0")
quant_version = vllm_config.quant_config.quant_description.get("version", "0")
# NOTE: new quantize weights: 2 int4 pack into int8
self.new_quant_version = quant_version == "1.0.0"
self.quant_method = vllm_config.quant_config.quant_description.get(
"ascend_quant_method", "")
self.quant_method = vllm_config.quant_config.quant_description.get("ascend_quant_method", "")
if self.quant_method == COMPRESSED_TENSORS_METHOD:
self.weight_strategy = vllm_config.quant_config.quant_description.get(
"weight_strategy", "group")
self.weight_strategy = vllm_config.quant_config.quant_description.get("weight_strategy", "group")
self.tp_size = 1 if vllm_config.parallel_config.enable_expert_parallel else self.ep_group.world_size
self.dynamic_eplb = get_ascend_config().eplb_config.dynamic_eplb
if self.new_quant_version and self.tp_size > 16:
raise ValueError(
"The current weight does not support moe part tp>16.")
raise ValueError("The current weight does not support moe part tp>16.")
try:
device_group = get_mc2_group().device_group
# TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch.distributed.get_rank(group=device_group)
backend = device_group._get_backend(torch.device("npu"))
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
local_rank)
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank)
except AttributeError:
self.moe_all_to_all_group_name = ""
def get_weight(self, num_experts: int,
intermediate_size_per_partition: int, hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
def get_weight(
self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
) -> dict[str, Any]:
if self.quant_method == COMPRESSED_TENSORS_METHOD:
return self.get_weight_compressed_tensors(
num_experts, intermediate_size_per_partition,
hidden_sizes, params_dtype)
num_experts, intermediate_size_per_partition, hidden_sizes, params_dtype
)
else:
return self.get_weight_modelslim(
num_experts, intermediate_size_per_partition,
hidden_sizes, params_dtype)
def get_weight_compressed_tensors(self, num_experts: int,
intermediate_size_per_partition: int, hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
return self.get_weight_modelslim(num_experts, intermediate_size_per_partition, hidden_sizes, params_dtype)
def get_weight_compressed_tensors(
self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
) -> dict[str, Any]:
param_dict = {}
E = num_experts
H = hidden_sizes
IN = intermediate_size_per_partition
g = self.group_size
param_dict["w13_weight"] = torch.empty(E, 2 * IN, H,
dtype=torch.int8)
param_dict["w2_weight"] = torch.empty(E, H, IN,
dtype=torch.int8)
param_dict["w13_weight"] = torch.empty(E, 2 * IN, H, dtype=torch.int8)
param_dict["w2_weight"] = torch.empty(E, H, IN, dtype=torch.int8)
return param_dict
def get_weight_modelslim(self, num_experts: int,
intermediate_size_per_partition: int, hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
def get_weight_modelslim(
self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
) -> dict[str, Any]:
param_dict = {}
if self.new_quant_version:
w13_output_size = intermediate_size_per_partition
@@ -280,33 +246,27 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
w13_output_size = 2 * intermediate_size_per_partition
w2_output_size = hidden_sizes
param_dict["w13_weight"] = torch.empty(num_experts,
w13_output_size,
hidden_sizes,
dtype=torch.int8)
param_dict["w2_weight"] = torch.empty(num_experts,
w2_output_size,
intermediate_size_per_partition,
dtype=torch.int8)
param_dict["w13_weight"] = torch.empty(num_experts, w13_output_size, hidden_sizes, dtype=torch.int8)
param_dict["w2_weight"] = torch.empty(
num_experts, w2_output_size, intermediate_size_per_partition, dtype=torch.int8
)
return param_dict
def get_dynamic_quant_param(self, num_experts: int,
intermediate_size_per_partition: int,
hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
def get_dynamic_quant_param(
self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
) -> dict[str, Any]:
if self.quant_method == COMPRESSED_TENSORS_METHOD:
return self.get_dynamic_quant_param_compressed_tensors(
num_experts, intermediate_size_per_partition,
hidden_sizes, params_dtype)
num_experts, intermediate_size_per_partition, hidden_sizes, params_dtype
)
else:
return self.get_dynamic_quant_param_modelslim(
num_experts, intermediate_size_per_partition,
hidden_sizes, params_dtype)
def get_dynamic_quant_param_compressed_tensors(self, num_experts: int,
intermediate_size_per_partition: int,
hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
num_experts, intermediate_size_per_partition, hidden_sizes, params_dtype
)
def get_dynamic_quant_param_compressed_tensors(
self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
) -> dict[str, Any]:
param_dict = {}
E = num_experts
@@ -318,72 +278,48 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
def _n_scale_cols(in_features: int) -> int:
return 1 if g <= 0 else (in_features // g)
param_dict["w13_weight_scale"] = torch.empty(
E, 2 * IN, _n_scale_cols(H), dtype=torch.bfloat16)
param_dict["w13_weight_scale"] = torch.empty(E, 2 * IN, _n_scale_cols(H), dtype=torch.bfloat16)
param_dict["w2_weight_scale"] = torch.empty(E, H, _n_scale_cols(IN),
dtype=torch.bfloat16)
param_dict["w2_weight_scale"] = torch.empty(E, H, _n_scale_cols(IN), dtype=torch.bfloat16)
return param_dict
def get_dynamic_quant_param_modelslim(self, num_experts: int,
intermediate_size_per_partition: int,
hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
def get_dynamic_quant_param_modelslim(
self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
) -> dict[str, Any]:
param_dict = {}
param_dict["w13_weight_scale"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32)
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
)
param_dict["w13_weight_offset"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32)
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
)
param_dict["w2_weight_scale"] = torch.empty(num_experts,
hidden_sizes,
1,
dtype=torch.float32)
param_dict["w2_weight_offset"] = torch.empty(num_experts,
hidden_sizes,
1,
dtype=torch.float32)
param_dict["w2_weight_scale"] = torch.empty(num_experts, hidden_sizes, 1, dtype=torch.float32)
param_dict["w2_weight_offset"] = torch.empty(num_experts, hidden_sizes, 1, dtype=torch.float32)
if not self.is_per_channel_weight:
param_dict["w13_weight_scale_second"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_sizes // self.group_size,
dtype=torch.float32)
num_experts, 2 * intermediate_size_per_partition, hidden_sizes // self.group_size, dtype=torch.float32
)
param_dict["w13_weight_offset_second"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_sizes // self.group_size,
dtype=torch.float32)
num_experts, 2 * intermediate_size_per_partition, hidden_sizes // self.group_size, dtype=torch.float32
)
param_dict["w2_weight_scale_second"] = torch.empty(
num_experts,
hidden_sizes,
intermediate_size_per_partition // self.group_size,
dtype=torch.float32)
num_experts, hidden_sizes, intermediate_size_per_partition // self.group_size, dtype=torch.float32
)
param_dict["w2_weight_offset_second"] = torch.empty(
num_experts,
hidden_sizes,
intermediate_size_per_partition // self.group_size,
dtype=torch.float32)
num_experts, hidden_sizes, intermediate_size_per_partition // self.group_size, dtype=torch.float32
)
if self.new_quant_version:
param_dict["w13_scale_bias"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32)
param_dict["w2_scale_bias"] = torch.empty(num_experts,
hidden_sizes,
16 // self.tp_size,
dtype=torch.float32)
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
)
param_dict["w2_scale_bias"] = torch.empty(
num_experts, hidden_sizes, 16 // self.tp_size, dtype=torch.float32
)
return param_dict
@@ -396,21 +332,22 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
renormalize: bool,
use_grouped_topk: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
expert_map: torch.Tensor | None = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
e_score_correction_bias: torch.Tensor | None = None,
is_prefill: bool = True,
enable_force_load_balance: bool = False,
log2phy: Optional[torch.Tensor] = None,
log2phy: torch.Tensor | None = None,
global_redundant_expert_num: int = 0,
**kwargs,
) -> torch.Tensor:
assert router_logits.shape[
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, (
"Number of global experts mismatch (excluding redundancy)"
)
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
topk_weights, topk_ids = select_experts(
@@ -424,18 +361,17 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
global_num_experts=global_num_experts,
)
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
random_matrix = torch.rand(topk_ids.size(0),
global_num_experts -
global_redundant_expert_num,
device=topk_ids.device)
topk_ids = torch.argsort(
random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)
random_matrix = torch.rand(
topk_ids.size(0), global_num_experts - global_redundant_expert_num, device=topk_ids.device
)
topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype)
topk_weights = topk_weights.to(x.dtype)
@@ -446,25 +382,23 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
w2=[layer.w2_weight],
w1_scale=[layer.w13_weight_scale],
w2_scale=[layer.w2_weight_scale],
w1_scale_bias=layer.w13_scale_bias if hasattr(
layer, "w13_scale_bias") else None,
w2_scale_bias=layer.w2_scale_bias if hasattr(
layer, "w2_scale_bias") else None,
w1_scale_bias=layer.w13_scale_bias if hasattr(layer, "w13_scale_bias") else None,
w2_scale_bias=layer.w2_scale_bias if hasattr(layer, "w2_scale_bias") else None,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_int4_w4a8=True,
expert_map=expert_map,
log2phy=log2phy,
dynamic_eplb=self.dynamic_eplb,
mc2_mask=kwargs.get("mc2_mask", None))
mc2_mask=kwargs.get("mc2_mask"),
)
def process_scale(self, weight: torch.Tensor, scale, per_group_scale):
scale = scale.transpose(1, 2).contiguous()
if self.is_per_channel_weight:
scale_np = scale.cpu().numpy()
scale_np.dtype = np.uint32
scale_uint64_tensor = torch.from_numpy(scale_np.astype(
np.int64)).npu()
scale_uint64_tensor = torch.from_numpy(scale_np.astype(np.int64)).npu()
return scale_uint64_tensor, None
per_group_scale = per_group_scale.transpose(1, 2).contiguous()
group_num, k, n = weight.shape
@@ -475,32 +409,27 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
group_num, quantgroup_num, n = per_group_scale.shape
bias = None
if not self.new_quant_version:
weight_high = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * \
per_group_scale.reshape([group_num, quantgroup_num, 1, n])
weight_high = weight.to(torch.float32).reshape(
[group_num, quantgroup_num, -1, n]
) * per_group_scale.reshape([group_num, quantgroup_num, 1, n])
weight_high = weight_high.reshape([group_num, k, n])
bias = 8 * (weight_high.to(torch.float32) * scale).sum(axis=1)
scale_fp32 = (scale * per_group_scale).to(torch.float16).to(
torch.float32)
scale_fp32 = (scale * per_group_scale).to(torch.float16).to(torch.float32)
scale_fp32_np = scale_fp32.cpu().numpy()
scale_fp32_np.dtype = np.uint32
sscale_uint64 = np.zeros((group_num, quantgroup_num, n * 2),
dtype=np.uint32)
sscale_uint64 = np.zeros((group_num, quantgroup_num, n * 2), dtype=np.uint32)
sscale_uint64[..., ::2] = scale_fp32_np
sscale_uint64_buffer = np.frombuffer(sscale_uint64.tobytes(),
dtype=np.int64).copy()
sscale_uint64_tensor = torch.from_numpy(sscale_uint64_buffer).reshape(
group_num, quantgroup_num, n)
sscale_uint64_buffer = np.frombuffer(sscale_uint64.tobytes(), dtype=np.int64).copy()
sscale_uint64_tensor = torch.from_numpy(sscale_uint64_buffer).reshape(group_num, quantgroup_num, n)
sscale_uint64_tensor = sscale_uint64_tensor.npu()
return sscale_uint64_tensor, bias
def update_bias(self, layer, w13_bias, w2_bias):
if self.new_quant_version:
layer.w13_scale_bias.data = layer.w13_scale_bias.data.transpose(
1, 2).contiguous().sum(axis=1)
layer.w2_scale_bias.data = layer.w2_scale_bias.data.transpose(
1, 2).contiguous().sum(axis=1)
layer.w13_scale_bias.data = layer.w13_scale_bias.data.transpose(1, 2).contiguous().sum(axis=1)
layer.w2_scale_bias.data = layer.w2_scale_bias.data.transpose(1, 2).contiguous().sum(axis=1)
else:
w13_scale_bias = torch.nn.Parameter(w13_bias, requires_grad=False)
layer.register_parameter("w13_scale_bias", w13_scale_bias)
@@ -510,13 +439,12 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
def pack_to_int32(self, weight: torch.Tensor):
if self.new_quant_version:
# pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4
assert weight.shape[
-1] % 4 == 0, "the last dim of weight needs to be divided by 4"
assert weight.shape[-1] % 4 == 0, "the last dim of weight needs to be divided by 4"
return weight.view(torch.int32).contiguous()
else:
return torch_npu.npu_quantize(weight.to(torch.float32),
torch.tensor([1.]).npu(), None,
torch.quint4x2, -1, False)
return torch_npu.npu_quantize(
weight.to(torch.float32), torch.tensor([1.0]).npu(), None, torch.quint4x2, -1, False
)
def process_weights_after_loading(self, layer):
if self.quant_method == COMPRESSED_TENSORS_METHOD:
@@ -524,23 +452,18 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
else:
self.process_weights_after_loading_modelslim(layer)
def process_weights_after_loading_compressed_tensors(self, layer):
layer.w13_weight.data = layer.w13_weight.data.transpose(
1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(1,
2).contiguous()
layer.w13_weight.data = layer.w13_weight.data.transpose(1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(1, 2).contiguous()
def process_scale_compressed_tensors(scale: torch.Tensor):
scale = scale.transpose(1, 2).to(torch.float32).contiguous()
scale_np = scale.cpu().numpy()
scale_np.dtype = np.uint32
scale_uint64_tensor = torch.from_numpy(scale_np.astype(
np.int64)).npu()
scale_uint64_tensor = torch.from_numpy(scale_np.astype(np.int64)).npu()
return scale_uint64_tensor
def update_bias_compressed_tensors(weight: torch.Tensor,
scale: torch.Tensor, strategy:str):
def update_bias_compressed_tensors(weight: torch.Tensor, scale: torch.Tensor, strategy: str):
group_num, k, n = weight.shape
scale = scale.transpose(1, 2).contiguous()
scale = scale.reshape(group_num, -1, n)
@@ -548,8 +471,9 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
bias = None
if strategy == "group":
tmp = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * \
scale.reshape([group_num, quantgroup_num, 1, n])
tmp = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * scale.reshape(
[group_num, quantgroup_num, 1, n]
)
tmp = tmp.reshape([group_num, k, n])
bias = 8 * tmp.sum(axis=1)
elif strategy == "channel":
@@ -558,19 +482,14 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
raise ValueError(f"Unsupported weight strategy: {strategy}")
return bias
w13_bias = update_bias_compressed_tensors(layer.w13_weight.data,
layer.w13_weight_scale.data,
self.weight_strategy)
w2_bias = update_bias_compressed_tensors(layer.w2_weight.data,
layer.w2_weight_scale.data,
self.weight_strategy)
w13_bias = update_bias_compressed_tensors(
layer.w13_weight.data, layer.w13_weight_scale.data, self.weight_strategy
)
w2_bias = update_bias_compressed_tensors(layer.w2_weight.data, layer.w2_weight_scale.data, self.weight_strategy)
layer.w13_weight_scale.data = process_scale_compressed_tensors(
layer.w13_weight_scale.data)
layer.w2_weight_scale.data = process_scale_compressed_tensors(
layer.w2_weight_scale.data)
layer.w13_weight_scale.data = process_scale_compressed_tensors(layer.w13_weight_scale.data)
layer.w2_weight_scale.data = process_scale_compressed_tensors(layer.w2_weight_scale.data)
w13_scale_bias = torch.nn.Parameter(w13_bias, requires_grad=False)
layer.register_parameter("w13_scale_bias", w13_scale_bias)
w2_scale_bias = torch.nn.Parameter(w2_bias, requires_grad=False)
@@ -583,21 +502,19 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data)
def process_weights_after_loading_modelslim(self, layer):
layer.w13_weight.data = layer.w13_weight.data.transpose(
1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(1,
2).contiguous()
layer.w13_weight.data = layer.w13_weight.data.transpose(1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(1, 2).contiguous()
w13_weight_scale_second = layer.w13_weight_scale_second.data if hasattr(
layer, "w13_weight_scale_second") else None
w2_weight_scale_second = layer.w2_weight_scale_second.data if hasattr(
layer, "w2_weight_scale_second") else None
w13_weight_scale_second = (
layer.w13_weight_scale_second.data if hasattr(layer, "w13_weight_scale_second") else None
)
w2_weight_scale_second = layer.w2_weight_scale_second.data if hasattr(layer, "w2_weight_scale_second") else None
layer.w13_weight_scale.data, w13_bias = self.process_scale(
layer.w13_weight, layer.w13_weight_scale.data,
w13_weight_scale_second)
layer.w13_weight, layer.w13_weight_scale.data, w13_weight_scale_second
)
layer.w2_weight_scale.data, w2_bias = self.process_scale(
layer.w2_weight, layer.w2_weight_scale.data,
w2_weight_scale_second)
layer.w2_weight, layer.w2_weight_scale.data, w2_weight_scale_second
)
if hasattr(layer, "w13_weight_scale_second"):
# scale_second is no longer used, release this part of the memory
del layer.w13_weight_scale_second