Files
enginex-mlu370-vllm/vllm-v0.6.2/tools/quant_tools/weight_only.py
2026-02-04 17:22:39 +08:00

153 lines
5.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import torch
from torch import Tensor
import numpy as np
import logging
from vllm import LLM
from utils_internal import convert_to_merged, cleanup, vllm_cleanup, should_skip
from dump_smooth import save_weights, save_generate_weights
logger = logging.getLogger(__name__)
def merge_adjacent_low_4bit(tensor: Tensor):
"""
将一个包含int8类型数据的张量按相邻两个元素的低4位合并成新的int8数据
并输出一个新的张量。
参数:
- tensor: 类型为torch.int8的张量长度应为偶数。
返回:
- 新张量其中每个元素是相邻原元素低4位的合并结果。
示例:
a = torch.tensor([5, 7, 12, 3], dtype=torch.int8) # 示例张量,每对元素将被合并
merged_tensor = merge_adjacent_low_nibbles(a)
print(f"合并后的张量: {merged_tensor} (二进制: {merged_tensor.tolist()})")
"""
# 确保输入张量类型为int8且长度为偶数
assert tensor.dtype == torch.int8, "输入张量必须为int8类型"
assert tensor.shape[-1] % 2 == 0, "输入张量最后一维长度需为偶数"
even = np.bitwise_and(tensor[..., 0::2], 0x0F, dtype=np.int8)
odd = np.bitwise_and(tensor[..., 1::2], 0x0F, dtype=np.int8)
merged_tensor = np.bitwise_or(np.left_shift(odd, 4), even)
# 结果是已经合并的新张量
return merged_tensor
def cal_weightonly_weight(weight, weight_bits, qmin, qmax, has_qzeros, eps: float = 1e-8):
'''
return quantized_weight, scales, qzeros
args:
weight: need to be quantized
weight_bits: quantized bitwidth
qmin: minimum value in quantized range
qmax: maximum value in quantized range
has_qzeros: whether to generate qzeros weight
eps: limit zero float value to avoid floatpoint error
'''
assert weight.numel() != 0, "weight should not be empty tensor"
assert weight.dim() == 2 or weight.dim() == 3, "Invalid dim. The dim of weight should be 2 or 3"
assert weight.dtype in [torch.float32, torch.float16, torch.bfloat16
], "Invalid datatype. Weight must be torch.float32 or torch.float16 or torch.bfloat16"
weight_scale = weight.float().abs().clamp(min=eps).max(dim=-1).values / qmax
unpacked_weight = (torch.round((weight / weight_scale[..., None]).float())).clip(min=qmin, max=qmax).to(torch.int8)
scale_quant_orig_c = weight_scale.squeeze()
if weight_bits == 4:
quantized_weight = merge_adjacent_low_4bit(unpacked_weight)
else:
quantized_weight = unpacked_weight
if has_qzeros:
qzeros = torch.zeros_like(scale_quant_orig_c, dtype=torch.int32)
else:
qzeros = None
return quantized_weight, scale_quant_orig_c, qzeros
def generate_weightonly_weight(act_range, name_parameters, args):
'''
generate hugging face weight to quanizated weightonly weight
args:
act_range: non parallem act_range
name_parameters: non parallel hugging face named parameters
args: arguments from main
'''
weightonly_weight = {}
has_qzeros = args.has_qzeros
weight_bits = 8 if args.weight_only_precision == 'int8' else 4
qmin = float(-2**(weight_bits - 1))
qmax = float(2**(weight_bits - 1) - 1)
for name, param in name_parameters.items():
if should_skip(args.model_type, name):
logger.info(f"skip {name}")
weightonly_weight[name] = param
continue
if name.endswith("bias"):
weightonly_weight[name] = param
continue
name_parts = name.split(".")
layer_name = ".".join(name_parts[:-1])
if layer_name in act_range:
qweight, scales, qzeros = cal_weightonly_weight(param, weight_bits, qmin, qmax, has_qzeros)
scales = scales.to(args.torch_scales_smooth_dtype)
weightonly_weight[f'{layer_name}.qweight'] = qweight
weightonly_weight[f'{layer_name}.scales'] = scales
if has_qzeros:
weightonly_weight[f'{layer_name}.qzeros'] = qzeros
else:
weightonly_weight[name] = param
return weightonly_weight
def generate_weights_of_weight_only(llm: LLM, args: argparse.Namespace):
'''
generate weightonly weights
args:
llm: LLM instance
args: argument from main
'''
tp_size = args.tp_size
llm.llm_engine.model_executor._run_workers("setup_smooth_hook")
llm.llm_engine.model_executor._run_workers("remove_hooks")
act_range = llm.llm_engine.model_executor._run_workers("get_act_range")
named_parameters = llm.llm_engine.model_executor._run_workers("get_named_parameters")
vllm_cleanup(llm)
cleanup()
logger.info("get act_range and named_parameters from llm finished")
merged_act_range, merged_named_parameters, _ = convert_to_merged(act_range, named_parameters, tp_size, args)
save_weights(merged_named_parameters, args)
del act_range
del named_parameters
cleanup()
logger.info("get merged_act_range and merged_named_parameters finished")
weightonly_weight = generate_weightonly_weight(merged_act_range, merged_named_parameters, args)
save_generate_weights(weightonly_weight, args)
del merged_act_range
del merged_named_parameters
cleanup()
logger.info("get weightonly_weight finished")
return weightonly_weight