Files
enginex-mlu370-vllm/vllm-v0.6.2/tools/quant_tools/weight_only.py

153 lines
5.3 KiB
Python
Raw Normal View History

2026-02-04 17:22:39 +08:00
import argparse
import torch
from torch import Tensor
import numpy as np
import logging
from vllm import LLM
from utils_internal import convert_to_merged, cleanup, vllm_cleanup, should_skip
from dump_smooth import save_weights, save_generate_weights
logger = logging.getLogger(__name__)
def merge_adjacent_low_4bit(tensor: Tensor):
"""
将一个包含int8类型数据的张量按相邻两个元素的低4位合并成新的int8数据
并输出一个新的张量
参数:
- tensor: 类型为torch.int8的张量长度应为偶数
返回:
- 新张量其中每个元素是相邻原元素低4位的合并结果
示例:
a = torch.tensor([5, 7, 12, 3], dtype=torch.int8) # 示例张量,每对元素将被合并
merged_tensor = merge_adjacent_low_nibbles(a)
print(f"合并后的张量: {merged_tensor} (二进制: {merged_tensor.tolist()})")
"""
# 确保输入张量类型为int8且长度为偶数
assert tensor.dtype == torch.int8, "输入张量必须为int8类型"
assert tensor.shape[-1] % 2 == 0, "输入张量最后一维长度需为偶数"
even = np.bitwise_and(tensor[..., 0::2], 0x0F, dtype=np.int8)
odd = np.bitwise_and(tensor[..., 1::2], 0x0F, dtype=np.int8)
merged_tensor = np.bitwise_or(np.left_shift(odd, 4), even)
# 结果是已经合并的新张量
return merged_tensor
def cal_weightonly_weight(weight, weight_bits, qmin, qmax, has_qzeros, eps: float = 1e-8):
'''
return quantized_weight, scales, qzeros
args:
weight: need to be quantized
weight_bits: quantized bitwidth
qmin: minimum value in quantized range
qmax: maximum value in quantized range
has_qzeros: whether to generate qzeros weight
eps: limit zero float value to avoid floatpoint error
'''
assert weight.numel() != 0, "weight should not be empty tensor"
assert weight.dim() == 2 or weight.dim() == 3, "Invalid dim. The dim of weight should be 2 or 3"
assert weight.dtype in [torch.float32, torch.float16, torch.bfloat16
], "Invalid datatype. Weight must be torch.float32 or torch.float16 or torch.bfloat16"
weight_scale = weight.float().abs().clamp(min=eps).max(dim=-1).values / qmax
unpacked_weight = (torch.round((weight / weight_scale[..., None]).float())).clip(min=qmin, max=qmax).to(torch.int8)
scale_quant_orig_c = weight_scale.squeeze()
if weight_bits == 4:
quantized_weight = merge_adjacent_low_4bit(unpacked_weight)
else:
quantized_weight = unpacked_weight
if has_qzeros:
qzeros = torch.zeros_like(scale_quant_orig_c, dtype=torch.int32)
else:
qzeros = None
return quantized_weight, scale_quant_orig_c, qzeros
def generate_weightonly_weight(act_range, name_parameters, args):
'''
generate hugging face weight to quanizated weightonly weight
args:
act_range: non parallem act_range
name_parameters: non parallel hugging face named parameters
args: arguments from main
'''
weightonly_weight = {}
has_qzeros = args.has_qzeros
weight_bits = 8 if args.weight_only_precision == 'int8' else 4
qmin = float(-2**(weight_bits - 1))
qmax = float(2**(weight_bits - 1) - 1)
for name, param in name_parameters.items():
if should_skip(args.model_type, name):
logger.info(f"skip {name}")
weightonly_weight[name] = param
continue
if name.endswith("bias"):
weightonly_weight[name] = param
continue
name_parts = name.split(".")
layer_name = ".".join(name_parts[:-1])
if layer_name in act_range:
qweight, scales, qzeros = cal_weightonly_weight(param, weight_bits, qmin, qmax, has_qzeros)
scales = scales.to(args.torch_scales_smooth_dtype)
weightonly_weight[f'{layer_name}.qweight'] = qweight
weightonly_weight[f'{layer_name}.scales'] = scales
if has_qzeros:
weightonly_weight[f'{layer_name}.qzeros'] = qzeros
else:
weightonly_weight[name] = param
return weightonly_weight
def generate_weights_of_weight_only(llm: LLM, args: argparse.Namespace):
'''
generate weightonly weights
args:
llm: LLM instance
args: argument from main
'''
tp_size = args.tp_size
llm.llm_engine.model_executor._run_workers("setup_smooth_hook")
llm.llm_engine.model_executor._run_workers("remove_hooks")
act_range = llm.llm_engine.model_executor._run_workers("get_act_range")
named_parameters = llm.llm_engine.model_executor._run_workers("get_named_parameters")
vllm_cleanup(llm)
cleanup()
logger.info("get act_range and named_parameters from llm finished")
merged_act_range, merged_named_parameters, _ = convert_to_merged(act_range, named_parameters, tp_size, args)
save_weights(merged_named_parameters, args)
del act_range
del named_parameters
cleanup()
logger.info("get merged_act_range and merged_named_parameters finished")
weightonly_weight = generate_weightonly_weight(merged_act_range, merged_named_parameters, args)
save_generate_weights(weightonly_weight, args)
del merged_act_range
del merged_named_parameters
cleanup()
logger.info("get weightonly_weight finished")
return weightonly_weight