import argparse import torch from torch import Tensor import numpy as np import logging from vllm import LLM from utils_internal import convert_to_merged, cleanup, vllm_cleanup, should_skip from dump_smooth import save_weights, save_generate_weights logger = logging.getLogger(__name__) def merge_adjacent_low_4bit(tensor: Tensor): """ 将一个包含int8类型数据的张量,按相邻两个元素的低4位合并成新的int8数据, 并输出一个新的张量。 参数: - tensor: 类型为torch.int8的张量,长度应为偶数。 返回: - 新张量,其中每个元素是相邻原元素低4位的合并结果。 示例: a = torch.tensor([5, 7, 12, 3], dtype=torch.int8) # 示例张量,每对元素将被合并 merged_tensor = merge_adjacent_low_nibbles(a) print(f"合并后的张量: {merged_tensor} (二进制: {merged_tensor.tolist()})") """ # 确保输入张量类型为int8且长度为偶数 assert tensor.dtype == torch.int8, "输入张量必须为int8类型" assert tensor.shape[-1] % 2 == 0, "输入张量最后一维长度需为偶数" even = np.bitwise_and(tensor[..., 0::2], 0x0F, dtype=np.int8) odd = np.bitwise_and(tensor[..., 1::2], 0x0F, dtype=np.int8) merged_tensor = np.bitwise_or(np.left_shift(odd, 4), even) # 结果是已经合并的新张量 return merged_tensor def cal_weightonly_weight(weight, weight_bits, qmin, qmax, has_qzeros, eps: float = 1e-8): ''' return quantized_weight, scales, qzeros args: weight: need to be quantized weight_bits: quantized bitwidth qmin: minimum value in quantized range qmax: maximum value in quantized range has_qzeros: whether to generate qzeros weight eps: limit zero float value to avoid floatpoint error ''' assert weight.numel() != 0, "weight should not be empty tensor" assert weight.dim() == 2 or weight.dim() == 3, "Invalid dim. The dim of weight should be 2 or 3" assert weight.dtype in [torch.float32, torch.float16, torch.bfloat16 ], "Invalid datatype. Weight must be torch.float32 or torch.float16 or torch.bfloat16" weight_scale = weight.float().abs().clamp(min=eps).max(dim=-1).values / qmax unpacked_weight = (torch.round((weight / weight_scale[..., None]).float())).clip(min=qmin, max=qmax).to(torch.int8) scale_quant_orig_c = weight_scale.squeeze() if weight_bits == 4: quantized_weight = merge_adjacent_low_4bit(unpacked_weight) else: quantized_weight = unpacked_weight if has_qzeros: qzeros = torch.zeros_like(scale_quant_orig_c, dtype=torch.int32) else: qzeros = None return quantized_weight, scale_quant_orig_c, qzeros def generate_weightonly_weight(act_range, name_parameters, args): ''' generate hugging face weight to quanizated weightonly weight args: act_range: non parallem act_range name_parameters: non parallel hugging face named parameters args: arguments from main ''' weightonly_weight = {} has_qzeros = args.has_qzeros weight_bits = 8 if args.weight_only_precision == 'int8' else 4 qmin = float(-2**(weight_bits - 1)) qmax = float(2**(weight_bits - 1) - 1) for name, param in name_parameters.items(): if should_skip(args.model_type, name): logger.info(f"skip {name}") weightonly_weight[name] = param continue if name.endswith("bias"): weightonly_weight[name] = param continue name_parts = name.split(".") layer_name = ".".join(name_parts[:-1]) if layer_name in act_range: qweight, scales, qzeros = cal_weightonly_weight(param, weight_bits, qmin, qmax, has_qzeros) scales = scales.to(args.torch_scales_smooth_dtype) weightonly_weight[f'{layer_name}.qweight'] = qweight weightonly_weight[f'{layer_name}.scales'] = scales if has_qzeros: weightonly_weight[f'{layer_name}.qzeros'] = qzeros else: weightonly_weight[name] = param return weightonly_weight def generate_weights_of_weight_only(llm: LLM, args: argparse.Namespace): ''' generate weightonly weights args: llm: LLM instance args: argument from main ''' tp_size = args.tp_size llm.llm_engine.model_executor._run_workers("setup_smooth_hook") llm.llm_engine.model_executor._run_workers("remove_hooks") act_range = llm.llm_engine.model_executor._run_workers("get_act_range") named_parameters = llm.llm_engine.model_executor._run_workers("get_named_parameters") vllm_cleanup(llm) cleanup() logger.info("get act_range and named_parameters from llm finished") merged_act_range, merged_named_parameters, _ = convert_to_merged(act_range, named_parameters, tp_size, args) save_weights(merged_named_parameters, args) del act_range del named_parameters cleanup() logger.info("get merged_act_range and merged_named_parameters finished") weightonly_weight = generate_weightonly_weight(merged_act_range, merged_named_parameters, args) save_generate_weights(weightonly_weight, args) del merged_act_range del merged_named_parameters cleanup() logger.info("get weightonly_weight finished") return weightonly_weight