Files
2026-04-02 04:55:00 +00:00

1457 lines
72 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2 model compatible with HuggingFace weights, with VACC modifications."""
from typing import Iterable, Optional, Set, Tuple, Union, List, Any, Dict
import math
import os
import torch
import numpy as np
from torch import nn
from transformers import Qwen2Config
from vllm.attention import Attention, AttentionType
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, get_tp_group)
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.forward_context import ForwardContext, get_forward_context
# Import original classes to be modified
from vllm.model_executor.models.qwen2 import (Qwen2Attention as Qwen2AttentionOrig,
Qwen2DecoderLayer as Qwen2DecoderLayerOrig,
Qwen2MLP as Qwen2MLPOrig)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm_vacc.vllm.model_executor.layers.layernorm import RMSNorm_forward_vacc
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding, apply_interleaved_rope
from ..ops.mrope_op import get_sin_cos_mrope
import copy
logger = init_logger(__name__)
# --- VACC specific additions ---
from vllm_vacc.vllm.model_executor.models.vars import TRANSPOSE_GPTQ_WEIGHT, USE_FUSED_QWEN_ATTENTION
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
def save_tensor_pth(name, tensor):
return
import time
if tensor is None:
return
process_id = os.getpid()
timestamp = time.strftime("%Y-%m-%d_%H", time.localtime())
pathdir = f'./dump_qwen_pth/{timestamp}/process_{process_id}'
if not os.path.exists(pathdir):
os.makedirs(pathdir)
counter = 0
filename = f"{name}_{counter}.pth"
while os.path.exists(os.path.join(pathdir, filename)):
filename = f"{name}_{counter}.pth"
counter += 1
file_path = os.path.join(pathdir, filename)
if isinstance(tensor, torch.Tensor):
print(f'save {filename} {tensor.shape}')
torch.save(tensor.cpu(), file_path)
elif 'sin_cache' in name or 'cos_cache' in name:
tensor_list = [i.cpu() for i in tensor]
torch.save(tensor_list, file_path)
else:
torch.save(tensor, file_path)
return
# --- GPTQ Dequantization Logic (from gptq.py) ---
def int32_to_int4(s0, axis = -2):
# 要先拉平 shape[1, n]
# 每个int32 拆成8个int4, 8个int32表示 得到[8, n]
# x32(int32) => 32bit => 4bit x 8 x4[8] 4bit
# x32 31-28 => x4[7]
# x32 27-24 => x4[6]
# ...
# x32 3-0 => x4[0]
# x32[index=0] => x4[7,6,5,4,3,2,1,0]
# 4bit转真实数字
# 不是按补码方式
# 1111 => 15 => 7
# 15-8 = 7
# 0101 => 6 =>-2
# 6-8 = -2
# 0x 6A CB 37 2B (内存中排列 2B 37 CB 6A => B273BCA6 => (-8) => int4: 3, -6, -1, -5, 3, 4, 2, -2
# 内存中实际排布为小端模式:
# int32: 2B 37 CB 6A => 2,11,3,7,12,11,6,10 => (-8) => -6,3, -5,-1, 4,3, -2,2 => 同一字节所在的两个交换得到 3, -6, -1, -5, 3, 4, 2, -2
# int4: 3, -6, -1, -5, 3, 4, 2, -2
s = s0.view(torch.uint32)
all = []
for i in range(8):
x = 15 << (i*4)
# s2 = torch.bitwise_and(x,s)
s2 = torch.from_numpy(np.bitwise_and(x, s.numpy()))
s3 = s2 / (2 ** (i*4))
s4 = s3.to(torch.int32)
# 补码, 结果不对
# s4[s4 > 7] = s4[s4 > 7]-16
# 直接 - 8 结果正确, 范围: -8-7
s4 = s4 - 8
all.append(s4.reshape(1,*s4.shape))
all = torch.concatenate(all, 0)
if axis == -2 or axis == 0:
# 8,K//8,N => K//8,8,N => K,N
all = all.transpose(-2,0).reshape(-1,all.shape[-1]).contiguous()
else:
# 8,N,K//8 => N,K//8,8 => N,K
all = all.permute(1,2,0).reshape(all.shape[-2],-1).contiguous()
return all
def dequant_weight(qw, scales, group_size = 128):
N = qw.shape[1]
int4_to_int32_axis = -2
if TRANSPOSE_GPTQ_WEIGHT:
N = qw.shape[0]
int4_to_int32_axis = -1
qweight = int32_to_int4(qw,int4_to_int32_axis).to(torch.float16) #int32 => 8 int4 +> fp16
if TRANSPOSE_GPTQ_WEIGHT:
scales = scales.T.contiguous()
qweight = qweight.T.contiguous()
scales = torch.concatenate([scales] * group_size, 1).reshape(-1, N) # scale 按 group_size 扩展, 每 group_size 个数共用一个scale
# print('qweight', qweight.shape, qweight.dtype)
# print('scale', scales.shape, scales.dtype)
dequant_weight = qweight * scales #dequant
return dequant_weight
def apply_gptq_linear(
input_tensor: torch.Tensor,
layer: Union[QKVParallelLinear, RowParallelLinear],
) -> torch.Tensor:
"""
Applies a GPTQ-quantized linear layer by dequantizing weights on-the-fly.
"""
out_shape = input_tensor.shape[:-1] + (layer.qweight.shape[-2 if TRANSPOSE_GPTQ_WEIGHT else -1], ) # M,N
reshaped_x = input_tensor.reshape(-1, input_tensor.shape[-1])
# This assumes the linear_method is attached and is a GPTQ derivative
linear_method = layer.quant_method
quant_config = linear_method.quant_config
# scale_k is a VACC-specific parameter for scale broadcasting
scale_k = getattr(linear_method, "scale_k", 1)
# Dequantize weight
# weight = dequant_weight(layer.qweight.data, layer.scales.data, quant_config.group_size // scale_k)
weight = dequant_weight(layer.qweight.cpu(), layer.scales.cpu(), quant_config.group_size // scale_k).to(layer.qweight.device)
# Perform GEMM
output = torch.matmul(reshaped_x, weight)
# Add bias if it exists
if hasattr(layer, 'bias') and layer.bias is not None:
output = output + layer.bias
# out_shape = input_tensor.shape[:-1] + (weight.shape[-1],)
return output.reshape(out_shape)
def get_gptq_group_size(layer: Union[QKVParallelLinear, RowParallelLinear]):
# This assumes the linear_method is attached and is a GPTQ derivative
linear_method = layer.quant_method
quant_config = linear_method.quant_config
# scale_k is a VACC-specific parameter for scale broadcasting
scale_k = getattr(linear_method, "scale_k", 1)
return quant_config.group_size // scale_k
def apply_gptq_linear_(
input_tensor: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
bias: Optional[torch.Tensor],
qzeros: Optional[torch.Tensor],
group_size_h: int,
group_size_w: int,
) -> torch.Tensor:
"""
Applies a GPTQ-quantized linear layer by dequantizing weights on-the-fly.
"""
out_shape = input_tensor.shape[:-1] + (weight.shape[-2 if TRANSPOSE_GPTQ_WEIGHT else -1], ) # M,N
reshaped_x = input_tensor.reshape(-1, input_tensor.shape[-1])
# This assumes the linear_method is attached and is a GPTQ derivative
# linear_method = layer.quant_method
# quant_config = linear_method.quant_config
# # scale_k is a VACC-specific parameter for scale broadcasting
# scale_k = getattr(linear_method, "scale_k", 1)
# Dequantize weight
# weight = dequant_weight(layer.qweight.data, layer.scales.data, quant_config.group_size // scale_k)
# weight = dequant_weight(weight.cpu(), weight_scale.cpu(), group_size).to(weight.device)
# Perform GEMM
# output = torch.matmul(reshaped_x, weight)
# print("entering apply_gptq_linear_, reshaped_x shape:", reshaped_x.shape, "reshaped_x stride", reshaped_x.stride(), "input_tensor", input_tensor.shape, "weight shape:", weight.shape, "weight_scale shape:", weight_scale.shape, "group_size:", group_size)
output = torch.vacc.w4a8_block_int4_matmul(
reshaped_x,
weight.transpose(-1, -2),
weight_scale.transpose(-1, -2),
[group_size_h, group_size_w],
)
# print("exiting apply_gptq_linear_, output shape:", output.shape)
# Add bias if it exists
if bias is not None:
output = output + bias
# out_shape = input_tensor.shape[:-1] + (weight.shape[-1],)
return output.reshape(out_shape)
def post_layernorm(
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
weight: Optional[torch.Tensor] = None,
variance_size_override: Optional[int] = None,
variance_epsilon: float = 1e-6,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
hidden_size = x.shape[-1]
if variance_size_override is None:
x_var = x
else:
if hidden_size < variance_size_override:
raise ValueError(
"Expected hidden_size to be at least "
f"{variance_size_override}, but found: {hidden_size}")
x_var = x[:, :, :variance_size_override]
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + variance_epsilon)
x = x.to(orig_dtype)
if weight is not None:
x = x * weight
if residual is None:
return x
else:
return x, residual
# --- VACC Fused Kernel for Qwen2 ---
def vacc_fused_attn_qwen2_naive(
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
hidden_states_norm_weight: torch.Tensor,
qkv_proj_weight: torch.Tensor,
qkv_proj_weight_scale: torch.Tensor,
qkv_proj_bias: Optional[torch.Tensor],
qkv_proj_qzeros: Optional[torch.Tensor],
sin_cache: List[torch.Tensor],
cos_cache: List[torch.Tensor],
slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
block_tables: torch.Tensor,
block_group_size: int,
o_proj_weight: torch.Tensor,
o_proj_weight_scale: torch.Tensor,
o_proj_bias: Optional[torch.Tensor],
o_proj_qzeros: Optional[torch.Tensor],
seq_lens: List[int],
sm_scale: float,
num_attention_heads: int,
num_key_value_heads: int,
flash_attention: bool,
is_decode: bool,
reduce_result: bool,
world_size: int,
rank: int,
group_id: int,
dev_info: List[int] | Tuple[int],
block_size: int = 16
):
# qkv_proj_group_size = 128
# o_proj_group_size = 64
qkv_proj_group_size_h = qkv_proj_weight.shape[-1] * 8 // qkv_proj_weight_scale.shape[-1]
qkv_proj_group_size_w = qkv_proj_weight.shape[-2] * 8 // qkv_proj_weight_scale.shape[-2]
o_proj_group_size_h = o_proj_weight.shape[-1] * 8 // o_proj_weight_scale.shape[-1]
o_proj_group_size_w = o_proj_weight.shape[-2] * 8 // o_proj_weight_scale.shape[-2]
if residual is not None:
hidden_states = hidden_states + residual
residual_out = hidden_states
save_tensor_pth("hidden_states", hidden_states)
# 1. Fused RMSNorm
hidden_states_norm = torch.vacc.rms_norm(
hidden_states.unsqueeze(0), hidden_states_norm_weight, 1e-6).squeeze(0)
# NOTE: for qwen3 and qwen2.5, head_dim is always 128
head_dim = 128
# 2. QKV Projection using on-the-fly dequantization
save_tensor_pth("hidden_states_norm", hidden_states_norm)
qkv = apply_gptq_linear_(hidden_states_norm, qkv_proj_weight, qkv_proj_weight_scale, qkv_proj_bias, qkv_proj_qzeros, qkv_proj_group_size_h, qkv_proj_group_size_w)
save_tensor_pth("qkv", qkv)
# Split Q, K, V
num_q_heads = num_attention_heads // world_size
num_kv_heads = max(1, num_key_value_heads // world_size)
q_size = num_q_heads * head_dim
kv_size = num_kv_heads * head_dim
q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
save_tensor_pth("q", q)
save_tensor_pth("k", k)
save_tensor_pth("v", v)
q = q.view(-1, num_q_heads, head_dim)
k = k.view(-1, num_kv_heads, head_dim)
v = v.view(-1, num_kv_heads, head_dim)
save_tensor_pth("q_reshaped", q)
save_tensor_pth("k_reshaped", k)
save_tensor_pth("v_reshaped", v)
# 3. RoPE, KV Caching, and Attention loop
start = 0
attn_outs = []
if is_decode:
# convert block_tables to 8K group index
block_per_group = block_group_size // block_size
block_tables_grouped = (block_tables // block_per_group).to(torch.int32)
# logger.warning(f"decode block table: {block_tables}")
num_blocks = kv_cache.shape[1]
key_cache_split = kv_cache[0].view(num_blocks, -1, num_kv_heads, head_dim)
value_cache_split = kv_cache[1].view(num_blocks, -1, num_kv_heads, head_dim)
# bs loop
for i, seq_len in enumerate(seq_lens):
if not is_decode: # Prefill
end = start + seq_len
else: # Decode
end = start + 1
cos = cos_cache[i].unsqueeze(-2)
sin = sin_cache[i].unsqueeze(-2)
q_i, k_i, v_i = q[start:end], k[start:end], v[start:end]
# Apply Rotary Positional Embedding
q_rot, k_rot = torch.vacc.RotaryPosEmbedding(q_i, k_i, cos, sin, 0, "neox")
save_tensor_pth("q_rot", q_rot)
save_tensor_pth("k_rot", k_rot)
# Reshape and cache K/V
torch.vacc.reshape_and_cache_attention(k_rot, key_cache_split, slot_mapping[start : end, ...])
torch.vacc.reshape_and_cache_attention(v_i, value_cache_split, slot_mapping[start : end, ...])
# Attention calculation
if not is_decode:
attn_out = torch.vacc.scaled_dot_product_attention(
query=q_rot, key=k_rot, value=v_i.contiguous(),
attn_mask = None,
dropout_p = 0.0,
is_causal = True, #causal_attn and not self.need_mask,
is_train = False,
recompute = False,
flash_attention = False,
sm_scale=sm_scale)
else:
# For decode, reconstruct past K/V from cache
key_cache_grouped = key_cache_split.view(-1, block_group_size, num_kv_heads, head_dim)
value_cache_grouped = value_cache_split.view(-1, block_group_size, num_kv_heads, head_dim)
# block_per_group = block_group_size // block_size
# block_tables_grouped = (block_tables // block_per_group).to(torch.int32)
k_slices = key_cache_grouped[block_tables_grouped[i], :, :, :]
k_past = torch.cat([k_slices[j].unsqueeze(0) for j in range(len(block_tables_grouped[i]))], dim=0)
k_past = k_past.reshape(-1, num_kv_heads, head_dim)[:seq_len]
v_slices = value_cache_grouped[block_tables_grouped[i], :, :, :]
v_past = torch.cat([v_slices[j].unsqueeze(0) for j in range(len(block_tables_grouped[i]))], dim=0)
v_past = v_past.reshape(-1, num_kv_heads, head_dim)[:seq_len]
# k_past = key_cache_split.reshape(-1, key_cache_split.shape[-2], key_cache_split.shape[-1])[:seq_len]
# v_past = value_cache_split.reshape(-1, value_cache_split.shape[-2], value_cache_split.shape[-1])[:seq_len]
attn_out = torch.vacc.scaled_dot_product_attention(
query=q_rot.contiguous(), key=k_past.contiguous(), value=v_past.contiguous(),
attn_mask=None,
dropout_p=0,
is_causal=False,
is_train=False,
recompute=False,
flash_attention=False,#flash_attention
sm_scale=sm_scale,)
attn_outs.append(attn_out)
start = end
attn_output_cat = torch.cat(attn_outs, dim=0)
save_tensor_pth("attn_output", attn_output_cat)
# 4. Output Projection
attn_output_reshaped = attn_output_cat.view(hidden_states.shape[0], -1)
# print(f"attn_output_reshaped: {attn_output_reshaped}")
output = apply_gptq_linear_(attn_output_reshaped,o_proj_weight, o_proj_weight_scale,o_proj_bias, o_proj_qzeros, o_proj_group_size_h, o_proj_group_size_w)
# 5. Optional All-Reduce and post layernorm
if reduce_result:
output = tensor_model_parallel_all_reduce(output)
if residual is not None:
return output, residual_out
return output
# --- Rewriting forward methods for Qwen2Attention and Qwen2DecoderLayer ---
# uniform the params names from different quantize method
def set_fused_params(fused_params: Dict[str, Any], quant_method: QuantizeMethodBase, layer: nn.Module, name: str):
if isinstance(quant_method, UnquantizedLinearMethod):
fused_params[name + '_weight'] = layer.weight
fused_params[name + '_weight_scale'] = torch.Tensor()
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
fused_params[name + '_qzeros'] = None
elif isinstance(quant_method, Fp8LinearMethod):
fused_params[name + '_weight'] = layer.weight
fused_params[name + '_weight_scale'] = layer.weight_scale_inv
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
elif isinstance(quant_method, GPTQLinearMethod):
fused_params[name + '_weight'] = layer.qweight
fused_params[name + '_weight_scale'] = layer.scales
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
elif isinstance(quant_method, AWQLinearMethod):
fused_params[name + '_weight'] = layer.qweight
fused_params[name + '_weight_scale'] = layer.scales
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
else:
raise ValueError(f"Unsupported quant_method: {quant_method}")
# 全局变量,用于为每次 dump 操作创建一个唯一的批次 ID
# NOTE: This counter is process-specific. If you run multiple processes,
# each will have its own counter starting from 0. The directory structure
# already includes the process_id, so there won't be conflicts.
DUMP_COUNTER = 0
def dump_qwen2_attention_params(hidden_states, residual, fused_params, attn_metadata, is_decode,
sin_cache, cos_cache, kv_cache, env_blk_grp_size, scaling,
total_num_heads, total_num_kv_heads, world_size, rank, group_id, dev_info):
"""
收集所有需要dump的参数并以“批次”为单位进行保存。
"""
# if is_decode:
# return
# 1. 将所有需要保存的张量和数据收集到一个字典中
# 键是文件名(不含扩展名),值是数据本身。
# 这里的关键是:即使某个值是 None我们也把它放进字典
# 让保存函数自己去处理,而不是在这里跳过。
params_to_dump = {
'hidden_states': hidden_states,
'residual': residual,
'hidden_states_norm_weight': fused_params.get('input_layernorm_weight'),
'qkv_proj_weight': fused_params.get('qkv_proj_weight'),
'qkv_proj_weight_scale': fused_params.get('qkv_proj_weight_scale'),
'qkv_proj_bias': fused_params.get('qkv_proj_bias'),
'qkv_proj_qzeros': fused_params.get('qkv_proj_qzeros'),
'o_proj_weight': fused_params.get('o_proj_weight'),
'o_proj_weight_scale': fused_params.get('o_proj_weight_scale'),
'o_proj_bias': fused_params.get('o_proj_bias'),
'o_proj_qzeros': fused_params.get('o_proj_qzeros'),
'sin_cache': sin_cache,
'cos_cache': cos_cache,
'kv_cache': kv_cache,
'slot_mapping': attn_metadata.slot_mapping if attn_metadata else None,
'block_tables': attn_metadata.block_tables if attn_metadata else None,
'seq_lens': attn_metadata.seq_lens if attn_metadata else None,
'block_group_size': env_blk_grp_size,
'sm_scale': scaling,
'num_attention_heads': total_num_heads,
'num_key_value_heads': total_num_kv_heads,
'world_size': world_size,
'rank': rank,
'group_id': group_id,
'dev_info': dev_info,
}
# 2. 一次性调用新的批次保存函数
save_dump_batch(params_to_dump)
def save_dump_batch(tensors_dict):
"""
将一个字典中的所有张量和数据保存到一个唯一的、代表“同一次调用”的子目录中。
"""
global DUMP_COUNTER
import time
# 基础目录结构:./dump_data/{timestamp}/process_{pid}
process_id = os.getpid()
timestamp = time.strftime("%Y-%m-%d_%H", time.localtime())
base_dir = f'./dump_qwen2_prefill_pth/{timestamp}/process_{process_id}'
# 为这次“批次”的dump创建一个唯一的子目录
# 这就是你想要的“魔法数”的实现
batch_dir = os.path.join(base_dir, f'dump_{DUMP_COUNTER}')
# 递增全局计数器,为下一次调用做准备
DUMP_COUNTER += 1
# 创建目录(如果不存在)
os.makedirs(batch_dir, exist_ok=True)
print(f"--- Dumping batch to: {batch_dir} ---")
# 遍历字典,保存每一个非 None 的项
for name, data in tensors_dict.items():
if data is None:
# 如果数据是 None直接跳过不保存任何文件
continue
file_path = os.path.join(batch_dir, f"{name}.pth")
try:
if isinstance(data, torch.Tensor):
print(f' Saving tensor: {name}.pth, shape: {data.shape}, dtype: {data.dtype}')
torch.save(data.cpu(), file_path)
elif name in ['sin_cache', 'cos_cache'] and isinstance(data, (list, tuple)):
# 对 sin/cos cache 的特殊处理
tensor_list = [i.cpu() for i in data]
print(f' Saving tensor list: {name}.pth, count: {len(tensor_list)}')
torch.save(tensor_list, file_path)
else:
# 保存其他可序列化的 Python 对象(如数字、字符串等)
print(f' Saving metadata: {name}.pth, value: {data}')
torch.save(data, file_path)
except Exception as e:
print(f" [ERROR] Failed to save {name}.pth: {e}")
print(f"--- Finished dumping batch {DUMP_COUNTER - 1} ---")
def has_nan(t1, t2):
has_nan_1 = torch.isnan(t1).any()
has_nan_2 = torch.isnan(t2).any()
# 2. 根据检查结果进行判断和计算
if has_nan_1 or has_nan_2:
print("计算被中止因为检测到了NaN值。")
if has_nan_1:
print(" - Tensor 1 包含 NaN。")
if has_nan_2:
print(" - Tensor 2 包含 NaN。")
class MlpSplitConfig:
def __init__(self):
self.k_split = 0
self.n_split = 0
self.w1_weight_addr = 0
self.w3_weight_addr = 0
self.matA_addr_0 = 0
self.matA_addr_1 = 0
self.w1_splitk_buffer = 0
self.w3_splitk_buffer = 0
self.w13_multiply = 0
self.w2_splitn_buffer_list = []
self.w13_dequant_block_size = 0
self.w2_dequant_block_size = 0
def get_dequant_block_size_int4(k, n, k_split, n_split):
ret = []
w13_dequant_block_init_size = 128
k_per_split = k // k_split
if k_per_split % w13_dequant_block_init_size == 0:
ret.append(w13_dequant_block_init_size)
else:
while w13_dequant_block_init_size >= 16 and k_per_split % w13_dequant_block_init_size != 0:
w13_dequant_block_init_size //= 2
if w13_dequant_block_init_size >= 16:
ret.append(w13_dequant_block_init_size)
else:
ret.append(-1)
w2_dequant_block_init_size = 128
n_per_split = n // n_split
if n_per_split % w2_dequant_block_init_size == 0:
ret.append(w2_dequant_block_init_size)
else:
while w2_dequant_block_init_size >= 16 and n_per_split % w2_dequant_block_init_size != 0:
w2_dequant_block_init_size //= 2
if w2_dequant_block_init_size >= 16:
ret.append(w2_dequant_block_init_size)
else:
ret.append(-1)
return ret
def get_split_config(total_split):
ret = []
max_split_n = math.ceil(math.sqrt(total_split))
for i in range(1, max_split_n + 1):
if total_split % i == 0:
cur_config = [total_split // i, i] # [k_split, n_split]
ret.append(cur_config)
split_config_size = len(ret)
for i in range(split_config_size - 1, -1, -1):
if ret[i][0] != ret[i][1]:
ret.append([ret[i][1], ret[i][0]])
return ret
def validate_mlp_split_config(k, n, k_split, n_split, config):
min_seqlen = 1024
# 尺寸可除性检查
if k % k_split != 0:
return False
if n % (n_split * 4) != 0:
return False
if k % (k_split * 4) != 0:
return False
if (n / (n_split * 4)) % 16 != 0:
return False
if (k / (k_split * 4)) % 16 != 0:
return False
# 计算权重和输入矩阵的大小
w13_split_size = (k // k_split) * (n // n_split // 4) * 2 * 2
matA_size = min_seqlen * (k // k_split // 4) * 2 * 2
# 检查SSRAM 4-7区空间是否足够
if w13_split_size + matA_size > 5 * 1024 * 1024:
return False
# 计算剩余空间
ssram_4_to_7_size_left = 5 * 1024 * 1024 - (w13_split_size + matA_size)
w13_result_size = min_seqlen * (n // n_split // 4) * 2
# 检查结果存储空间
if w13_result_size * 3 > ssram_4_to_7_size_left + 5 * 1024 * 1024:
return False
ssram_0_to_3_size_left = 5 * 1024 * 1024
w13_result_save_type = -1
# 确定结果存储方案
if w13_result_size * 3 <= 5 * 1024 * 1024:
w13_result_save_type = 0
ssram_0_to_3_size_left -= w13_result_size * 3
elif w13_result_size * 2 <= 5 * 1024 * 1024:
if w13_result_size <= ssram_4_to_7_size_left:
w13_result_save_type = 1
ssram_0_to_3_size_left -= w13_result_size * 2
ssram_4_to_7_size_left -= w13_result_size
else:
return False
else:
return False
# 处理n_split > 1的情况
if n_split > 1:
w2_split_k_buffer_total_size = min_seqlen * (k // k_split // 4) * 2 * k_split
w2_split_k_buffer_size = min_seqlen * (k // k_split // 4) * 2
if w2_split_k_buffer_total_size > ssram_0_to_3_size_left + ssram_4_to_7_size_left:
return False
available_blocks = (ssram_0_to_3_size_left // w2_split_k_buffer_size) + \
(ssram_4_to_7_size_left // w2_split_k_buffer_size)
if available_blocks < k_split:
return False
# 设置配置参数
config.k_split = k_split
config.n_split = n_split
config.w1_weight_addr = 0x54000000
config.w3_weight_addr = config.w1_weight_addr + (k // k_split) * (n // n_split // 4) * 2
config.matA_addr_0 = config.w3_weight_addr + (k // k_split) * (n // n_split // 4) * 2
config.matA_addr_1 = config.matA_addr_0 + min_seqlen * (k // k_split // 4) * 2
config.w1_splitk_buffer = 0x50000000
config.w3_splitk_buffer = config.w1_splitk_buffer + min_seqlen * (n // n_split // 4) * 2
if w13_result_save_type == 0:
config.w13_multiply = config.w3_splitk_buffer + min_seqlen * (n // n_split // 4) * 2
elif w13_result_save_type == 1:
config.w13_multiply = config.matA_addr_1 + min_seqlen * (k // k_split // 4) * 2
# 处理w2缓冲区分配
if n_split > 1:
w2_split_k_buffer_size = min_seqlen * (k // k_split // 4) * 2
w2_split_k_buffer_in_ssram_0_to_3 = min(ssram_0_to_3_size_left // w2_split_k_buffer_size, k_split)
w2_split_k_buffer_in_ssram_4_to_7 = k_split - w2_split_k_buffer_in_ssram_0_to_3
if w13_result_save_type == 0:
ssram_0_to_3_start_addr = config.w13_multiply + min_seqlen * (n // n_split // 4) * 2
ssram_4_to_7_start_addr = config.matA_addr_1 + min_seqlen * (k // k_split // 4) * 2
else:
ssram_0_to_3_start_addr = config.w3_splitk_buffer + min_seqlen * (n // n_split // 4) * 2
ssram_4_to_7_start_addr = config.w13_multiply + min_seqlen * (n // n_split // 4) * 2
for i in range(w2_split_k_buffer_in_ssram_0_to_3):
addr = ssram_0_to_3_start_addr + i * w2_split_k_buffer_size
config.w2_splitn_buffer_list.append(addr)
for i in range(w2_split_k_buffer_in_ssram_4_to_7):
addr = ssram_4_to_7_start_addr + i * w2_split_k_buffer_size
config.w2_splitn_buffer_list.append(addr)
else:
config.w2_splitn_buffer_list.append(0)
# 获取解量化块大小
block_size = get_dequant_block_size_int4(k, n, k_split, n_split)
if block_size[0] == -1 or block_size[1] == -1:
return False
config.w13_dequant_block_size = block_size[0]
config.w2_dequant_block_size = block_size[1]
return True
def get_mlp_split_schedule(k, n):
min_split = math.ceil(k * n / 2 / 1310720) # 最大权重分割大小1.25MB
current_split = min_split
config = MlpSplitConfig()
for _ in range(100): # 最多尝试100次
split_configs = get_split_config(current_split)
for cfg in split_configs:
if validate_mlp_split_config(k, n, cfg[0], cfg[1], config):
return True, config
current_split += 1
return False, None
def Qwen2MLP__init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super(Qwen2MLPOrig, self).__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
if hasattr(self.gate_up_proj, 'quant_method') and isinstance(self.gate_up_proj.quant_method, GPTQLinearMethod):
intermediate_size_per_partition = intermediate_size // self.down_proj.tp_size
success, config = get_mlp_split_schedule(hidden_size, intermediate_size_per_partition)
if success:
self.w13_dequant_block_size = config.w13_dequant_block_size
self.w2_dequant_block_size = config.w2_dequant_block_size
group_size = self.gate_up_proj.quant_method.quant_config.group_size
scale_k_w13 = group_size // config.w13_dequant_block_size
scale_k_w2 = group_size // config.w2_dequant_block_size
self.gate_up_proj.quant_method.scale_k = max(scale_k_w13, self.gate_up_proj.quant_method.scale_k)
self.down_proj.quant_method.scale_k = max(scale_k_w2, self.down_proj.quant_method.scale_k)
if self.gate_up_proj.quant_method.scale_k > 1 and len(self.gate_up_proj.scales.data.shape) == 2 and \
self.gate_up_proj.scales.data.dtype in [torch.float16, torch.bfloat16, torch.float32]:
w13_scales = self.gate_up_proj.scales.data
scale_and_zero_size = hidden_size * self.gate_up_proj.quant_method.scale_k // group_size
w13_scales = torch.empty(
scale_and_zero_size,
w13_scales.shape[-1],
dtype=w13_scales.dtype,
)
self.gate_up_proj.scales.data = w13_scales
if self.down_proj.quant_method.scale_k > 1 and len(self.down_proj.scales.data.shape) == 2 and \
self.down_proj.scales.data.dtype in [torch.float16, torch.bfloat16, torch.float32]:
w2_scales = self.down_proj.scales.data
scale_and_zero_size = intermediate_size_per_partition * self.down_proj.quant_method.scale_k // group_size
w2_scales = torch.empty(
scale_and_zero_size,
w2_scales.shape[-1],
dtype=w2_scales.dtype,
)
self.down_proj.scales.data = w2_scales
else:
raise ValueError(f"Cannot find valid MLP split schedule for hidden_size {hidden_size}, "
"intermediate_size {intermediate_size}, tp_size {self.down_proj.tp_size}")
def get_cos_sin_cache(rotary_emb: Union["MRotaryEmbedding", "RotaryEmbedding"],
attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]],
positions: Union[torch.Tensor, list],
is_decode: bool):
if isinstance(rotary_emb, MRotaryEmbedding):
# get mrope sin/cos
cos_cache, sin_cache = get_sin_cos_mrope(rotary_emb, positions)
if len(attn_metadata.seq_lens) > 1:
if is_decode:
cos_cache = torch.chunk(cos_cache, len(attn_metadata.seq_lens))
sin_cache = torch.chunk(sin_cache, len(attn_metadata.seq_lens))
else:
cos_cache = torch.split(cos_cache, attn_metadata.seq_lens)
sin_cache = torch.split(sin_cache, attn_metadata.seq_lens)
else:
cos_cache = [cos_cache]
sin_cache = [sin_cache]
else:
if is_decode:
positions = [i - 1 for i in attn_metadata.seq_lens]
cos_cache = [rotary_emb.cos_cache[i:i+1, ...] for i in positions]
sin_cache = [rotary_emb.sin_cache[i:i+1, ...] for i in positions]
else:
cos_cache = [rotary_emb.cos_cache[:i, ...] for i in attn_metadata.seq_lens]
sin_cache = [rotary_emb.sin_cache[:i, ...] for i in attn_metadata.seq_lens]
return cos_cache, sin_cache
class Qwen2MLP(nn.Module):
def forward(self, x):
#TODO for other quant_method
if hasattr(self.gate_up_proj, 'quant_method') and isinstance(self.gate_up_proj.quant_method, GPTQLinearMethod):
# bit = self.gate_up_proj.quant_method.quant_config.weight_bits
# pack_num = torch.iinfo(self.gate_up_proj.qweight.dtype).bits // bit
# w13_group_size = pack_num * self.gate_up_proj.qweight.shape[1] // self.gate_up_proj.scales.shape[1]
# w2_group_size = pack_num * self.down_proj.qweight.shape[1] // self.down_proj.scales.shape[1]
tp_group = get_tp_group()
batch_size = x.shape[0]
if batch_size > 4:
y = torch.vacc.fuse_mlp_qwen_int4(
x,
self.gate_up_proj.qweight,
self.down_proj.qweight,
self.gate_up_proj.scales,
self.down_proj.scales,
self.gate_up_proj.qzeros,
self.down_proj.qzeros,
[1, self.w13_dequant_block_size],
[1, self.w2_dequant_block_size]
)
mlp_out = tensor_model_parallel_all_reduce(y)
return mlp_out
else:
mlp_out = torch.vacc.fuse_mlp_qwen_int4_reduce(
x,
self.gate_up_proj.qweight,
self.down_proj.qweight,
self.gate_up_proj.scales,
self.down_proj.scales,
self.gate_up_proj.qzeros,
self.down_proj.qzeros,
[1, self.w13_dequant_block_size],
[1, self.w2_dequant_block_size],
world_size = tp_group.world_size,
rank = tp_group.rank_in_group,
group_id = tp_group.group_id,
dev_info = tp_group.rank_device_infos
)
return mlp_out
elif hasattr(self.gate_up_proj, 'quant_method') and isinstance(self.gate_up_proj.quant_method, UnquantizedLinearMethod):
batch_size = x.shape[0]
if batch_size <= 4:
hidden_states = torch.vacc.fuse_mlp_qwen_fp16_bf16_reduce(x.view(-1, x.shape[-1]),
self.gate_up_proj.weight,
self.down_proj.weight,
world_size = get_tp_group().world_size,
rank = get_tp_group().rank_in_group,
group_id = get_tp_group().group_id,
dev_info = get_tp_group().rank_device_infos
)
return hidden_states
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class Qwen2Attention(nn.Module):
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor] = None, # Added for fused kernel
cos_cache: list[torch.Tensor] = None,
sin_cache: list[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.attn.layer_name]
kv_cache = self.attn.kv_cache[forward_context.virtual_engine]
is_decode = attn_metadata.prefill_metadata is None
# print('attn_metadata', attn_metadata)
use_gptq = hasattr(self.qkv_proj, 'quant_method') and isinstance(self.qkv_proj.quant_method, GPTQLinearMethod)
use_unquan = hasattr(self.qkv_proj, 'quant_method') and isinstance(self.qkv_proj.quant_method, UnquantizedLinearMethod)
# if USE_FUSED_QWEN_ATTENTION and not is_decode:
if use_gptq:
# Determine if we should reduce the result inside the kernel
# print(f"Qwen2Attention forward, is_decode: {is_decode}, kv_cache shape: {kv_cache.shape}, attn_metadata.seq_lens: {attn_metadata.seq_lens}")
kv_cache = kv_cache.view(2, -1, 16, self.total_num_kv_heads // get_tp_group().world_size, self.head_dim)
reduce_result = is_decode
# reduce_result = False
# total_bytes = hidden_states.numel() * hidden_states.element_size() * get_tp_group().world_size
# if total_bytes < 4194304: # Heuristic from qwen3_moe_vacc
# reduce_result = True
# Get rotary caches
# cos_cache = []
# sin_cache = []
# start = 0
# for seq_len in attn_metadata.seq_lens:
# if not is_decode:
# end = start + seq_len
# else:
# end = start + 1
# # NOTE:
# # in prefill stage, the value from positions[start] to positions[end] should be always contiguous
# cos_cache.append(self.rotary_emb.cos_cache[positions[start]:positions[end-1] + 1, ...])
# sin_cache.append(self.rotary_emb.sin_cache[positions[start]:positions[end-1] + 1, ...])
# start = end
if is_decode:
positions = [i - 1 for i in attn_metadata.seq_lens]
cos_cache = [self.rotary_emb.cos_cache[i:i+1, ...] for i in positions]
sin_cache = [self.rotary_emb.sin_cache[i:i+1, ...] for i in positions]
else:
cos_cache = [self.rotary_emb.cos_cache[:i, ...] for i in attn_metadata.seq_lens]
sin_cache = [self.rotary_emb.sin_cache[:i, ...] for i in attn_metadata.seq_lens]
if residual is None:
res_out = hidden_states
if USE_FUSED_QWEN_ATTENTION:
qkv_proj_qzeros = torch.Tensor()
o_proj_bias = torch.Tensor()
o_proj_qzeros = torch.Tensor()
qkv_proj_bias = self.fused_params['qkv_proj_bias']
# if is_decode:
# qkv_proj_bias = self.fused_params['qkv_proj_bias']
# if hidden_states.shape[0] > 1:
# qkv_proj_bias = qkv_proj_bias.repeat(len(attn_metadata.seq_lens), 1)
# else:
# qkv_proj_bias = self.fused_params['qkv_proj_bias']
# qkv_proj_bias = qkv_proj_bias.repeat(sum(attn_metadata.seq_lens), 1)
# print("Qwen2Attention forward, is_decode:", is_decode,
# "kv_cache shape:", kv_cache.shape,
# "attn_metadata.seq_lens:", attn_metadata.seq_lens,
# "qkv_proj_bias shape:", qkv_proj_bias.shape,
# "qkv_proj_bias:", qkv_proj_bias,
# "qkv_proj_qzeros shape:", qkv_proj_qzeros.shape,
# "o_proj_bias shape:", o_proj_bias.shape,
# "o_proj_bias:", o_proj_bias,
# "o_proj_qzeros shape:", o_proj_qzeros.shape,
# "o_proj_qzeros:", o_proj_qzeros,
# "reduce_result:", reduce_result,
# "world_size:", get_tp_group().world_size,
# "rank:", get_tp_group().rank_in_group,
# "group_id:", get_tp_group().group_id,
# "dev_info:", get_tp_group().rank_device_infos)
# print("residual shape:", residual.shape if residual is not None else "None",
# "hidden_states shape:", hidden_states.shape,
# "hidden_states dtype:", hidden_states.dtype,
# "hidden_states device:", hidden_states.device,
# "attn_metadata.slot_mapping shape:", attn_metadata.slot_mapping.shape if attn_metadata.slot_mapping is not None else "None",
# "attn_metadata.block_tables shape:", attn_metadata.block_tables.shape if attn_metadata.block_tables is not None else "None",
# "attn_metadata.block_tables dtype:", attn_metadata.block_tables.dtype if attn_metadata.block_tables is not None else "None",
# "attn_metadata.block_tables device:", attn_metadata.block_tables.device if attn_metadata.block_tables is not None else "None",
# "attn_metadata.seq_lens:", attn_metadata.seq_lens,
# "attn_metadata.seq_lens length:", len(attn_metadata.seq_lens),
# "attn_metadata.seq_lens dtype:", type(attn_metadata.seq_lens),
# # "attn_metadata.seq_lens device:", attn_metadata.seq_lens[0].device if attn_metadata.seq_lens else "None",
# "attn_metadata.seq_lens type:", type(attn_metadata.seq_lens[0]) if attn_metadata.seq_lens else "None",
# "attn_metadata.seq_lens content:", attn_metadata.seq_lens if attn_metadata.seq_lens else "None",
# "attn_metadata.slot_mapping dtype:", attn_metadata.slot_mapping.dtype if attn_metadata.slot_mapping is not None else "None",
# "attn_metadata.slot_mapping device:", attn_metadata.slot_mapping.device if attn_metadata.slot_mapping is not None else "None",
# "attn_metadata.slot_mapping type:", type(attn_metadata.slot_mapping) if attn_metadata.slot_mapping is not None else "None",
# "attn_metadata.slot_mapping content:", attn_metadata.slot_mapping if attn_metadata.slot_mapping is not None else "None",
# "self.fused_params['input_layernorm_weight'] shape:", self.fused_params['input_layernorm_weight'].shape if 'input_layernorm_weight' in self.fused_params else "None",
# "self.fused_params['input_layernorm_weight'] dtype:", self.fused_params['input_layernorm_weight'].dtype if 'input_layernorm_weight' in self.fused_params else "None",
# "self.fused_params['input_layernorm_weight'] device:", self.fused_params['input_layernorm_weight'].device if 'input_layernorm_weight' in self.fused_params else "None",
# "self.fused_params['qkv_proj_weight'] shape:", self.fused_params['qkv_proj_weight'].shape if 'qkv_proj_weight' in self.fused_params else "None",
# "self.fused_params['qkv_proj_weight'] dtype:", self.fused_params['qkv_proj_weight'].dtype if 'qkv_proj_weight' in self.fused_params else "None",
# "self.fused_params['qkv_proj_weight'] device:", self.fused_params['qkv_proj_weight'].device if 'qkv_proj_weight' in self.fused_params else "None",
# "self.fused_params['qkv_proj_weight_scale'] shape:", self.fused_params['qkv_proj_weight_scale'].shape if 'qkv_proj_weight_scale' in self.fused_params else "None",
# "self.fused_params['qkv_proj_weight_scale'] dtype:", self.fused_params['qkv_proj_weight_scale'].dtype if 'qkv_proj_weight_scale' in self.fused_params else "None",
# "self.fused_params['qkv_proj_weight_scale'] device:", self.fused_params['qkv_proj_weight_scale'].device if 'qkv_proj_weight_scale' in self.fused_params else "None",
# "self.fused_params['qkv_proj_bias'] shape:", self.fused_params['qkv_proj_bias'].shape if 'qkv_proj_bias' in self.fused_params else "None",
# "self.fused_params['qkv_proj_bias'] dtype:", self.fused_params['qkv_proj_bias'].dtype if 'qkv_proj_bias' in self.fused_params else "None",
# "self.fused_params['qkv_proj_bias'] device:", self.fused_params['qkv_proj_bias'].device if 'qkv_proj_bias' in self.fused_params else "None",
# "self.fused_params['qkv_proj_qzeros'] shape:", self.fused_params['qkv_proj_qzeros'].shape if 'qkv_proj_qzeros' in self.fused_params else "None",
# "self.fused_params['qkv_proj_qzeros'] dtype:", self.fused_params['qkv_proj_qzeros'].dtype if 'qkv_proj_qzeros' in self.fused_params else "None",
# "self.fused_params['qkv_proj_qzeros'] device:", self.fused_params['qkv_proj_qzeros'].device if 'qkv_proj_qzeros' in self.fused_params else "None",
# "self.fused_params['o_proj_weight'] shape:", self.fused_params['o_proj_weight'].shape if 'o_proj_weight' in self.fused_params else "None",
# "self.fused_params['o_proj_weight'] dtype:", self.fused_params['o_proj_weight'].dtype if 'o_proj_weight' in self.fused_params else "None",
# "self.fused_params['o_proj_weight'] device:", self.fused_params['o_proj_weight'].device if 'o_proj_weight' in self.fused_params else "None",
# "self.fused_params['o_proj_weight_scale'] shape:", self.fused_params['o_proj_weight_scale'].shape if 'o_proj_weight_scale' in self.fused_params else "None",
# "self.fused_params['o_proj_weight_scale'] dtype:", self.fused_params['o_proj_weight_scale'].dtype if 'o_proj_weight_scale' in self.fused_params else "None",
# "self.fused_params['o_proj_weight_scale'] device:", self.fused_params['o_proj_weight_scale'].device if 'o_proj_weight_scale' in self.fused_params else "None",
# # "self.fused_params['o_proj_bias'] shape:", self.fused_params['o_proj_bias'].shape if 'o_proj_bias' in self.fused_params else "None",
# # "self.fused_params['o_proj_bias'] dtype:", self.fused_params['o_proj_bias'].dtype if 'o_proj_bias' in self.fused_params else "None",
# # "self.fused_params['o_proj_bias'] device:", self.fused_params['o_proj_bias'].device if 'o_proj_bias' in self.fused_params else "None",
# # "self.fused_params['o_proj_qzeros'] shape:", self.fused_params['o_proj_qzeros'].shape if 'o_proj_qzeros' in self.fused_params else "None",
# # "self.fused_params['o_proj_qzeros'] dtype:", self.fused_params['o_proj_qzeros'].dtype if 'o_proj_qzeros' in self.fused_params else "None",
# # "self.fused_params['o_proj_qzeros'] device:", self.fused_params['o_proj_qzeros'].device if 'o_proj_qzeros' in self.fused_params else "None",
# "sin_cache length:", len(sin_cache),
# "cos_cache length:", len(cos_cache),
# "sin_cache[0] shape:", sin_cache[0].shape if sin_cache else "None",
# "cos_cache[0] shape:", cos_cache[0].shape if cos_cache else "None",
# "sin_cache[0] dtype:", sin_cache[0].dtype if sin_cache else "None",
# "cos_cache[0] dtype:", cos_cache[0].dtype if cos_cache else "None",
# "sin_cache[0] device:", sin_cache[0].device if sin_cache else "None",
# "cos_cache[0] device:", cos_cache[0].device if cos_cache else "None",
# "kv_cache shape:", kv_cache.shape,
# "kv_cache dtype:", kv_cache.dtype,
# "kv_cache device:", kv_cache.device,
# "self.scaling:", self.scaling,
# "self.total_num_heads:", self.total_num_heads,
# "self.total_num_kv_heads:", self.total_num_kv_heads,
# "is_decode:", is_decode,
# "reduce_result:", reduce_result
# )
# Call the naive fused kernel
####################################call fusion kernel######################################
# print("starting fused attention kernel......................................, is_decode:", is_decode)
# print('hidden_states', hidden_states.shape)
# print('attn_metadata.slot_mapping,', attn_metadata.slot_mapping,)
# print('attn_metadata.block_tables,', attn_metadata.block_tables,)
# print('attn_metadata.seq_lens,', attn_metadata.seq_lens,)
# print('is_decode,', is_decode)
attn_outs = torch.vacc.fuse_atten_qwen2(
hidden_states,
residual,
self.fused_params['input_layernorm_weight'],
self.fused_params['qkv_proj_weight'],
self.fused_params['qkv_proj_weight_scale'],
qkv_proj_bias,
qkv_proj_qzeros, #self.fused_params['qkv_proj_qzeros'],
sin_cache,
cos_cache,
attn_metadata.slot_mapping,
kv_cache,
attn_metadata.block_tables,
env_blk_grp_size,
self.fused_params['o_proj_weight'],
self.fused_params['o_proj_weight_scale'],
o_proj_bias, #self.fused_params['o_proj_bias'],
o_proj_qzeros, #self.fused_params['o_proj_qzeros'],
attn_metadata.seq_lens,
self.scaling,
self.total_num_heads,
self.total_num_kv_heads,
is_decode,
is_decode,
reduce_result,
world_size=get_tp_group().world_size,
rank=get_tp_group().rank_in_group,
group_id=get_tp_group().group_id,
dev_info=get_tp_group().rank_device_infos)
# print("finished fused attention kernel......................................, is_decode:", is_decode)
# attn_outs1 = vacc_fused_attn_qwen2_naive(
# hidden_states=hidden_states,
# residual=residual,
# hidden_states_norm_weight=self.fused_params['input_layernorm_weight'],
# ###ignore the group size, it is not used in the fused kernel
# # qkv_proj_group_size=get_gptq_group_size(self.qkv_proj),
# # o_proj_group_size=get_gptq_group_size(self.o_proj),
# ###
# qkv_proj_weight=self.fused_params['qkv_proj_weight'],
# qkv_proj_weight_scale=self.fused_params['qkv_proj_weight_scale'],
# qkv_proj_bias=self.fused_params['qkv_proj_bias'],
# qkv_proj_qzeros=self.fused_params['qkv_proj_qzeros'],
# sin_cache=sin_cache,
# cos_cache=cos_cache,
# slot_mapping=attn_metadata.slot_mapping,
# kv_cache=kv_cache,
# block_tables=attn_metadata.block_tables,
# block_group_size=env_blk_grp_size,
# o_proj_weight=self.fused_params['o_proj_weight'],
# o_proj_weight_scale=self.fused_params['o_proj_weight_scale'],
# o_proj_bias=self.fused_params['o_proj_bias'],
# o_proj_qzeros=self.fused_params['o_proj_qzeros'],
# seq_lens=attn_metadata.seq_lens,
# sm_scale=self.scaling,
# num_attention_heads=self.total_num_heads,
# num_key_value_heads=self.total_num_kv_heads,
# # head_dim=self.head_dim,
# is_decode=is_decode,
# flash_attention=is_decode,
# reduce_result=reduce_result,
# world_size=get_tp_group().world_size,
# rank=get_tp_group().rank_in_group,
# group_id=get_tp_group().group_id,
# dev_info=get_tp_group().rank_device_infos
# )
# if residual is None:
# has_nan(attn_outs1, attn_outs)
# ret = torch.isnan(attn_outs).any()
# if torch.isinf(attn_outs).any() or torch.isnan(attn_outs).any():
# dump_qwen2_attention_params(hidden_states=hidden_states,
# # output=output_,
# residual=residual,
# fused_params=self.fused_params,
# attn_metadata=attn_metadata,
# is_decode=is_decode,
# sin_cache=sin_cache,
# cos_cache=cos_cache,
# kv_cache=kv_cache,
# env_blk_grp_size=env_blk_grp_size,
# scaling=self.scaling,
# total_num_heads=self.total_num_heads,
# total_num_kv_heads=self.total_num_kv_heads,
# world_size=get_tp_group().world_size,
# rank=get_tp_group().rank_in_group,
# group_id=get_tp_group().group_id,
# dev_info=get_tp_group().rank_device_infos)
# cos_sim = torch.cosine_similarity(attn_outs1.cpu().reshape([-1]).float(), attn_outs.cpu().reshape([-1]).float(), dim=-1)
# print("cos_sim witout residual:",cos_sim )
# else :
# has_nan(attn_outs1[0], attn_outs[0])
# ret = torch.isnan(attn_outs[0]).any()
# if ret:
# dump_qwen2_attention_params(hidden_states=hidden_states,
# # output=output_,
# residual=residual,
# fused_params=self.fused_params,
# attn_metadata=attn_metadata,
# is_decode=is_decode,
# sin_cache=sin_cache,
# cos_cache=cos_cache,
# kv_cache=kv_cache,
# env_blk_grp_size=env_blk_grp_size,
# scaling=self.scaling,
# total_num_heads=self.total_num_heads,
# total_num_kv_heads=self.total_num_kv_heads,
# world_size=get_tp_group().world_size,
# rank=get_tp_group().rank_in_group,
# group_id=get_tp_group().group_id,
# dev_info=get_tp_group().rank_device_infos)
# cos_sim = torch.cosine_similarity(attn_outs1[0].cpu().reshape([-1]).float(), attn_outs[0].cpu().reshape([-1]).float(), dim=-1)
# print("cos_sim residual:",cos_sim )
else:
##########################################################################
# dump_qwen2_attention_params(hidden_states=hidden_states,
# # output=output_,
# residual=residual,
# fused_params=self.fused_params,
# attn_metadata=attn_metadata,
# is_decode=is_decode,
# sin_cache=sin_cache,
# cos_cache=cos_cache,
# kv_cache=kv_cache,
# env_blk_grp_size=env_blk_grp_size,
# scaling=self.scaling,
# total_num_heads=self.total_num_heads,
# total_num_kv_heads=self.total_num_kv_heads,
# world_size=get_tp_group().world_size,
# rank=get_tp_group().rank_in_group,
# group_id=get_tp_group().group_id,
# dev_info=get_tp_group().rank_device_infos)
attn_outs = vacc_fused_attn_qwen2_naive(
hidden_states=hidden_states,
residual=residual,
hidden_states_norm_weight=self.fused_params['input_layernorm_weight'],
###ignore the group size, it is not used in the fused kernel
# qkv_proj_group_size=get_gptq_group_size(self.qkv_proj),
# o_proj_group_size=get_gptq_group_size(self.o_proj),
###
qkv_proj_weight=self.fused_params['qkv_proj_weight'],
qkv_proj_weight_scale=self.fused_params['qkv_proj_weight_scale'],
qkv_proj_bias=self.fused_params['qkv_proj_bias'],
qkv_proj_qzeros=None,
sin_cache=sin_cache,
cos_cache=cos_cache,
slot_mapping=attn_metadata.slot_mapping,
kv_cache=kv_cache,
block_tables=attn_metadata.block_tables,
block_group_size=env_blk_grp_size,
o_proj_weight=self.fused_params['o_proj_weight'],
o_proj_weight_scale=self.fused_params['o_proj_weight_scale'],
o_proj_bias=None,
o_proj_qzeros=None,
seq_lens=attn_metadata.seq_lens,
sm_scale=self.scaling,
num_attention_heads=self.total_num_heads,
num_key_value_heads=self.total_num_kv_heads,
# head_dim=self.head_dim,
is_decode=is_decode,
flash_attention=is_decode,
reduce_result=reduce_result,
world_size=get_tp_group().world_size,
rank=get_tp_group().rank_in_group,
group_id=get_tp_group().group_id,
dev_info=get_tp_group().rank_device_infos
)
# if is_decode:
# assert(1<0)
# if residual is None:
# output_ = attn_outs
# else:
# output_ = attn_outs[0]
# print("starting tensor_model_parallel_all_reduce......................................, is_decode:", is_decode)
if residual is None:
attn_out = tensor_model_parallel_all_reduce(attn_outs) if not reduce_result else attn_outs
else:
res_out = attn_outs[1]
attn_out = tensor_model_parallel_all_reduce(attn_outs[0]) if not reduce_result else attn_outs[0]
# print("finished tensor_model_parallel_all_reduce......................................, is_decode:", is_decode)
return attn_out, res_out
elif use_unquan:
# Determine if we should reduce the result inside the kernel
# print(f"Qwen2Attention forward, is_decode: {is_decode}, kv_cache shape: {kv_cache.shape}, attn_metadata.seq_lens: {attn_metadata.seq_lens}")
kv_cache = kv_cache.view(2, -1, 16, self.total_num_kv_heads // get_tp_group().world_size, self.head_dim)
reduce_result = is_decode
if cos_cache is None or sin_cache is None:
cos_cache, sin_cache = get_cos_sin_cache(self.rotary_emb, attn_metadata, positions, is_decode)
if residual is None:
res_out = hidden_states
if USE_FUSED_QWEN_ATTENTION:
qkv_proj_qzeros = torch.Tensor()
o_proj_bias = torch.Tensor()
o_proj_qzeros = torch.Tensor()
# Call the naive fused kernel
####################################call fusion kernel######################################
# print("starting fused attention kernel......................................, is_decode:", is_decode)
# print('hidden_states', hidden_states.shape)
# print('attn_metadata.slot_mapping,', attn_metadata.slot_mapping,)
# print('attn_metadata.block_tables,', attn_metadata.block_tables,)
# print('attn_metadata.seq_lens,', attn_metadata.seq_lens,)
attn_outs = torch.vacc.fuse_atten_qwen2(
hidden_states,
residual,
self.fused_params['input_layernorm_weight'],
self.fused_params['qkv_proj_weight'],
self.fused_params['qkv_proj_weight_scale'],
# qkv_proj_bias,
self.fused_params['qkv_proj_bias'],
qkv_proj_qzeros, #self.fused_params['qkv_proj_qzeros'],
sin_cache,
cos_cache,
attn_metadata.slot_mapping,
kv_cache,
attn_metadata.block_tables,
env_blk_grp_size,
self.fused_params['o_proj_weight'],
self.fused_params['o_proj_weight_scale'],
o_proj_bias, #self.fused_params['o_proj_bias'],
o_proj_qzeros, #self.fused_params['o_proj_qzeros'],
attn_metadata.seq_lens,
self.scaling,
self.total_num_heads,
self.total_num_kv_heads,
is_decode,
is_decode,
reduce_result,
world_size=get_tp_group().world_size,
rank=get_tp_group().rank_in_group,
group_id=get_tp_group().group_id,
dev_info=get_tp_group().rank_device_infos)
else:
attn_outs = vacc_fused_attn_qwen2_naive(
hidden_states=hidden_states,
residual=residual,
hidden_states_norm_weight=self.fused_params['input_layernorm_weight'],
###ignore the group size, it is not used in the fused kernel
# qkv_proj_group_size=get_gptq_group_size(self.qkv_proj),
# o_proj_group_size=get_gptq_group_size(self.o_proj),
###
qkv_proj_weight=self.fused_params['qkv_proj_weight'],
qkv_proj_weight_scale=self.fused_params['qkv_proj_weight_scale'],
qkv_proj_bias=self.fused_params['qkv_proj_bias'],
qkv_proj_qzeros=None,
sin_cache=sin_cache,
cos_cache=cos_cache,
slot_mapping=attn_metadata.slot_mapping,
kv_cache=kv_cache,
block_tables=attn_metadata.block_tables,
block_group_size=env_blk_grp_size,
o_proj_weight=self.fused_params['o_proj_weight'],
o_proj_weight_scale=self.fused_params['o_proj_weight_scale'],
o_proj_bias=None,
o_proj_qzeros=None,
seq_lens=attn_metadata.seq_lens,
sm_scale=self.scaling,
num_attention_heads=self.total_num_heads,
num_key_value_heads=self.total_num_kv_heads,
# head_dim=self.head_dim,
is_decode=is_decode,
flash_attention=is_decode,
reduce_result=reduce_result,
world_size=get_tp_group().world_size,
rank=get_tp_group().rank_in_group,
group_id=get_tp_group().group_id,
dev_info=get_tp_group().rank_device_infos
)
# if is_decode:
# assert(1<0)
# if residual is None:
# output_ = attn_outs
# else:
# output_ = attn_outs[0]
# print("starting tensor_model_parallel_all_reduce......................................, is_decode:", is_decode)
if residual is None:
attn_out = tensor_model_parallel_all_reduce(attn_outs) if not reduce_result else attn_outs
# print(f"--------------------------attn_outs :{attn_out}")
# print(f"--------------------------attn_outs shape :{attn_out.shape}")
# print(f"--------------------------res_out :{res_out}")
else:
res_out = attn_outs[1]
attn_out = tensor_model_parallel_all_reduce(attn_outs[0]) if not reduce_result else attn_outs[0]
# print(f"--------------------------attn_outs :{attn_out}")
# print(f"--------------------------attn_outs shape :{attn_out.shape}")
# print(f"--------------------------res_out :{res_out}")
# print("finished tensor_model_parallel_all_reduce......................................, is_decode:", is_decode)
return attn_out, res_out
else:
# Fallback to original implementation
# save_tensor_pth("hidden_states", hidden_states)
qkv, _ = self.qkv_proj(hidden_states)
# save_tensor_pth("qkv", qkv)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# save_tensor_pth("q", q)
# save_tensor_pth("k", k)
# save_tensor_pth("v", v)
q, k = self.rotary_emb(positions, q, k)
# save_tensor_pth("q_rot", q)
# save_tensor_pth("k_rot", k)
attn_output = self.attn(q, k, v)
# save_tensor_pth("attn_output", attn_output)
output, _ = self.o_proj(attn_output)
# save_tensor_pth("output", output)
return output
class Qwen2DecoderLayer(nn.Module):
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# forward_context: ForwardContext = get_forward_context()
# attn_metadata = forward_context.attn_metadata
# kv_cache = self.attn.kv_cache[forward_context.virtual_engine]
# is_decode = attn_metadata.prefill_metadata is None
use_gptq = hasattr(self.self_attn.qkv_proj, 'quant_method') and isinstance(self.self_attn.qkv_proj.quant_method, GPTQLinearMethod)
use_unquan = hasattr(self.self_attn.qkv_proj, 'quant_method') and isinstance(self.self_attn.qkv_proj.quant_method, UnquantizedLinearMethod)
# if USE_FUSED_QWEN_ATTENTION and not is_decode:
if use_gptq or use_unquan:
# Self Attention with fused pre-layernorm
# Pass the layernorm weight to the attention layer for the fused kernel
self.self_attn.input_layernorm_weight = self.input_layernorm.weight
self.self_attn.fused_params = {}
self.self_attn.fused_params['input_layernorm_weight'] = self.input_layernorm.weight
set_fused_params(self.self_attn.fused_params, self.self_attn.qkv_proj.quant_method, self.self_attn.qkv_proj, 'qkv_proj')
set_fused_params(self.self_attn.fused_params, self.self_attn.o_proj.quant_method, self.self_attn.o_proj, 'o_proj')
hidden_states, residual = self.self_attn(
positions=positions,
hidden_states=hidden_states,
residual=residual,
)
else:
# Original non-fused path
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
# Fully Connected (MLP) part remains the same
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
# save_tensor_pth("post_attention_layernorm_hidden_states", hidden_states)
# save_tensor_pth("post_attention_layernorm_residual", residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual