Files
2026-02-04 17:22:39 +08:00

419 lines
15 KiB
Python

import argparse
import torch
from datasets import load_dataset
import logging
import csv
import os
from vllm import LLM, SamplingParams
from utils_internal import convert_to_merged, cleanup, vllm_cleanup, should_skip
from input_context import prepare_inputs
from dump_smooth import save_prompt_token_ids, save_input_ids, save_act_range, save_weights, save_generate_weights
from model_special import smooth_model_config
logger = logging.getLogger(__name__)
def load_prompts_from_csv(args):
'''
load prompts from csv file
'''
if args.prompt_file is not None:
prompt_file = args.prompt_file
else:
current_dir = os.path.dirname(__file__)
prompt_file = os.path.join(current_dir, 'summarize_1024_prompts.csv')
# 从 CSV 文件加载数据为 List
loaded_prompts = []
# 从按列显示的 CSV 文件中读取数据并转换为 List 形式
with open(prompt_file, 'r', newline='') as file:
reader = csv.reader(file)
loaded_prompts = list(zip(*reader))[0]
loaded_prompts = list(loaded_prompts)
num_samples = min(args.num_samples, len(loaded_prompts))
prompts = loaded_prompts[0:num_samples]
return prompts
def save_summarize_1024_prompts_as_csv(prompts):
'''
save summarize 512 prompts
'''
# 将 List 数据按列保存为 CSV 文件
# 转置 List
transposed_prompts = [prompts]
with open('summarize_1024_prompts.csv', 'w', newline='') as file:
writer = csv.writer(file)
writer.writerows(zip(*transposed_prompts))
def generate_prompts(args: argparse.Namespace):
'''
Generate prompts based on the evaluation task and arguments.
'''
eval_task_config = {
"code_completion": {
"dataset_name": "openai_humaneval",
"dataset_revision": None,
"dataset_input_key": "prompt",
"dataset_split": "test"
},
"summarize": {
"dataset_name": "ccdv/cnn_dailymail",
"dataset_revision": "3.0.0",
"dataset_input_key": "article",
"dataset_split": "train"
},
"summarize_long": {
"dataset_name": "tau/zero_scrolls",
"dataset_revision": "squality",
"dataset_input_key": "input",
"dataset_split": "validation"
},
"summarize_hg": {
"dataset_name": "cnn_dailymail",
"dataset_revision": "3.0.0",
"dataset_input_key": "article",
"dataset_split": "validation"
},
"text_generation": {
"dataset_name": "lambada",
"dataset_revision": None,
"dataset_input_key": "text",
"dataset_split": "validation"
}
}
if args.eval_task in eval_task_config:
config = eval_task_config[args.eval_task]
dataset_name = config["dataset_name"]
dataset_revision = config["dataset_revision"]
dataset_input_key = config["dataset_input_key"]
dataset_split = config["dataset_split"]
else:
assert args.dataset_name is not None, f"dataset_name is None when eval_task == custom"
assert args.dataset_input_key is not None, f"dataset_input_key is None when eval_task == custom"
assert args.dataset_split is not None, f"dataset_split is None when eval_task == custom"
dataset_name = args.dataset_name
dataset_revision = args.dataset_revision
dataset_input_key = args.dataset_input_key
dataset_split = args.dataset_split
if args.prompt_file is not None or (args.eval_task == "summarize" and args.num_samples <= 1024):
prompts = load_prompts_from_csv(args)
num_samples = min(args.num_samples, len(prompts))
else:
dataset = load_dataset(dataset_name,
dataset_revision,
cache_dir=args.dataset_cache_dir,
split=dataset_split,
trust_remote_code=True)
num_samples = min(args.num_samples, len(dataset))
prompts = dataset[0:num_samples][dataset_input_key]
# save_summarize_1024_prompts_as_csv(prompts)
prompt_token_ids = []
if args.has_prompt_token_id:
batch_input_ids = prepare_inputs(prompts,
args.tokenizer,
args.model_name,
args.model_version,
args.max_input_length,
eval_task=args.eval_task,
add_special_tokens=args.add_special_tokens)
save_prompt_token_ids(batch_input_ids, args)
for i in range(num_samples):
prompt_token_ids.append(batch_input_ids[i].tolist())
if len(prompts) == 0:
prompts = None
else:
prompts = [s[:args.max_input_length] for s in prompts]
if len(prompt_token_ids) == 0:
prompt_token_ids = None
return prompts, prompt_token_ids
@torch.no_grad()
def get_smooth_cal_weight(name, weight, name_parameters, act_range, model_type):
'''
get cal_weight for smooth process to solve q/k/v and gate/up layer merged condition in vllm
args:
name: weight name
weight: weight value
name_parameters: named parameters
act_range: layer act range info of name
model_type: model type
'''
if act_range["is_qkv"] is True:
name_parts = name.split(".")
self_attn_layer_name = ".".join(name_parts[:-2])
qkv_list = smooth_model_config[model_type]["qkv_list"]
q_weight_name = f"{self_attn_layer_name}.{qkv_list[0]}.weight"
k_weight_name = f"{self_attn_layer_name}.{qkv_list[1]}.weight"
v_weight_name = f"{self_attn_layer_name}.{qkv_list[2]}.weight"
q_weight = name_parameters[q_weight_name]
k_weight = name_parameters[k_weight_name]
v_weight = name_parameters[v_weight_name]
cal_weight = torch.cat([q_weight, k_weight, v_weight], dim=0)
elif act_range["is_merge"] is True:
name_parts = name.split(".")
mlp_layer_name = ".".join(name_parts[:-2])
gate_up_list = smooth_model_config[model_type]["gate_up_list"]
gate_weight_name = f"{mlp_layer_name}.{gate_up_list[0]}.weight"
up_weight_name = f"{mlp_layer_name}.{gate_up_list[1]}.weight"
gate_weight = name_parameters[gate_weight_name]
up_weight = name_parameters[up_weight_name]
cal_weight = torch.cat([gate_weight, up_weight], dim=0)
else:
cal_weight = weight
return cal_weight
@torch.no_grad()
def cal_smoother(weight, act_range_x, alpha=0.5):
'''
calculate smoother value
args:
weight: smoother weight
act_range_x: activation max value of per channel
alpha: smooth factor, default 0.5
'''
assert weight.shape[-1] == act_range_x.numel()
weight_scales = weight.view(-1, weight.shape[-1])
weight_scales = weight_scales.abs().max(dim=0)[0]
weight_scales = weight_scales.to(float).clamp(min=1e-6)
smoother = (act_range_x.to(weight_scales.device).to(float).pow(alpha) /
weight_scales.pow(1 - alpha)).clamp(min=1e-6)
return smoother
@torch.no_grad()
def cal_qweight_scales(sweight, smooth_act_range_x, per_token, per_channel):
'''
calculate quantized weight anc scales
args:
sweight: weight which has been divided by smoother value
smooth_act_range_x: activation max value which has beed divide by smoother value
per_token: bool, means whether calculate the weight and scales dynamically
per_channel: bool, mean whether calculate the weight and scales by channel
'''
scale_x_quant_orig_t = smooth_act_range_x.max() / 127.0
smooth_act_range_w = sweight.abs().max(dim=-1)[0]
smooth_act_range_w = smooth_act_range_w.to(float).clamp(min=1e-6)
scale_w_quant_orig_c = smooth_act_range_w / 127.0
scale_w_quant_orig_t = smooth_act_range_w.max() / 127
if per_channel:
qweight = (sweight / scale_w_quant_orig_c[..., None])
else:
qweight = (sweight / scale_w_quant_orig_t)
qweight = qweight.clip(-128, 127).to(torch.int8)
scale_to_int = 1 / scale_x_quant_orig_t
if per_token:
if per_channel:
per_channel_scale = scale_w_quant_orig_c
else:
per_channel_scale = scale_w_quant_orig_t
else:
if per_channel:
per_channel_scale = scale_x_quant_orig_t * scale_w_quant_orig_c
hidden_size = smooth_act_range_x.numel()
scale_to_int = scale_to_int.repeat(hidden_size)
else:
per_channel_scale = scale_x_quant_orig_t * scale_w_quant_orig_t
per_channel_scale = per_channel_scale.squeeze()
if per_channel_scale.numel() == 1 and per_channel_scale.dim() == 0:
per_channel_scale = per_channel_scale.unsqueeze(0)
if scale_to_int.numel() == 1 and scale_to_int.dim() == 0:
scale_to_int = scale_to_int.unsqueeze(0)
sinfo = [
scale_w_quant_orig_t.item(), scale_x_quant_orig_t.item(),
scale_w_quant_orig_t.item() / scale_x_quant_orig_t.item()
]
return qweight, per_channel_scale, scale_to_int, sinfo
def check_smooth_weight_vaild(name, qweight, per_channel_scale, smooth, qzeros, scale_to_int):
'''
check whether nan/inf appears in qweight, per_channel_scale, smooth, qzeros, scale_to_int
'''
if torch.isinf(qweight).any() or torch.isnan(qweight).any():
logger.error(f"name:{name} qweight has inf or nan")
if torch.isinf(per_channel_scale).any() or torch.isnan(per_channel_scale).any():
logger.error(f"name:{name} per_channel_scale has inf or nan")
if torch.isinf(smooth).any() or torch.isnan(smooth).any():
logger.error(f"name:{name} smooth has inf or nan")
if torch.isinf(scale_to_int).any() or torch.isnan(scale_to_int).any():
logger.error(f"name:{name} scale_to_int has inf or nan")
if qzeros is not None and (torch.isinf(qzeros).any() or torch.isnan(qzeros).any()):
logger.error(f"name:{name} qzeros has inf or nan")
@torch.no_grad()
def cal_smooth_weight(name, act_range_x, weight, smooth_value, has_qzeros, per_token, per_channel, cal_weight):
'''
calculate qweight, scales, smooth, qzeros
args:
name: weight name
act_range_x: activation max value of per channel
weight: weight to be quantized
smooth_value: smooth value
has_qzeros: which generate qzeros weight
per_token: bool, means whether calculate the weight and scales dynamically
per_channel: bool, mean whether calculate the weight and scales by channel
model_type: model type
'''
smoother = cal_smoother(cal_weight, act_range_x, smooth_value)
smooth_act_range_x = act_range_x / smoother
sweight = weight * (smoother.view(1, -1))
qweight, per_channel_scale, scale_to_int, sinfo = cal_qweight_scales(sweight, smooth_act_range_x, per_token,
per_channel)
qweight = qweight.reshape(weight.shape)
smooth = 1 / smoother
smooth = smooth.squeeze()
if has_qzeros:
qzeros = torch.zeros_like(per_channel_scale, dtype=torch.int32)
else:
qzeros = None
# check_smooth_weight_vaild(name, qweight, per_channel_scale, smooth, qzeros, scale_to_int)
return qweight, per_channel_scale, smooth, qzeros, scale_to_int, sinfo
@torch.no_grad()
def generate_smooth_weight(act_range, name_parameters, args):
'''
generate smooth weight
args:
act_range: act_range collected in model running
name_parameters: hugging face model named parameters
args: argument from main
'''
smooth_weight = {}
smooth_info = {}
has_qzeros = args.has_qzeros
smooth_value = args.smooth_value
smooth_info["title"] = ["max_scale_w, max_scale_x, max_scale_w/max_scale_x"]
for name, param in name_parameters.items():
if should_skip(args.model_type, name):
logger.info(f"skip {name}")
smooth_weight[name] = param
continue
if name.endswith("bias"):
smooth_weight[name] = param
continue
name_parts = name.split(".")
layer_name = ".".join(name_parts[:-1])
if layer_name in act_range:
act_range_x = act_range[layer_name]['x']
cal_weight = get_smooth_cal_weight(name, param, name_parameters, act_range[layer_name], args.model_type)
qweight, per_channel_scale, smooth, qzeros, scale_to_int, sinfo = cal_smooth_weight(
name, act_range_x, param, smooth_value, has_qzeros, args.per_token, args.per_channel, cal_weight)
per_channel_scale = per_channel_scale.to(args.torch_scales_smooth_dtype)
smooth = smooth.to(args.torch_scales_smooth_dtype)
scale_to_int = scale_to_int.to(args.torch_scales_smooth_dtype)
smooth_weight[f'{layer_name}.qweight'] = qweight
smooth_weight[f'{layer_name}.per_channel_scale'] = per_channel_scale
if args.per_token is True:
smooth_weight[f'{layer_name}.smooth'] = smooth
else:
scale_to_int = scale_to_int * smooth
smooth_weight[f'{layer_name}.scale_to_int'] = scale_to_int
if has_qzeros:
smooth_weight[f'{layer_name}.qzeros'] = qzeros
smooth_info[name] = sinfo
else:
smooth_weight[name] = param
return smooth_weight, smooth_info
def generate_weights_of_smoothquant(llm: LLM, args: argparse.Namespace):
'''
generate smoothquant weights
args:
llm: LLM instance
args: argument from main
'''
prompts, prompt_token_ids = generate_prompts(args)
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=args.output_len,
repetition_penalty=args.repetition_penalty,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k)
tp_size = args.tp_size
llm.llm_engine.model_executor._run_workers("setup_smooth_hook", args.dump_input_ids)
llm.generate(prompts, sampling_params, prompt_token_ids=prompt_token_ids, use_tqdm=True)
logger.info("llm generate finished")
llm.llm_engine.model_executor._run_workers("remove_hooks")
act_range = llm.llm_engine.model_executor._run_workers("get_act_range")
named_parameters = llm.llm_engine.model_executor._run_workers("get_named_parameters")
vllm_cleanup(llm)
del prompts
del prompt_token_ids
cleanup()
logger.info("get act_range and named_parameters from llm finished")
merged_act_range, merged_named_parameters, input_id_list = convert_to_merged(act_range, named_parameters, tp_size,
args)
save_input_ids(input_id_list, args)
save_act_range(merged_act_range, args)
save_weights(merged_named_parameters, args)
del act_range
del named_parameters
cleanup()
logger.info("get merged_act_range and merged_named_parameters finished")
smooth_weight, smooth_info = generate_smooth_weight(merged_act_range, merged_named_parameters, args)
save_generate_weights(smooth_weight, args)
del merged_act_range
del merged_named_parameters
cleanup()
logger.info("get smooth_weight finished")
return smooth_weight, smooth_info