[Quantization][Feature] Support compressed tensors moe w4a8 dynamic weight (#5889)
### What this PR does / why we need it?
While using the LLM Compressor quantization tool from the VLLM community
to generate quantized weights, the VLLM Ascend engine needs to be
adapted to support the compressed tensors quantization format.
1. Support Moe model W4A8 dynamic weight.
- vLLM version: v0.13.0
- vLLM main:
bde38c11df
---------
Signed-off-by: LHXuuu <scut_xlh@163.com>
Signed-off-by: menogrey <1299267905@qq.com>
Co-authored-by: menogrey <1299267905@qq.com>
This commit is contained in:
@@ -27,7 +27,7 @@ from vllm.forward_context import get_forward_context
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.utils import maybe_trans_nz
|
||||
from vllm_ascend.utils import maybe_trans_nz, COMPRESSED_TENSORS_METHOD
|
||||
|
||||
from .base import AscendLinearScheme, AscendMoEScheme, QuantType
|
||||
from .registry import register_scheme
|
||||
@@ -217,6 +217,13 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
|
||||
"version", "0")
|
||||
# NOTE: new quantize weights: 2 int4 pack into int8
|
||||
self.new_quant_version = quant_version == "1.0.0"
|
||||
|
||||
self.quant_method = vllm_config.quant_config.quant_description.get(
|
||||
"ascend_quant_method", "")
|
||||
if self.quant_method == COMPRESSED_TENSORS_METHOD:
|
||||
self.weight_strategy = vllm_config.quant_config.quant_description.get(
|
||||
"weight_strategy", "group")
|
||||
|
||||
self.tp_size = 1 if vllm_config.parallel_config.enable_expert_parallel else self.ep_group.world_size
|
||||
self.dynamic_eplb = get_ascend_config().eplb_config.dynamic_eplb
|
||||
if self.new_quant_version and self.tp_size > 16:
|
||||
@@ -236,6 +243,35 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
|
||||
def get_weight(self, num_experts: int,
|
||||
intermediate_size_per_partition: int, hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
if self.quant_method == COMPRESSED_TENSORS_METHOD:
|
||||
return self.get_weight_compressed_tensors(
|
||||
num_experts, intermediate_size_per_partition,
|
||||
hidden_sizes, params_dtype)
|
||||
else:
|
||||
return self.get_weight_modelslim(
|
||||
num_experts, intermediate_size_per_partition,
|
||||
hidden_sizes, params_dtype)
|
||||
|
||||
def get_weight_compressed_tensors(self, num_experts: int,
|
||||
intermediate_size_per_partition: int, hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
|
||||
param_dict = {}
|
||||
E = num_experts
|
||||
H = hidden_sizes
|
||||
IN = intermediate_size_per_partition
|
||||
g = self.group_size
|
||||
|
||||
param_dict["w13_weight"] = torch.empty(E, 2 * IN, H,
|
||||
dtype=torch.int8)
|
||||
param_dict["w2_weight"] = torch.empty(E, H, IN,
|
||||
dtype=torch.int8)
|
||||
return param_dict
|
||||
|
||||
|
||||
def get_weight_modelslim(self, num_experts: int,
|
||||
intermediate_size_per_partition: int, hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
param_dict = {}
|
||||
if self.new_quant_version:
|
||||
w13_output_size = intermediate_size_per_partition
|
||||
@@ -258,6 +294,42 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
|
||||
intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
if self.quant_method == COMPRESSED_TENSORS_METHOD:
|
||||
return self.get_dynamic_quant_param_compressed_tensors(
|
||||
num_experts, intermediate_size_per_partition,
|
||||
hidden_sizes, params_dtype)
|
||||
else:
|
||||
return self.get_dynamic_quant_param_modelslim(
|
||||
num_experts, intermediate_size_per_partition,
|
||||
hidden_sizes, params_dtype)
|
||||
|
||||
def get_dynamic_quant_param_compressed_tensors(self, num_experts: int,
|
||||
intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
param_dict = {}
|
||||
|
||||
E = num_experts
|
||||
H = hidden_sizes
|
||||
IN = intermediate_size_per_partition
|
||||
g = self.group_size
|
||||
|
||||
# Per-row scale columns
|
||||
def _n_scale_cols(in_features: int) -> int:
|
||||
return 1 if g <= 0 else (in_features // g)
|
||||
|
||||
param_dict["w13_weight_scale"] = torch.empty(
|
||||
E, 2 * IN, _n_scale_cols(H), dtype=torch.bfloat16)
|
||||
|
||||
param_dict["w2_weight_scale"] = torch.empty(E, H, _n_scale_cols(IN),
|
||||
dtype=torch.bfloat16)
|
||||
|
||||
return param_dict
|
||||
|
||||
def get_dynamic_quant_param_modelslim(self, num_experts: int,
|
||||
intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
param_dict = {}
|
||||
param_dict["w13_weight_scale"] = torch.empty(
|
||||
num_experts,
|
||||
@@ -374,8 +446,10 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
|
||||
w2=[layer.w2_weight],
|
||||
w1_scale=[layer.w13_weight_scale],
|
||||
w2_scale=[layer.w2_weight_scale],
|
||||
w1_scale_bias=layer.w13_scale_bias,
|
||||
w2_scale_bias=layer.w2_scale_bias,
|
||||
w1_scale_bias=layer.w13_scale_bias if hasattr(
|
||||
layer, "w13_scale_bias") else None,
|
||||
w2_scale_bias=layer.w2_scale_bias if hasattr(
|
||||
layer, "w2_scale_bias") else None,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
use_int4_w4a8=True,
|
||||
@@ -445,6 +519,70 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
|
||||
torch.quint4x2, -1, False)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.quant_method == COMPRESSED_TENSORS_METHOD:
|
||||
self.process_weights_after_loading_compressed_tensors(layer)
|
||||
else:
|
||||
self.process_weights_after_loading_modelslim(layer)
|
||||
|
||||
|
||||
def process_weights_after_loading_compressed_tensors(self, layer):
|
||||
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(1,
|
||||
2).contiguous()
|
||||
|
||||
def process_scale_compressed_tensors(scale: torch.Tensor):
|
||||
scale = scale.transpose(1, 2).to(torch.float32).contiguous()
|
||||
scale_np = scale.cpu().numpy()
|
||||
scale_np.dtype = np.uint32
|
||||
scale_uint64_tensor = torch.from_numpy(scale_np.astype(
|
||||
np.int64)).npu()
|
||||
return scale_uint64_tensor
|
||||
|
||||
def update_bias_compressed_tensors(weight: torch.Tensor,
|
||||
scale: torch.Tensor, strategy:str):
|
||||
group_num, k, n = weight.shape
|
||||
scale = scale.transpose(1, 2).contiguous()
|
||||
scale = scale.reshape(group_num, -1, n)
|
||||
group_num, quantgroup_num, n = scale.shape
|
||||
|
||||
bias = None
|
||||
if strategy == "group":
|
||||
tmp = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * \
|
||||
scale.reshape([group_num, quantgroup_num, 1, n])
|
||||
tmp = tmp.reshape([group_num, k, n])
|
||||
bias = 8 * tmp.sum(axis=1)
|
||||
elif strategy == "channel":
|
||||
bias = 8 * (weight.to(torch.float32) * scale).sum(axis=1)
|
||||
else:
|
||||
raise ValueError(f"Unsupported weight strategy: {strategy}")
|
||||
return bias
|
||||
|
||||
w13_bias = update_bias_compressed_tensors(layer.w13_weight.data,
|
||||
layer.w13_weight_scale.data,
|
||||
self.weight_strategy)
|
||||
w2_bias = update_bias_compressed_tensors(layer.w2_weight.data,
|
||||
layer.w2_weight_scale.data,
|
||||
self.weight_strategy)
|
||||
|
||||
layer.w13_weight_scale.data = process_scale_compressed_tensors(
|
||||
layer.w13_weight_scale.data)
|
||||
layer.w2_weight_scale.data = process_scale_compressed_tensors(
|
||||
layer.w2_weight_scale.data)
|
||||
|
||||
|
||||
w13_scale_bias = torch.nn.Parameter(w13_bias, requires_grad=False)
|
||||
layer.register_parameter("w13_scale_bias", w13_scale_bias)
|
||||
w2_scale_bias = torch.nn.Parameter(w2_bias, requires_grad=False)
|
||||
layer.register_parameter("w2_scale_bias", w2_scale_bias)
|
||||
|
||||
# Accuracy problem in nz format
|
||||
# layer.w13_weight.data = maybe_trans_nz(layer.w13_weight.data)
|
||||
# layer.w2_weight.data = maybe_trans_nz(layer.w2_weight.data)
|
||||
layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data)
|
||||
layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data)
|
||||
|
||||
def process_weights_after_loading_modelslim(self, layer):
|
||||
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(1,
|
||||
|
||||
Reference in New Issue
Block a user