1457 lines
72 KiB
Python
1457 lines
72 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
||
|
||
# Adapted from
|
||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
|
||
# Copyright 2024 The Qwen team.
|
||
# Copyright 2023 The vLLM team.
|
||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||
#
|
||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||
# and OPT implementations in this library. It has been modified from its
|
||
# original forms to accommodate minor architectural differences compared
|
||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
"""Inference-only Qwen2 model compatible with HuggingFace weights, with VACC modifications."""
|
||
from typing import Iterable, Optional, Set, Tuple, Union, List, Any, Dict
|
||
import math
|
||
import os
|
||
|
||
import torch
|
||
import numpy as np
|
||
from torch import nn
|
||
from transformers import Qwen2Config
|
||
|
||
from vllm.attention import Attention, AttentionType
|
||
from vllm.config import CacheConfig
|
||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||
tensor_model_parallel_all_reduce, get_tp_group)
|
||
from vllm.logger import init_logger
|
||
from vllm.model_executor.layers.activation import SiluAndMul
|
||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||
QKVParallelLinear,
|
||
RowParallelLinear)
|
||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||
from vllm.forward_context import ForwardContext, get_forward_context
|
||
|
||
# Import original classes to be modified
|
||
from vllm.model_executor.models.qwen2 import (Qwen2Attention as Qwen2AttentionOrig,
|
||
Qwen2DecoderLayer as Qwen2DecoderLayerOrig,
|
||
Qwen2MLP as Qwen2MLPOrig)
|
||
|
||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
|
||
from vllm.model_executor.layers.quantization.base_config import (
|
||
QuantizationConfig, QuantizeMethodBase)
|
||
from vllm_vacc.vllm.model_executor.layers.layernorm import RMSNorm_forward_vacc
|
||
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding, apply_interleaved_rope
|
||
from ..ops.mrope_op import get_sin_cos_mrope
|
||
import copy
|
||
logger = init_logger(__name__)
|
||
|
||
# --- VACC specific additions ---
|
||
from vllm_vacc.vllm.model_executor.models.vars import TRANSPOSE_GPTQ_WEIGHT, USE_FUSED_QWEN_ATTENTION
|
||
|
||
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
|
||
|
||
def save_tensor_pth(name, tensor):
|
||
return
|
||
import time
|
||
if tensor is None:
|
||
return
|
||
|
||
process_id = os.getpid()
|
||
|
||
timestamp = time.strftime("%Y-%m-%d_%H", time.localtime())
|
||
pathdir = f'./dump_qwen_pth/{timestamp}/process_{process_id}'
|
||
|
||
if not os.path.exists(pathdir):
|
||
os.makedirs(pathdir)
|
||
|
||
counter = 0
|
||
filename = f"{name}_{counter}.pth"
|
||
while os.path.exists(os.path.join(pathdir, filename)):
|
||
filename = f"{name}_{counter}.pth"
|
||
counter += 1
|
||
|
||
file_path = os.path.join(pathdir, filename)
|
||
|
||
if isinstance(tensor, torch.Tensor):
|
||
print(f'save {filename} {tensor.shape}')
|
||
torch.save(tensor.cpu(), file_path)
|
||
elif 'sin_cache' in name or 'cos_cache' in name:
|
||
tensor_list = [i.cpu() for i in tensor]
|
||
torch.save(tensor_list, file_path)
|
||
else:
|
||
torch.save(tensor, file_path)
|
||
return
|
||
|
||
# --- GPTQ Dequantization Logic (from gptq.py) ---
|
||
def int32_to_int4(s0, axis = -2):
|
||
# 要先拉平 shape[1, n]
|
||
# 每个int32 拆成8个int4, 8个int32表示, 得到[8, n]
|
||
|
||
# x32(int32) => 32bit => 4bit x 8 x4[8] 4bit
|
||
|
||
# x32 31-28 => x4[7]
|
||
# x32 27-24 => x4[6]
|
||
# ...
|
||
# x32 3-0 => x4[0]
|
||
|
||
# x32[index=0] => x4[7,6,5,4,3,2,1,0]
|
||
|
||
# 4bit转真实数字:
|
||
# 不是按补码方式
|
||
|
||
# 1111 => 15 => 7
|
||
# 15-8 = 7
|
||
|
||
# 0101 => 6 =>-2
|
||
# 6-8 = -2
|
||
|
||
# 0x 6A CB 37 2B (内存中排列 2B 37 CB 6A) => B273BCA6 => (-8) => int4: 3, -6, -1, -5, 3, 4, 2, -2
|
||
|
||
# 内存中实际排布为小端模式:
|
||
# int32: 2B 37 CB 6A => 2,11,3,7,12,11,6,10 => (-8) => -6,3, -5,-1, 4,3, -2,2 => 同一字节所在的两个交换得到 3, -6, -1, -5, 3, 4, 2, -2
|
||
# int4: 3, -6, -1, -5, 3, 4, 2, -2
|
||
|
||
s = s0.view(torch.uint32)
|
||
all = []
|
||
for i in range(8):
|
||
x = 15 << (i*4)
|
||
# s2 = torch.bitwise_and(x,s)
|
||
s2 = torch.from_numpy(np.bitwise_and(x, s.numpy()))
|
||
s3 = s2 / (2 ** (i*4))
|
||
s4 = s3.to(torch.int32)
|
||
# 补码, 结果不对
|
||
# s4[s4 > 7] = s4[s4 > 7]-16
|
||
# 直接 - 8 结果正确, 范围: -8-7
|
||
s4 = s4 - 8
|
||
all.append(s4.reshape(1,*s4.shape))
|
||
all = torch.concatenate(all, 0)
|
||
if axis == -2 or axis == 0:
|
||
# 8,K//8,N => K//8,8,N => K,N
|
||
all = all.transpose(-2,0).reshape(-1,all.shape[-1]).contiguous()
|
||
else:
|
||
# 8,N,K//8 => N,K//8,8 => N,K
|
||
all = all.permute(1,2,0).reshape(all.shape[-2],-1).contiguous()
|
||
return all
|
||
|
||
def dequant_weight(qw, scales, group_size = 128):
|
||
N = qw.shape[1]
|
||
int4_to_int32_axis = -2
|
||
if TRANSPOSE_GPTQ_WEIGHT:
|
||
N = qw.shape[0]
|
||
int4_to_int32_axis = -1
|
||
qweight = int32_to_int4(qw,int4_to_int32_axis).to(torch.float16) #int32 => 8 int4 +> fp16
|
||
|
||
if TRANSPOSE_GPTQ_WEIGHT:
|
||
scales = scales.T.contiguous()
|
||
qweight = qweight.T.contiguous()
|
||
|
||
scales = torch.concatenate([scales] * group_size, 1).reshape(-1, N) # scale 按 group_size 扩展, 每 group_size 个数共用一个scale
|
||
|
||
# print('qweight', qweight.shape, qweight.dtype)
|
||
# print('scale', scales.shape, scales.dtype)
|
||
|
||
dequant_weight = qweight * scales #dequant
|
||
return dequant_weight
|
||
|
||
def apply_gptq_linear(
|
||
input_tensor: torch.Tensor,
|
||
layer: Union[QKVParallelLinear, RowParallelLinear],
|
||
) -> torch.Tensor:
|
||
"""
|
||
Applies a GPTQ-quantized linear layer by dequantizing weights on-the-fly.
|
||
"""
|
||
out_shape = input_tensor.shape[:-1] + (layer.qweight.shape[-2 if TRANSPOSE_GPTQ_WEIGHT else -1], ) # M,N
|
||
|
||
reshaped_x = input_tensor.reshape(-1, input_tensor.shape[-1])
|
||
|
||
# This assumes the linear_method is attached and is a GPTQ derivative
|
||
linear_method = layer.quant_method
|
||
quant_config = linear_method.quant_config
|
||
|
||
# scale_k is a VACC-specific parameter for scale broadcasting
|
||
scale_k = getattr(linear_method, "scale_k", 1)
|
||
|
||
# Dequantize weight
|
||
# weight = dequant_weight(layer.qweight.data, layer.scales.data, quant_config.group_size // scale_k)
|
||
weight = dequant_weight(layer.qweight.cpu(), layer.scales.cpu(), quant_config.group_size // scale_k).to(layer.qweight.device)
|
||
|
||
# Perform GEMM
|
||
output = torch.matmul(reshaped_x, weight)
|
||
|
||
# Add bias if it exists
|
||
if hasattr(layer, 'bias') and layer.bias is not None:
|
||
output = output + layer.bias
|
||
|
||
# out_shape = input_tensor.shape[:-1] + (weight.shape[-1],)
|
||
return output.reshape(out_shape)
|
||
|
||
def get_gptq_group_size(layer: Union[QKVParallelLinear, RowParallelLinear]):
|
||
|
||
# This assumes the linear_method is attached and is a GPTQ derivative
|
||
linear_method = layer.quant_method
|
||
quant_config = linear_method.quant_config
|
||
|
||
# scale_k is a VACC-specific parameter for scale broadcasting
|
||
scale_k = getattr(linear_method, "scale_k", 1)
|
||
return quant_config.group_size // scale_k
|
||
|
||
def apply_gptq_linear_(
|
||
input_tensor: torch.Tensor,
|
||
weight: torch.Tensor,
|
||
weight_scale: torch.Tensor,
|
||
bias: Optional[torch.Tensor],
|
||
qzeros: Optional[torch.Tensor],
|
||
group_size_h: int,
|
||
group_size_w: int,
|
||
) -> torch.Tensor:
|
||
"""
|
||
Applies a GPTQ-quantized linear layer by dequantizing weights on-the-fly.
|
||
"""
|
||
out_shape = input_tensor.shape[:-1] + (weight.shape[-2 if TRANSPOSE_GPTQ_WEIGHT else -1], ) # M,N
|
||
|
||
|
||
reshaped_x = input_tensor.reshape(-1, input_tensor.shape[-1])
|
||
|
||
# This assumes the linear_method is attached and is a GPTQ derivative
|
||
# linear_method = layer.quant_method
|
||
# quant_config = linear_method.quant_config
|
||
|
||
# # scale_k is a VACC-specific parameter for scale broadcasting
|
||
# scale_k = getattr(linear_method, "scale_k", 1)
|
||
|
||
# Dequantize weight
|
||
# weight = dequant_weight(layer.qweight.data, layer.scales.data, quant_config.group_size // scale_k)
|
||
# weight = dequant_weight(weight.cpu(), weight_scale.cpu(), group_size).to(weight.device)
|
||
|
||
# Perform GEMM
|
||
# output = torch.matmul(reshaped_x, weight)
|
||
# print("entering apply_gptq_linear_, reshaped_x shape:", reshaped_x.shape, "reshaped_x stride", reshaped_x.stride(), "input_tensor", input_tensor.shape, "weight shape:", weight.shape, "weight_scale shape:", weight_scale.shape, "group_size:", group_size)
|
||
output = torch.vacc.w4a8_block_int4_matmul(
|
||
reshaped_x,
|
||
weight.transpose(-1, -2),
|
||
weight_scale.transpose(-1, -2),
|
||
[group_size_h, group_size_w],
|
||
)
|
||
# print("exiting apply_gptq_linear_, output shape:", output.shape)
|
||
# Add bias if it exists
|
||
if bias is not None:
|
||
output = output + bias
|
||
|
||
# out_shape = input_tensor.shape[:-1] + (weight.shape[-1],)
|
||
return output.reshape(out_shape)
|
||
|
||
def post_layernorm(
|
||
x: torch.Tensor,
|
||
residual: Optional[torch.Tensor] = None,
|
||
weight: Optional[torch.Tensor] = None,
|
||
variance_size_override: Optional[int] = None,
|
||
variance_epsilon: float = 1e-6,
|
||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||
"""PyTorch-native implementation equivalent to forward()."""
|
||
orig_dtype = x.dtype
|
||
x = x.to(torch.float32)
|
||
if residual is not None:
|
||
x = x + residual.to(torch.float32)
|
||
residual = x.to(orig_dtype)
|
||
|
||
hidden_size = x.shape[-1]
|
||
|
||
if variance_size_override is None:
|
||
x_var = x
|
||
else:
|
||
if hidden_size < variance_size_override:
|
||
raise ValueError(
|
||
"Expected hidden_size to be at least "
|
||
f"{variance_size_override}, but found: {hidden_size}")
|
||
|
||
x_var = x[:, :, :variance_size_override]
|
||
|
||
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
|
||
|
||
x = x * torch.rsqrt(variance + variance_epsilon)
|
||
x = x.to(orig_dtype)
|
||
if weight is not None:
|
||
x = x * weight
|
||
if residual is None:
|
||
return x
|
||
else:
|
||
return x, residual
|
||
|
||
# --- VACC Fused Kernel for Qwen2 ---
|
||
def vacc_fused_attn_qwen2_naive(
|
||
hidden_states: torch.Tensor,
|
||
residual: Optional[torch.Tensor],
|
||
hidden_states_norm_weight: torch.Tensor,
|
||
qkv_proj_weight: torch.Tensor,
|
||
qkv_proj_weight_scale: torch.Tensor,
|
||
qkv_proj_bias: Optional[torch.Tensor],
|
||
qkv_proj_qzeros: Optional[torch.Tensor],
|
||
sin_cache: List[torch.Tensor],
|
||
cos_cache: List[torch.Tensor],
|
||
slot_mapping: torch.Tensor,
|
||
kv_cache: torch.Tensor,
|
||
block_tables: torch.Tensor,
|
||
block_group_size: int,
|
||
o_proj_weight: torch.Tensor,
|
||
o_proj_weight_scale: torch.Tensor,
|
||
o_proj_bias: Optional[torch.Tensor],
|
||
o_proj_qzeros: Optional[torch.Tensor],
|
||
seq_lens: List[int],
|
||
sm_scale: float,
|
||
num_attention_heads: int,
|
||
num_key_value_heads: int,
|
||
flash_attention: bool,
|
||
is_decode: bool,
|
||
reduce_result: bool,
|
||
world_size: int,
|
||
rank: int,
|
||
group_id: int,
|
||
dev_info: List[int] | Tuple[int],
|
||
block_size: int = 16
|
||
):
|
||
# qkv_proj_group_size = 128
|
||
# o_proj_group_size = 64
|
||
qkv_proj_group_size_h = qkv_proj_weight.shape[-1] * 8 // qkv_proj_weight_scale.shape[-1]
|
||
qkv_proj_group_size_w = qkv_proj_weight.shape[-2] * 8 // qkv_proj_weight_scale.shape[-2]
|
||
o_proj_group_size_h = o_proj_weight.shape[-1] * 8 // o_proj_weight_scale.shape[-1]
|
||
o_proj_group_size_w = o_proj_weight.shape[-2] * 8 // o_proj_weight_scale.shape[-2]
|
||
if residual is not None:
|
||
hidden_states = hidden_states + residual
|
||
residual_out = hidden_states
|
||
|
||
save_tensor_pth("hidden_states", hidden_states)
|
||
# 1. Fused RMSNorm
|
||
hidden_states_norm = torch.vacc.rms_norm(
|
||
hidden_states.unsqueeze(0), hidden_states_norm_weight, 1e-6).squeeze(0)
|
||
# NOTE: for qwen3 and qwen2.5, head_dim is always 128
|
||
head_dim = 128
|
||
# 2. QKV Projection using on-the-fly dequantization
|
||
save_tensor_pth("hidden_states_norm", hidden_states_norm)
|
||
qkv = apply_gptq_linear_(hidden_states_norm, qkv_proj_weight, qkv_proj_weight_scale, qkv_proj_bias, qkv_proj_qzeros, qkv_proj_group_size_h, qkv_proj_group_size_w)
|
||
save_tensor_pth("qkv", qkv)
|
||
# Split Q, K, V
|
||
num_q_heads = num_attention_heads // world_size
|
||
num_kv_heads = max(1, num_key_value_heads // world_size)
|
||
q_size = num_q_heads * head_dim
|
||
kv_size = num_kv_heads * head_dim
|
||
q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
|
||
save_tensor_pth("q", q)
|
||
save_tensor_pth("k", k)
|
||
save_tensor_pth("v", v)
|
||
q = q.view(-1, num_q_heads, head_dim)
|
||
k = k.view(-1, num_kv_heads, head_dim)
|
||
v = v.view(-1, num_kv_heads, head_dim)
|
||
save_tensor_pth("q_reshaped", q)
|
||
save_tensor_pth("k_reshaped", k)
|
||
save_tensor_pth("v_reshaped", v)
|
||
# 3. RoPE, KV Caching, and Attention loop
|
||
start = 0
|
||
attn_outs = []
|
||
|
||
|
||
if is_decode:
|
||
# convert block_tables to 8K group index
|
||
block_per_group = block_group_size // block_size
|
||
block_tables_grouped = (block_tables // block_per_group).to(torch.int32)
|
||
# logger.warning(f"decode block table: {block_tables}")
|
||
|
||
num_blocks = kv_cache.shape[1]
|
||
key_cache_split = kv_cache[0].view(num_blocks, -1, num_kv_heads, head_dim)
|
||
value_cache_split = kv_cache[1].view(num_blocks, -1, num_kv_heads, head_dim)
|
||
|
||
# bs loop
|
||
|
||
for i, seq_len in enumerate(seq_lens):
|
||
if not is_decode: # Prefill
|
||
end = start + seq_len
|
||
else: # Decode
|
||
end = start + 1
|
||
|
||
cos = cos_cache[i].unsqueeze(-2)
|
||
sin = sin_cache[i].unsqueeze(-2)
|
||
|
||
q_i, k_i, v_i = q[start:end], k[start:end], v[start:end]
|
||
|
||
# Apply Rotary Positional Embedding
|
||
q_rot, k_rot = torch.vacc.RotaryPosEmbedding(q_i, k_i, cos, sin, 0, "neox")
|
||
save_tensor_pth("q_rot", q_rot)
|
||
save_tensor_pth("k_rot", k_rot)
|
||
# Reshape and cache K/V
|
||
torch.vacc.reshape_and_cache_attention(k_rot, key_cache_split, slot_mapping[start : end, ...])
|
||
torch.vacc.reshape_and_cache_attention(v_i, value_cache_split, slot_mapping[start : end, ...])
|
||
|
||
# Attention calculation
|
||
if not is_decode:
|
||
attn_out = torch.vacc.scaled_dot_product_attention(
|
||
query=q_rot, key=k_rot, value=v_i.contiguous(),
|
||
attn_mask = None,
|
||
dropout_p = 0.0,
|
||
is_causal = True, #causal_attn and not self.need_mask,
|
||
is_train = False,
|
||
recompute = False,
|
||
flash_attention = False,
|
||
sm_scale=sm_scale)
|
||
else:
|
||
# For decode, reconstruct past K/V from cache
|
||
key_cache_grouped = key_cache_split.view(-1, block_group_size, num_kv_heads, head_dim)
|
||
value_cache_grouped = value_cache_split.view(-1, block_group_size, num_kv_heads, head_dim)
|
||
# block_per_group = block_group_size // block_size
|
||
# block_tables_grouped = (block_tables // block_per_group).to(torch.int32)
|
||
|
||
k_slices = key_cache_grouped[block_tables_grouped[i], :, :, :]
|
||
k_past = torch.cat([k_slices[j].unsqueeze(0) for j in range(len(block_tables_grouped[i]))], dim=0)
|
||
k_past = k_past.reshape(-1, num_kv_heads, head_dim)[:seq_len]
|
||
|
||
v_slices = value_cache_grouped[block_tables_grouped[i], :, :, :]
|
||
v_past = torch.cat([v_slices[j].unsqueeze(0) for j in range(len(block_tables_grouped[i]))], dim=0)
|
||
v_past = v_past.reshape(-1, num_kv_heads, head_dim)[:seq_len]
|
||
# k_past = key_cache_split.reshape(-1, key_cache_split.shape[-2], key_cache_split.shape[-1])[:seq_len]
|
||
# v_past = value_cache_split.reshape(-1, value_cache_split.shape[-2], value_cache_split.shape[-1])[:seq_len]
|
||
|
||
attn_out = torch.vacc.scaled_dot_product_attention(
|
||
query=q_rot.contiguous(), key=k_past.contiguous(), value=v_past.contiguous(),
|
||
attn_mask=None,
|
||
dropout_p=0,
|
||
is_causal=False,
|
||
is_train=False,
|
||
recompute=False,
|
||
flash_attention=False,#flash_attention
|
||
sm_scale=sm_scale,)
|
||
|
||
attn_outs.append(attn_out)
|
||
start = end
|
||
|
||
attn_output_cat = torch.cat(attn_outs, dim=0)
|
||
|
||
save_tensor_pth("attn_output", attn_output_cat)
|
||
# 4. Output Projection
|
||
attn_output_reshaped = attn_output_cat.view(hidden_states.shape[0], -1)
|
||
# print(f"attn_output_reshaped: {attn_output_reshaped}")
|
||
output = apply_gptq_linear_(attn_output_reshaped,o_proj_weight, o_proj_weight_scale,o_proj_bias, o_proj_qzeros, o_proj_group_size_h, o_proj_group_size_w)
|
||
|
||
# 5. Optional All-Reduce and post layernorm
|
||
if reduce_result:
|
||
output = tensor_model_parallel_all_reduce(output)
|
||
|
||
if residual is not None:
|
||
return output, residual_out
|
||
|
||
return output
|
||
|
||
# --- Rewriting forward methods for Qwen2Attention and Qwen2DecoderLayer ---
|
||
|
||
# uniform the params names from different quantize method
|
||
def set_fused_params(fused_params: Dict[str, Any], quant_method: QuantizeMethodBase, layer: nn.Module, name: str):
|
||
if isinstance(quant_method, UnquantizedLinearMethod):
|
||
fused_params[name + '_weight'] = layer.weight
|
||
fused_params[name + '_weight_scale'] = torch.Tensor()
|
||
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
|
||
fused_params[name + '_qzeros'] = None
|
||
elif isinstance(quant_method, Fp8LinearMethod):
|
||
fused_params[name + '_weight'] = layer.weight
|
||
fused_params[name + '_weight_scale'] = layer.weight_scale_inv
|
||
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
|
||
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
|
||
elif isinstance(quant_method, GPTQLinearMethod):
|
||
fused_params[name + '_weight'] = layer.qweight
|
||
fused_params[name + '_weight_scale'] = layer.scales
|
||
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
|
||
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
|
||
elif isinstance(quant_method, AWQLinearMethod):
|
||
fused_params[name + '_weight'] = layer.qweight
|
||
fused_params[name + '_weight_scale'] = layer.scales
|
||
fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias
|
||
fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros
|
||
else:
|
||
raise ValueError(f"Unsupported quant_method: {quant_method}")
|
||
|
||
# 全局变量,用于为每次 dump 操作创建一个唯一的批次 ID
|
||
# NOTE: This counter is process-specific. If you run multiple processes,
|
||
# each will have its own counter starting from 0. The directory structure
|
||
# already includes the process_id, so there won't be conflicts.
|
||
DUMP_COUNTER = 0
|
||
|
||
def dump_qwen2_attention_params(hidden_states, residual, fused_params, attn_metadata, is_decode,
|
||
sin_cache, cos_cache, kv_cache, env_blk_grp_size, scaling,
|
||
total_num_heads, total_num_kv_heads, world_size, rank, group_id, dev_info):
|
||
"""
|
||
收集所有需要dump的参数,并以“批次”为单位进行保存。
|
||
"""
|
||
# if is_decode:
|
||
# return
|
||
|
||
# 1. 将所有需要保存的张量和数据收集到一个字典中
|
||
# 键是文件名(不含扩展名),值是数据本身。
|
||
# 这里的关键是:即使某个值是 None,我们也把它放进字典,
|
||
# 让保存函数自己去处理,而不是在这里跳过。
|
||
params_to_dump = {
|
||
'hidden_states': hidden_states,
|
||
'residual': residual,
|
||
'hidden_states_norm_weight': fused_params.get('input_layernorm_weight'),
|
||
'qkv_proj_weight': fused_params.get('qkv_proj_weight'),
|
||
'qkv_proj_weight_scale': fused_params.get('qkv_proj_weight_scale'),
|
||
'qkv_proj_bias': fused_params.get('qkv_proj_bias'),
|
||
'qkv_proj_qzeros': fused_params.get('qkv_proj_qzeros'),
|
||
'o_proj_weight': fused_params.get('o_proj_weight'),
|
||
'o_proj_weight_scale': fused_params.get('o_proj_weight_scale'),
|
||
'o_proj_bias': fused_params.get('o_proj_bias'),
|
||
'o_proj_qzeros': fused_params.get('o_proj_qzeros'),
|
||
'sin_cache': sin_cache,
|
||
'cos_cache': cos_cache,
|
||
'kv_cache': kv_cache,
|
||
'slot_mapping': attn_metadata.slot_mapping if attn_metadata else None,
|
||
'block_tables': attn_metadata.block_tables if attn_metadata else None,
|
||
'seq_lens': attn_metadata.seq_lens if attn_metadata else None,
|
||
'block_group_size': env_blk_grp_size,
|
||
'sm_scale': scaling,
|
||
'num_attention_heads': total_num_heads,
|
||
'num_key_value_heads': total_num_kv_heads,
|
||
'world_size': world_size,
|
||
'rank': rank,
|
||
'group_id': group_id,
|
||
'dev_info': dev_info,
|
||
}
|
||
|
||
# 2. 一次性调用新的批次保存函数
|
||
save_dump_batch(params_to_dump)
|
||
|
||
|
||
def save_dump_batch(tensors_dict):
|
||
"""
|
||
将一个字典中的所有张量和数据保存到一个唯一的、代表“同一次调用”的子目录中。
|
||
"""
|
||
global DUMP_COUNTER
|
||
import time
|
||
# 基础目录结构:./dump_data/{timestamp}/process_{pid}
|
||
process_id = os.getpid()
|
||
timestamp = time.strftime("%Y-%m-%d_%H", time.localtime())
|
||
base_dir = f'./dump_qwen2_prefill_pth/{timestamp}/process_{process_id}'
|
||
|
||
|
||
# 为这次“批次”的dump创建一个唯一的子目录
|
||
# 这就是你想要的“魔法数”的实现
|
||
batch_dir = os.path.join(base_dir, f'dump_{DUMP_COUNTER}')
|
||
|
||
# 递增全局计数器,为下一次调用做准备
|
||
DUMP_COUNTER += 1
|
||
|
||
# 创建目录(如果不存在)
|
||
os.makedirs(batch_dir, exist_ok=True)
|
||
|
||
print(f"--- Dumping batch to: {batch_dir} ---")
|
||
|
||
# 遍历字典,保存每一个非 None 的项
|
||
for name, data in tensors_dict.items():
|
||
if data is None:
|
||
# 如果数据是 None,直接跳过,不保存任何文件
|
||
continue
|
||
|
||
file_path = os.path.join(batch_dir, f"{name}.pth")
|
||
|
||
try:
|
||
if isinstance(data, torch.Tensor):
|
||
print(f' Saving tensor: {name}.pth, shape: {data.shape}, dtype: {data.dtype}')
|
||
torch.save(data.cpu(), file_path)
|
||
elif name in ['sin_cache', 'cos_cache'] and isinstance(data, (list, tuple)):
|
||
# 对 sin/cos cache 的特殊处理
|
||
tensor_list = [i.cpu() for i in data]
|
||
print(f' Saving tensor list: {name}.pth, count: {len(tensor_list)}')
|
||
torch.save(tensor_list, file_path)
|
||
else:
|
||
# 保存其他可序列化的 Python 对象(如数字、字符串等)
|
||
print(f' Saving metadata: {name}.pth, value: {data}')
|
||
torch.save(data, file_path)
|
||
except Exception as e:
|
||
print(f" [ERROR] Failed to save {name}.pth: {e}")
|
||
|
||
print(f"--- Finished dumping batch {DUMP_COUNTER - 1} ---")
|
||
|
||
def has_nan(t1, t2):
|
||
has_nan_1 = torch.isnan(t1).any()
|
||
has_nan_2 = torch.isnan(t2).any()
|
||
|
||
# 2. 根据检查结果进行判断和计算
|
||
if has_nan_1 or has_nan_2:
|
||
print("计算被中止,因为检测到了NaN值。")
|
||
if has_nan_1:
|
||
print(" - Tensor 1 包含 NaN。")
|
||
if has_nan_2:
|
||
print(" - Tensor 2 包含 NaN。")
|
||
|
||
class MlpSplitConfig:
|
||
def __init__(self):
|
||
self.k_split = 0
|
||
self.n_split = 0
|
||
self.w1_weight_addr = 0
|
||
self.w3_weight_addr = 0
|
||
self.matA_addr_0 = 0
|
||
self.matA_addr_1 = 0
|
||
self.w1_splitk_buffer = 0
|
||
self.w3_splitk_buffer = 0
|
||
self.w13_multiply = 0
|
||
self.w2_splitn_buffer_list = []
|
||
self.w13_dequant_block_size = 0
|
||
self.w2_dequant_block_size = 0
|
||
|
||
def get_dequant_block_size_int4(k, n, k_split, n_split):
|
||
ret = []
|
||
w13_dequant_block_init_size = 128
|
||
k_per_split = k // k_split
|
||
|
||
if k_per_split % w13_dequant_block_init_size == 0:
|
||
ret.append(w13_dequant_block_init_size)
|
||
else:
|
||
while w13_dequant_block_init_size >= 16 and k_per_split % w13_dequant_block_init_size != 0:
|
||
w13_dequant_block_init_size //= 2
|
||
if w13_dequant_block_init_size >= 16:
|
||
ret.append(w13_dequant_block_init_size)
|
||
else:
|
||
ret.append(-1)
|
||
|
||
w2_dequant_block_init_size = 128
|
||
n_per_split = n // n_split
|
||
if n_per_split % w2_dequant_block_init_size == 0:
|
||
ret.append(w2_dequant_block_init_size)
|
||
else:
|
||
while w2_dequant_block_init_size >= 16 and n_per_split % w2_dequant_block_init_size != 0:
|
||
w2_dequant_block_init_size //= 2
|
||
if w2_dequant_block_init_size >= 16:
|
||
ret.append(w2_dequant_block_init_size)
|
||
else:
|
||
ret.append(-1)
|
||
|
||
return ret
|
||
|
||
def get_split_config(total_split):
|
||
ret = []
|
||
max_split_n = math.ceil(math.sqrt(total_split))
|
||
|
||
for i in range(1, max_split_n + 1):
|
||
if total_split % i == 0:
|
||
cur_config = [total_split // i, i] # [k_split, n_split]
|
||
ret.append(cur_config)
|
||
|
||
split_config_size = len(ret)
|
||
for i in range(split_config_size - 1, -1, -1):
|
||
if ret[i][0] != ret[i][1]:
|
||
ret.append([ret[i][1], ret[i][0]])
|
||
|
||
return ret
|
||
|
||
def validate_mlp_split_config(k, n, k_split, n_split, config):
|
||
min_seqlen = 1024
|
||
# 尺寸可除性检查
|
||
if k % k_split != 0:
|
||
return False
|
||
if n % (n_split * 4) != 0:
|
||
return False
|
||
if k % (k_split * 4) != 0:
|
||
return False
|
||
if (n / (n_split * 4)) % 16 != 0:
|
||
return False
|
||
if (k / (k_split * 4)) % 16 != 0:
|
||
return False
|
||
|
||
# 计算权重和输入矩阵的大小
|
||
w13_split_size = (k // k_split) * (n // n_split // 4) * 2 * 2
|
||
matA_size = min_seqlen * (k // k_split // 4) * 2 * 2
|
||
|
||
# 检查SSRAM 4-7区空间是否足够
|
||
if w13_split_size + matA_size > 5 * 1024 * 1024:
|
||
return False
|
||
|
||
# 计算剩余空间
|
||
ssram_4_to_7_size_left = 5 * 1024 * 1024 - (w13_split_size + matA_size)
|
||
w13_result_size = min_seqlen * (n // n_split // 4) * 2
|
||
|
||
# 检查结果存储空间
|
||
if w13_result_size * 3 > ssram_4_to_7_size_left + 5 * 1024 * 1024:
|
||
return False
|
||
|
||
ssram_0_to_3_size_left = 5 * 1024 * 1024
|
||
w13_result_save_type = -1
|
||
|
||
# 确定结果存储方案
|
||
if w13_result_size * 3 <= 5 * 1024 * 1024:
|
||
w13_result_save_type = 0
|
||
ssram_0_to_3_size_left -= w13_result_size * 3
|
||
elif w13_result_size * 2 <= 5 * 1024 * 1024:
|
||
if w13_result_size <= ssram_4_to_7_size_left:
|
||
w13_result_save_type = 1
|
||
ssram_0_to_3_size_left -= w13_result_size * 2
|
||
ssram_4_to_7_size_left -= w13_result_size
|
||
else:
|
||
return False
|
||
else:
|
||
return False
|
||
|
||
# 处理n_split > 1的情况
|
||
if n_split > 1:
|
||
w2_split_k_buffer_total_size = min_seqlen * (k // k_split // 4) * 2 * k_split
|
||
w2_split_k_buffer_size = min_seqlen * (k // k_split // 4) * 2
|
||
|
||
if w2_split_k_buffer_total_size > ssram_0_to_3_size_left + ssram_4_to_7_size_left:
|
||
return False
|
||
|
||
available_blocks = (ssram_0_to_3_size_left // w2_split_k_buffer_size) + \
|
||
(ssram_4_to_7_size_left // w2_split_k_buffer_size)
|
||
if available_blocks < k_split:
|
||
return False
|
||
|
||
# 设置配置参数
|
||
config.k_split = k_split
|
||
config.n_split = n_split
|
||
config.w1_weight_addr = 0x54000000
|
||
config.w3_weight_addr = config.w1_weight_addr + (k // k_split) * (n // n_split // 4) * 2
|
||
config.matA_addr_0 = config.w3_weight_addr + (k // k_split) * (n // n_split // 4) * 2
|
||
config.matA_addr_1 = config.matA_addr_0 + min_seqlen * (k // k_split // 4) * 2
|
||
config.w1_splitk_buffer = 0x50000000
|
||
config.w3_splitk_buffer = config.w1_splitk_buffer + min_seqlen * (n // n_split // 4) * 2
|
||
|
||
if w13_result_save_type == 0:
|
||
config.w13_multiply = config.w3_splitk_buffer + min_seqlen * (n // n_split // 4) * 2
|
||
elif w13_result_save_type == 1:
|
||
config.w13_multiply = config.matA_addr_1 + min_seqlen * (k // k_split // 4) * 2
|
||
|
||
# 处理w2缓冲区分配
|
||
if n_split > 1:
|
||
w2_split_k_buffer_size = min_seqlen * (k // k_split // 4) * 2
|
||
w2_split_k_buffer_in_ssram_0_to_3 = min(ssram_0_to_3_size_left // w2_split_k_buffer_size, k_split)
|
||
w2_split_k_buffer_in_ssram_4_to_7 = k_split - w2_split_k_buffer_in_ssram_0_to_3
|
||
|
||
if w13_result_save_type == 0:
|
||
ssram_0_to_3_start_addr = config.w13_multiply + min_seqlen * (n // n_split // 4) * 2
|
||
ssram_4_to_7_start_addr = config.matA_addr_1 + min_seqlen * (k // k_split // 4) * 2
|
||
else:
|
||
ssram_0_to_3_start_addr = config.w3_splitk_buffer + min_seqlen * (n // n_split // 4) * 2
|
||
ssram_4_to_7_start_addr = config.w13_multiply + min_seqlen * (n // n_split // 4) * 2
|
||
|
||
for i in range(w2_split_k_buffer_in_ssram_0_to_3):
|
||
addr = ssram_0_to_3_start_addr + i * w2_split_k_buffer_size
|
||
config.w2_splitn_buffer_list.append(addr)
|
||
|
||
for i in range(w2_split_k_buffer_in_ssram_4_to_7):
|
||
addr = ssram_4_to_7_start_addr + i * w2_split_k_buffer_size
|
||
config.w2_splitn_buffer_list.append(addr)
|
||
else:
|
||
config.w2_splitn_buffer_list.append(0)
|
||
|
||
# 获取解量化块大小
|
||
block_size = get_dequant_block_size_int4(k, n, k_split, n_split)
|
||
if block_size[0] == -1 or block_size[1] == -1:
|
||
return False
|
||
|
||
config.w13_dequant_block_size = block_size[0]
|
||
config.w2_dequant_block_size = block_size[1]
|
||
|
||
return True
|
||
|
||
def get_mlp_split_schedule(k, n):
|
||
min_split = math.ceil(k * n / 2 / 1310720) # 最大权重分割大小1.25MB
|
||
current_split = min_split
|
||
config = MlpSplitConfig()
|
||
|
||
for _ in range(100): # 最多尝试100次
|
||
split_configs = get_split_config(current_split)
|
||
for cfg in split_configs:
|
||
if validate_mlp_split_config(k, n, cfg[0], cfg[1], config):
|
||
return True, config
|
||
current_split += 1
|
||
|
||
return False, None
|
||
|
||
|
||
def Qwen2MLP__init__(
|
||
self,
|
||
hidden_size: int,
|
||
intermediate_size: int,
|
||
hidden_act: str,
|
||
quant_config: Optional[QuantizationConfig] = None,
|
||
prefix: str = "",
|
||
) -> None:
|
||
super(Qwen2MLPOrig, self).__init__()
|
||
self.gate_up_proj = MergedColumnParallelLinear(
|
||
hidden_size,
|
||
[intermediate_size] * 2,
|
||
bias=False,
|
||
quant_config=quant_config,
|
||
prefix=f"{prefix}.gate_up_proj",
|
||
)
|
||
self.down_proj = RowParallelLinear(
|
||
intermediate_size,
|
||
hidden_size,
|
||
bias=False,
|
||
quant_config=quant_config,
|
||
prefix=f"{prefix}.down_proj",
|
||
)
|
||
if hidden_act != "silu":
|
||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||
"Only silu is supported for now.")
|
||
self.act_fn = SiluAndMul()
|
||
|
||
if hasattr(self.gate_up_proj, 'quant_method') and isinstance(self.gate_up_proj.quant_method, GPTQLinearMethod):
|
||
intermediate_size_per_partition = intermediate_size // self.down_proj.tp_size
|
||
success, config = get_mlp_split_schedule(hidden_size, intermediate_size_per_partition)
|
||
if success:
|
||
self.w13_dequant_block_size = config.w13_dequant_block_size
|
||
self.w2_dequant_block_size = config.w2_dequant_block_size
|
||
group_size = self.gate_up_proj.quant_method.quant_config.group_size
|
||
scale_k_w13 = group_size // config.w13_dequant_block_size
|
||
scale_k_w2 = group_size // config.w2_dequant_block_size
|
||
self.gate_up_proj.quant_method.scale_k = max(scale_k_w13, self.gate_up_proj.quant_method.scale_k)
|
||
self.down_proj.quant_method.scale_k = max(scale_k_w2, self.down_proj.quant_method.scale_k)
|
||
if self.gate_up_proj.quant_method.scale_k > 1 and len(self.gate_up_proj.scales.data.shape) == 2 and \
|
||
self.gate_up_proj.scales.data.dtype in [torch.float16, torch.bfloat16, torch.float32]:
|
||
w13_scales = self.gate_up_proj.scales.data
|
||
scale_and_zero_size = hidden_size * self.gate_up_proj.quant_method.scale_k // group_size
|
||
w13_scales = torch.empty(
|
||
scale_and_zero_size,
|
||
w13_scales.shape[-1],
|
||
dtype=w13_scales.dtype,
|
||
)
|
||
self.gate_up_proj.scales.data = w13_scales
|
||
if self.down_proj.quant_method.scale_k > 1 and len(self.down_proj.scales.data.shape) == 2 and \
|
||
self.down_proj.scales.data.dtype in [torch.float16, torch.bfloat16, torch.float32]:
|
||
w2_scales = self.down_proj.scales.data
|
||
scale_and_zero_size = intermediate_size_per_partition * self.down_proj.quant_method.scale_k // group_size
|
||
w2_scales = torch.empty(
|
||
scale_and_zero_size,
|
||
w2_scales.shape[-1],
|
||
dtype=w2_scales.dtype,
|
||
)
|
||
self.down_proj.scales.data = w2_scales
|
||
else:
|
||
raise ValueError(f"Cannot find valid MLP split schedule for hidden_size {hidden_size}, "
|
||
"intermediate_size {intermediate_size}, tp_size {self.down_proj.tp_size}")
|
||
|
||
def get_cos_sin_cache(rotary_emb: Union["MRotaryEmbedding", "RotaryEmbedding"],
|
||
attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]],
|
||
positions: Union[torch.Tensor, list],
|
||
is_decode: bool):
|
||
if isinstance(rotary_emb, MRotaryEmbedding):
|
||
# get mrope sin/cos
|
||
cos_cache, sin_cache = get_sin_cos_mrope(rotary_emb, positions)
|
||
if len(attn_metadata.seq_lens) > 1:
|
||
if is_decode:
|
||
cos_cache = torch.chunk(cos_cache, len(attn_metadata.seq_lens))
|
||
sin_cache = torch.chunk(sin_cache, len(attn_metadata.seq_lens))
|
||
else:
|
||
cos_cache = torch.split(cos_cache, attn_metadata.seq_lens)
|
||
sin_cache = torch.split(sin_cache, attn_metadata.seq_lens)
|
||
else:
|
||
cos_cache = [cos_cache]
|
||
sin_cache = [sin_cache]
|
||
else:
|
||
if is_decode:
|
||
positions = [i - 1 for i in attn_metadata.seq_lens]
|
||
cos_cache = [rotary_emb.cos_cache[i:i+1, ...] for i in positions]
|
||
sin_cache = [rotary_emb.sin_cache[i:i+1, ...] for i in positions]
|
||
else:
|
||
cos_cache = [rotary_emb.cos_cache[:i, ...] for i in attn_metadata.seq_lens]
|
||
sin_cache = [rotary_emb.sin_cache[:i, ...] for i in attn_metadata.seq_lens]
|
||
return cos_cache, sin_cache
|
||
|
||
class Qwen2MLP(nn.Module):
|
||
def forward(self, x):
|
||
#TODO for other quant_method
|
||
if hasattr(self.gate_up_proj, 'quant_method') and isinstance(self.gate_up_proj.quant_method, GPTQLinearMethod):
|
||
# bit = self.gate_up_proj.quant_method.quant_config.weight_bits
|
||
# pack_num = torch.iinfo(self.gate_up_proj.qweight.dtype).bits // bit
|
||
|
||
# w13_group_size = pack_num * self.gate_up_proj.qweight.shape[1] // self.gate_up_proj.scales.shape[1]
|
||
# w2_group_size = pack_num * self.down_proj.qweight.shape[1] // self.down_proj.scales.shape[1]
|
||
tp_group = get_tp_group()
|
||
batch_size = x.shape[0]
|
||
|
||
if batch_size > 4:
|
||
y = torch.vacc.fuse_mlp_qwen_int4(
|
||
x,
|
||
self.gate_up_proj.qweight,
|
||
self.down_proj.qweight,
|
||
self.gate_up_proj.scales,
|
||
self.down_proj.scales,
|
||
self.gate_up_proj.qzeros,
|
||
self.down_proj.qzeros,
|
||
[1, self.w13_dequant_block_size],
|
||
[1, self.w2_dequant_block_size]
|
||
)
|
||
mlp_out = tensor_model_parallel_all_reduce(y)
|
||
return mlp_out
|
||
else:
|
||
mlp_out = torch.vacc.fuse_mlp_qwen_int4_reduce(
|
||
x,
|
||
self.gate_up_proj.qweight,
|
||
self.down_proj.qweight,
|
||
self.gate_up_proj.scales,
|
||
self.down_proj.scales,
|
||
self.gate_up_proj.qzeros,
|
||
self.down_proj.qzeros,
|
||
[1, self.w13_dequant_block_size],
|
||
[1, self.w2_dequant_block_size],
|
||
world_size = tp_group.world_size,
|
||
rank = tp_group.rank_in_group,
|
||
group_id = tp_group.group_id,
|
||
dev_info = tp_group.rank_device_infos
|
||
)
|
||
return mlp_out
|
||
|
||
elif hasattr(self.gate_up_proj, 'quant_method') and isinstance(self.gate_up_proj.quant_method, UnquantizedLinearMethod):
|
||
batch_size = x.shape[0]
|
||
if batch_size <= 4:
|
||
hidden_states = torch.vacc.fuse_mlp_qwen_fp16_bf16_reduce(x.view(-1, x.shape[-1]),
|
||
self.gate_up_proj.weight,
|
||
self.down_proj.weight,
|
||
world_size = get_tp_group().world_size,
|
||
rank = get_tp_group().rank_in_group,
|
||
group_id = get_tp_group().group_id,
|
||
dev_info = get_tp_group().rank_device_infos
|
||
)
|
||
return hidden_states
|
||
|
||
gate_up, _ = self.gate_up_proj(x)
|
||
x = self.act_fn(gate_up)
|
||
x, _ = self.down_proj(x)
|
||
return x
|
||
|
||
|
||
class Qwen2Attention(nn.Module):
|
||
|
||
def forward(
|
||
self,
|
||
positions: torch.Tensor,
|
||
hidden_states: torch.Tensor,
|
||
residual: Optional[torch.Tensor] = None, # Added for fused kernel
|
||
cos_cache: list[torch.Tensor] = None,
|
||
sin_cache: list[torch.Tensor] = None,
|
||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||
|
||
forward_context: ForwardContext = get_forward_context()
|
||
attn_metadata = forward_context.attn_metadata
|
||
if isinstance(attn_metadata, dict):
|
||
attn_metadata = attn_metadata[self.attn.layer_name]
|
||
kv_cache = self.attn.kv_cache[forward_context.virtual_engine]
|
||
is_decode = attn_metadata.prefill_metadata is None
|
||
# print('attn_metadata', attn_metadata)
|
||
use_gptq = hasattr(self.qkv_proj, 'quant_method') and isinstance(self.qkv_proj.quant_method, GPTQLinearMethod)
|
||
use_unquan = hasattr(self.qkv_proj, 'quant_method') and isinstance(self.qkv_proj.quant_method, UnquantizedLinearMethod)
|
||
# if USE_FUSED_QWEN_ATTENTION and not is_decode:
|
||
if use_gptq:
|
||
|
||
# Determine if we should reduce the result inside the kernel
|
||
# print(f"Qwen2Attention forward, is_decode: {is_decode}, kv_cache shape: {kv_cache.shape}, attn_metadata.seq_lens: {attn_metadata.seq_lens}")
|
||
kv_cache = kv_cache.view(2, -1, 16, self.total_num_kv_heads // get_tp_group().world_size, self.head_dim)
|
||
|
||
reduce_result = is_decode
|
||
# reduce_result = False
|
||
# total_bytes = hidden_states.numel() * hidden_states.element_size() * get_tp_group().world_size
|
||
# if total_bytes < 4194304: # Heuristic from qwen3_moe_vacc
|
||
# reduce_result = True
|
||
|
||
# Get rotary caches
|
||
# cos_cache = []
|
||
# sin_cache = []
|
||
|
||
# start = 0
|
||
# for seq_len in attn_metadata.seq_lens:
|
||
# if not is_decode:
|
||
# end = start + seq_len
|
||
# else:
|
||
# end = start + 1
|
||
# # NOTE:
|
||
# # in prefill stage, the value from positions[start] to positions[end] should be always contiguous
|
||
# cos_cache.append(self.rotary_emb.cos_cache[positions[start]:positions[end-1] + 1, ...])
|
||
# sin_cache.append(self.rotary_emb.sin_cache[positions[start]:positions[end-1] + 1, ...])
|
||
# start = end
|
||
if is_decode:
|
||
positions = [i - 1 for i in attn_metadata.seq_lens]
|
||
cos_cache = [self.rotary_emb.cos_cache[i:i+1, ...] for i in positions]
|
||
sin_cache = [self.rotary_emb.sin_cache[i:i+1, ...] for i in positions]
|
||
else:
|
||
cos_cache = [self.rotary_emb.cos_cache[:i, ...] for i in attn_metadata.seq_lens]
|
||
sin_cache = [self.rotary_emb.sin_cache[:i, ...] for i in attn_metadata.seq_lens]
|
||
|
||
if residual is None:
|
||
res_out = hidden_states
|
||
|
||
if USE_FUSED_QWEN_ATTENTION:
|
||
qkv_proj_qzeros = torch.Tensor()
|
||
o_proj_bias = torch.Tensor()
|
||
o_proj_qzeros = torch.Tensor()
|
||
qkv_proj_bias = self.fused_params['qkv_proj_bias']
|
||
# if is_decode:
|
||
# qkv_proj_bias = self.fused_params['qkv_proj_bias']
|
||
# if hidden_states.shape[0] > 1:
|
||
# qkv_proj_bias = qkv_proj_bias.repeat(len(attn_metadata.seq_lens), 1)
|
||
# else:
|
||
# qkv_proj_bias = self.fused_params['qkv_proj_bias']
|
||
# qkv_proj_bias = qkv_proj_bias.repeat(sum(attn_metadata.seq_lens), 1)
|
||
|
||
# print("Qwen2Attention forward, is_decode:", is_decode,
|
||
# "kv_cache shape:", kv_cache.shape,
|
||
# "attn_metadata.seq_lens:", attn_metadata.seq_lens,
|
||
# "qkv_proj_bias shape:", qkv_proj_bias.shape,
|
||
# "qkv_proj_bias:", qkv_proj_bias,
|
||
# "qkv_proj_qzeros shape:", qkv_proj_qzeros.shape,
|
||
# "o_proj_bias shape:", o_proj_bias.shape,
|
||
# "o_proj_bias:", o_proj_bias,
|
||
# "o_proj_qzeros shape:", o_proj_qzeros.shape,
|
||
# "o_proj_qzeros:", o_proj_qzeros,
|
||
# "reduce_result:", reduce_result,
|
||
# "world_size:", get_tp_group().world_size,
|
||
# "rank:", get_tp_group().rank_in_group,
|
||
# "group_id:", get_tp_group().group_id,
|
||
# "dev_info:", get_tp_group().rank_device_infos)
|
||
# print("residual shape:", residual.shape if residual is not None else "None",
|
||
# "hidden_states shape:", hidden_states.shape,
|
||
# "hidden_states dtype:", hidden_states.dtype,
|
||
# "hidden_states device:", hidden_states.device,
|
||
# "attn_metadata.slot_mapping shape:", attn_metadata.slot_mapping.shape if attn_metadata.slot_mapping is not None else "None",
|
||
# "attn_metadata.block_tables shape:", attn_metadata.block_tables.shape if attn_metadata.block_tables is not None else "None",
|
||
# "attn_metadata.block_tables dtype:", attn_metadata.block_tables.dtype if attn_metadata.block_tables is not None else "None",
|
||
# "attn_metadata.block_tables device:", attn_metadata.block_tables.device if attn_metadata.block_tables is not None else "None",
|
||
# "attn_metadata.seq_lens:", attn_metadata.seq_lens,
|
||
# "attn_metadata.seq_lens length:", len(attn_metadata.seq_lens),
|
||
# "attn_metadata.seq_lens dtype:", type(attn_metadata.seq_lens),
|
||
# # "attn_metadata.seq_lens device:", attn_metadata.seq_lens[0].device if attn_metadata.seq_lens else "None",
|
||
# "attn_metadata.seq_lens type:", type(attn_metadata.seq_lens[0]) if attn_metadata.seq_lens else "None",
|
||
# "attn_metadata.seq_lens content:", attn_metadata.seq_lens if attn_metadata.seq_lens else "None",
|
||
# "attn_metadata.slot_mapping dtype:", attn_metadata.slot_mapping.dtype if attn_metadata.slot_mapping is not None else "None",
|
||
# "attn_metadata.slot_mapping device:", attn_metadata.slot_mapping.device if attn_metadata.slot_mapping is not None else "None",
|
||
# "attn_metadata.slot_mapping type:", type(attn_metadata.slot_mapping) if attn_metadata.slot_mapping is not None else "None",
|
||
# "attn_metadata.slot_mapping content:", attn_metadata.slot_mapping if attn_metadata.slot_mapping is not None else "None",
|
||
# "self.fused_params['input_layernorm_weight'] shape:", self.fused_params['input_layernorm_weight'].shape if 'input_layernorm_weight' in self.fused_params else "None",
|
||
# "self.fused_params['input_layernorm_weight'] dtype:", self.fused_params['input_layernorm_weight'].dtype if 'input_layernorm_weight' in self.fused_params else "None",
|
||
# "self.fused_params['input_layernorm_weight'] device:", self.fused_params['input_layernorm_weight'].device if 'input_layernorm_weight' in self.fused_params else "None",
|
||
# "self.fused_params['qkv_proj_weight'] shape:", self.fused_params['qkv_proj_weight'].shape if 'qkv_proj_weight' in self.fused_params else "None",
|
||
# "self.fused_params['qkv_proj_weight'] dtype:", self.fused_params['qkv_proj_weight'].dtype if 'qkv_proj_weight' in self.fused_params else "None",
|
||
# "self.fused_params['qkv_proj_weight'] device:", self.fused_params['qkv_proj_weight'].device if 'qkv_proj_weight' in self.fused_params else "None",
|
||
# "self.fused_params['qkv_proj_weight_scale'] shape:", self.fused_params['qkv_proj_weight_scale'].shape if 'qkv_proj_weight_scale' in self.fused_params else "None",
|
||
# "self.fused_params['qkv_proj_weight_scale'] dtype:", self.fused_params['qkv_proj_weight_scale'].dtype if 'qkv_proj_weight_scale' in self.fused_params else "None",
|
||
# "self.fused_params['qkv_proj_weight_scale'] device:", self.fused_params['qkv_proj_weight_scale'].device if 'qkv_proj_weight_scale' in self.fused_params else "None",
|
||
# "self.fused_params['qkv_proj_bias'] shape:", self.fused_params['qkv_proj_bias'].shape if 'qkv_proj_bias' in self.fused_params else "None",
|
||
# "self.fused_params['qkv_proj_bias'] dtype:", self.fused_params['qkv_proj_bias'].dtype if 'qkv_proj_bias' in self.fused_params else "None",
|
||
# "self.fused_params['qkv_proj_bias'] device:", self.fused_params['qkv_proj_bias'].device if 'qkv_proj_bias' in self.fused_params else "None",
|
||
# "self.fused_params['qkv_proj_qzeros'] shape:", self.fused_params['qkv_proj_qzeros'].shape if 'qkv_proj_qzeros' in self.fused_params else "None",
|
||
# "self.fused_params['qkv_proj_qzeros'] dtype:", self.fused_params['qkv_proj_qzeros'].dtype if 'qkv_proj_qzeros' in self.fused_params else "None",
|
||
# "self.fused_params['qkv_proj_qzeros'] device:", self.fused_params['qkv_proj_qzeros'].device if 'qkv_proj_qzeros' in self.fused_params else "None",
|
||
# "self.fused_params['o_proj_weight'] shape:", self.fused_params['o_proj_weight'].shape if 'o_proj_weight' in self.fused_params else "None",
|
||
# "self.fused_params['o_proj_weight'] dtype:", self.fused_params['o_proj_weight'].dtype if 'o_proj_weight' in self.fused_params else "None",
|
||
# "self.fused_params['o_proj_weight'] device:", self.fused_params['o_proj_weight'].device if 'o_proj_weight' in self.fused_params else "None",
|
||
# "self.fused_params['o_proj_weight_scale'] shape:", self.fused_params['o_proj_weight_scale'].shape if 'o_proj_weight_scale' in self.fused_params else "None",
|
||
# "self.fused_params['o_proj_weight_scale'] dtype:", self.fused_params['o_proj_weight_scale'].dtype if 'o_proj_weight_scale' in self.fused_params else "None",
|
||
# "self.fused_params['o_proj_weight_scale'] device:", self.fused_params['o_proj_weight_scale'].device if 'o_proj_weight_scale' in self.fused_params else "None",
|
||
# # "self.fused_params['o_proj_bias'] shape:", self.fused_params['o_proj_bias'].shape if 'o_proj_bias' in self.fused_params else "None",
|
||
# # "self.fused_params['o_proj_bias'] dtype:", self.fused_params['o_proj_bias'].dtype if 'o_proj_bias' in self.fused_params else "None",
|
||
# # "self.fused_params['o_proj_bias'] device:", self.fused_params['o_proj_bias'].device if 'o_proj_bias' in self.fused_params else "None",
|
||
# # "self.fused_params['o_proj_qzeros'] shape:", self.fused_params['o_proj_qzeros'].shape if 'o_proj_qzeros' in self.fused_params else "None",
|
||
# # "self.fused_params['o_proj_qzeros'] dtype:", self.fused_params['o_proj_qzeros'].dtype if 'o_proj_qzeros' in self.fused_params else "None",
|
||
# # "self.fused_params['o_proj_qzeros'] device:", self.fused_params['o_proj_qzeros'].device if 'o_proj_qzeros' in self.fused_params else "None",
|
||
# "sin_cache length:", len(sin_cache),
|
||
# "cos_cache length:", len(cos_cache),
|
||
# "sin_cache[0] shape:", sin_cache[0].shape if sin_cache else "None",
|
||
# "cos_cache[0] shape:", cos_cache[0].shape if cos_cache else "None",
|
||
# "sin_cache[0] dtype:", sin_cache[0].dtype if sin_cache else "None",
|
||
# "cos_cache[0] dtype:", cos_cache[0].dtype if cos_cache else "None",
|
||
# "sin_cache[0] device:", sin_cache[0].device if sin_cache else "None",
|
||
# "cos_cache[0] device:", cos_cache[0].device if cos_cache else "None",
|
||
# "kv_cache shape:", kv_cache.shape,
|
||
# "kv_cache dtype:", kv_cache.dtype,
|
||
# "kv_cache device:", kv_cache.device,
|
||
# "self.scaling:", self.scaling,
|
||
# "self.total_num_heads:", self.total_num_heads,
|
||
# "self.total_num_kv_heads:", self.total_num_kv_heads,
|
||
# "is_decode:", is_decode,
|
||
# "reduce_result:", reduce_result
|
||
# )
|
||
|
||
# Call the naive fused kernel
|
||
####################################call fusion kernel######################################
|
||
# print("starting fused attention kernel......................................, is_decode:", is_decode)
|
||
# print('hidden_states', hidden_states.shape)
|
||
# print('attn_metadata.slot_mapping,', attn_metadata.slot_mapping,)
|
||
# print('attn_metadata.block_tables,', attn_metadata.block_tables,)
|
||
# print('attn_metadata.seq_lens,', attn_metadata.seq_lens,)
|
||
# print('is_decode,', is_decode)
|
||
attn_outs = torch.vacc.fuse_atten_qwen2(
|
||
hidden_states,
|
||
residual,
|
||
self.fused_params['input_layernorm_weight'],
|
||
self.fused_params['qkv_proj_weight'],
|
||
self.fused_params['qkv_proj_weight_scale'],
|
||
qkv_proj_bias,
|
||
qkv_proj_qzeros, #self.fused_params['qkv_proj_qzeros'],
|
||
sin_cache,
|
||
cos_cache,
|
||
attn_metadata.slot_mapping,
|
||
kv_cache,
|
||
attn_metadata.block_tables,
|
||
env_blk_grp_size,
|
||
self.fused_params['o_proj_weight'],
|
||
self.fused_params['o_proj_weight_scale'],
|
||
o_proj_bias, #self.fused_params['o_proj_bias'],
|
||
o_proj_qzeros, #self.fused_params['o_proj_qzeros'],
|
||
attn_metadata.seq_lens,
|
||
self.scaling,
|
||
self.total_num_heads,
|
||
self.total_num_kv_heads,
|
||
is_decode,
|
||
is_decode,
|
||
reduce_result,
|
||
world_size=get_tp_group().world_size,
|
||
rank=get_tp_group().rank_in_group,
|
||
group_id=get_tp_group().group_id,
|
||
dev_info=get_tp_group().rank_device_infos)
|
||
# print("finished fused attention kernel......................................, is_decode:", is_decode)
|
||
# attn_outs1 = vacc_fused_attn_qwen2_naive(
|
||
# hidden_states=hidden_states,
|
||
# residual=residual,
|
||
# hidden_states_norm_weight=self.fused_params['input_layernorm_weight'],
|
||
# ###ignore the group size, it is not used in the fused kernel
|
||
# # qkv_proj_group_size=get_gptq_group_size(self.qkv_proj),
|
||
# # o_proj_group_size=get_gptq_group_size(self.o_proj),
|
||
# ###
|
||
# qkv_proj_weight=self.fused_params['qkv_proj_weight'],
|
||
# qkv_proj_weight_scale=self.fused_params['qkv_proj_weight_scale'],
|
||
# qkv_proj_bias=self.fused_params['qkv_proj_bias'],
|
||
# qkv_proj_qzeros=self.fused_params['qkv_proj_qzeros'],
|
||
# sin_cache=sin_cache,
|
||
# cos_cache=cos_cache,
|
||
# slot_mapping=attn_metadata.slot_mapping,
|
||
# kv_cache=kv_cache,
|
||
# block_tables=attn_metadata.block_tables,
|
||
# block_group_size=env_blk_grp_size,
|
||
# o_proj_weight=self.fused_params['o_proj_weight'],
|
||
# o_proj_weight_scale=self.fused_params['o_proj_weight_scale'],
|
||
# o_proj_bias=self.fused_params['o_proj_bias'],
|
||
# o_proj_qzeros=self.fused_params['o_proj_qzeros'],
|
||
# seq_lens=attn_metadata.seq_lens,
|
||
# sm_scale=self.scaling,
|
||
# num_attention_heads=self.total_num_heads,
|
||
# num_key_value_heads=self.total_num_kv_heads,
|
||
# # head_dim=self.head_dim,
|
||
# is_decode=is_decode,
|
||
# flash_attention=is_decode,
|
||
# reduce_result=reduce_result,
|
||
# world_size=get_tp_group().world_size,
|
||
# rank=get_tp_group().rank_in_group,
|
||
# group_id=get_tp_group().group_id,
|
||
# dev_info=get_tp_group().rank_device_infos
|
||
# )
|
||
|
||
|
||
# if residual is None:
|
||
# has_nan(attn_outs1, attn_outs)
|
||
# ret = torch.isnan(attn_outs).any()
|
||
# if torch.isinf(attn_outs).any() or torch.isnan(attn_outs).any():
|
||
# dump_qwen2_attention_params(hidden_states=hidden_states,
|
||
# # output=output_,
|
||
# residual=residual,
|
||
# fused_params=self.fused_params,
|
||
# attn_metadata=attn_metadata,
|
||
# is_decode=is_decode,
|
||
# sin_cache=sin_cache,
|
||
# cos_cache=cos_cache,
|
||
# kv_cache=kv_cache,
|
||
# env_blk_grp_size=env_blk_grp_size,
|
||
# scaling=self.scaling,
|
||
# total_num_heads=self.total_num_heads,
|
||
# total_num_kv_heads=self.total_num_kv_heads,
|
||
# world_size=get_tp_group().world_size,
|
||
# rank=get_tp_group().rank_in_group,
|
||
# group_id=get_tp_group().group_id,
|
||
# dev_info=get_tp_group().rank_device_infos)
|
||
# cos_sim = torch.cosine_similarity(attn_outs1.cpu().reshape([-1]).float(), attn_outs.cpu().reshape([-1]).float(), dim=-1)
|
||
# print("cos_sim witout residual:",cos_sim )
|
||
# else :
|
||
# has_nan(attn_outs1[0], attn_outs[0])
|
||
# ret = torch.isnan(attn_outs[0]).any()
|
||
# if ret:
|
||
# dump_qwen2_attention_params(hidden_states=hidden_states,
|
||
# # output=output_,
|
||
# residual=residual,
|
||
# fused_params=self.fused_params,
|
||
# attn_metadata=attn_metadata,
|
||
# is_decode=is_decode,
|
||
# sin_cache=sin_cache,
|
||
# cos_cache=cos_cache,
|
||
# kv_cache=kv_cache,
|
||
# env_blk_grp_size=env_blk_grp_size,
|
||
# scaling=self.scaling,
|
||
# total_num_heads=self.total_num_heads,
|
||
# total_num_kv_heads=self.total_num_kv_heads,
|
||
# world_size=get_tp_group().world_size,
|
||
# rank=get_tp_group().rank_in_group,
|
||
# group_id=get_tp_group().group_id,
|
||
# dev_info=get_tp_group().rank_device_infos)
|
||
# cos_sim = torch.cosine_similarity(attn_outs1[0].cpu().reshape([-1]).float(), attn_outs[0].cpu().reshape([-1]).float(), dim=-1)
|
||
# print("cos_sim residual:",cos_sim )
|
||
else:
|
||
##########################################################################
|
||
# dump_qwen2_attention_params(hidden_states=hidden_states,
|
||
# # output=output_,
|
||
# residual=residual,
|
||
# fused_params=self.fused_params,
|
||
# attn_metadata=attn_metadata,
|
||
# is_decode=is_decode,
|
||
# sin_cache=sin_cache,
|
||
# cos_cache=cos_cache,
|
||
# kv_cache=kv_cache,
|
||
# env_blk_grp_size=env_blk_grp_size,
|
||
# scaling=self.scaling,
|
||
# total_num_heads=self.total_num_heads,
|
||
# total_num_kv_heads=self.total_num_kv_heads,
|
||
# world_size=get_tp_group().world_size,
|
||
# rank=get_tp_group().rank_in_group,
|
||
# group_id=get_tp_group().group_id,
|
||
# dev_info=get_tp_group().rank_device_infos)
|
||
|
||
attn_outs = vacc_fused_attn_qwen2_naive(
|
||
hidden_states=hidden_states,
|
||
residual=residual,
|
||
hidden_states_norm_weight=self.fused_params['input_layernorm_weight'],
|
||
###ignore the group size, it is not used in the fused kernel
|
||
# qkv_proj_group_size=get_gptq_group_size(self.qkv_proj),
|
||
# o_proj_group_size=get_gptq_group_size(self.o_proj),
|
||
###
|
||
qkv_proj_weight=self.fused_params['qkv_proj_weight'],
|
||
qkv_proj_weight_scale=self.fused_params['qkv_proj_weight_scale'],
|
||
qkv_proj_bias=self.fused_params['qkv_proj_bias'],
|
||
qkv_proj_qzeros=None,
|
||
sin_cache=sin_cache,
|
||
cos_cache=cos_cache,
|
||
slot_mapping=attn_metadata.slot_mapping,
|
||
kv_cache=kv_cache,
|
||
block_tables=attn_metadata.block_tables,
|
||
block_group_size=env_blk_grp_size,
|
||
o_proj_weight=self.fused_params['o_proj_weight'],
|
||
o_proj_weight_scale=self.fused_params['o_proj_weight_scale'],
|
||
o_proj_bias=None,
|
||
o_proj_qzeros=None,
|
||
seq_lens=attn_metadata.seq_lens,
|
||
sm_scale=self.scaling,
|
||
num_attention_heads=self.total_num_heads,
|
||
num_key_value_heads=self.total_num_kv_heads,
|
||
# head_dim=self.head_dim,
|
||
is_decode=is_decode,
|
||
flash_attention=is_decode,
|
||
reduce_result=reduce_result,
|
||
world_size=get_tp_group().world_size,
|
||
rank=get_tp_group().rank_in_group,
|
||
group_id=get_tp_group().group_id,
|
||
dev_info=get_tp_group().rank_device_infos
|
||
)
|
||
|
||
|
||
# if is_decode:
|
||
# assert(1<0)
|
||
# if residual is None:
|
||
# output_ = attn_outs
|
||
# else:
|
||
# output_ = attn_outs[0]
|
||
# print("starting tensor_model_parallel_all_reduce......................................, is_decode:", is_decode)
|
||
if residual is None:
|
||
attn_out = tensor_model_parallel_all_reduce(attn_outs) if not reduce_result else attn_outs
|
||
else:
|
||
res_out = attn_outs[1]
|
||
attn_out = tensor_model_parallel_all_reduce(attn_outs[0]) if not reduce_result else attn_outs[0]
|
||
# print("finished tensor_model_parallel_all_reduce......................................, is_decode:", is_decode)
|
||
return attn_out, res_out
|
||
elif use_unquan:
|
||
# Determine if we should reduce the result inside the kernel
|
||
# print(f"Qwen2Attention forward, is_decode: {is_decode}, kv_cache shape: {kv_cache.shape}, attn_metadata.seq_lens: {attn_metadata.seq_lens}")
|
||
kv_cache = kv_cache.view(2, -1, 16, self.total_num_kv_heads // get_tp_group().world_size, self.head_dim)
|
||
|
||
reduce_result = is_decode
|
||
|
||
if cos_cache is None or sin_cache is None:
|
||
cos_cache, sin_cache = get_cos_sin_cache(self.rotary_emb, attn_metadata, positions, is_decode)
|
||
|
||
if residual is None:
|
||
res_out = hidden_states
|
||
|
||
if USE_FUSED_QWEN_ATTENTION:
|
||
qkv_proj_qzeros = torch.Tensor()
|
||
o_proj_bias = torch.Tensor()
|
||
o_proj_qzeros = torch.Tensor()
|
||
|
||
# Call the naive fused kernel
|
||
####################################call fusion kernel######################################
|
||
# print("starting fused attention kernel......................................, is_decode:", is_decode)
|
||
# print('hidden_states', hidden_states.shape)
|
||
# print('attn_metadata.slot_mapping,', attn_metadata.slot_mapping,)
|
||
# print('attn_metadata.block_tables,', attn_metadata.block_tables,)
|
||
# print('attn_metadata.seq_lens,', attn_metadata.seq_lens,)
|
||
attn_outs = torch.vacc.fuse_atten_qwen2(
|
||
hidden_states,
|
||
residual,
|
||
self.fused_params['input_layernorm_weight'],
|
||
self.fused_params['qkv_proj_weight'],
|
||
self.fused_params['qkv_proj_weight_scale'],
|
||
# qkv_proj_bias,
|
||
self.fused_params['qkv_proj_bias'],
|
||
qkv_proj_qzeros, #self.fused_params['qkv_proj_qzeros'],
|
||
sin_cache,
|
||
cos_cache,
|
||
attn_metadata.slot_mapping,
|
||
kv_cache,
|
||
attn_metadata.block_tables,
|
||
env_blk_grp_size,
|
||
self.fused_params['o_proj_weight'],
|
||
self.fused_params['o_proj_weight_scale'],
|
||
o_proj_bias, #self.fused_params['o_proj_bias'],
|
||
o_proj_qzeros, #self.fused_params['o_proj_qzeros'],
|
||
attn_metadata.seq_lens,
|
||
self.scaling,
|
||
self.total_num_heads,
|
||
self.total_num_kv_heads,
|
||
is_decode,
|
||
is_decode,
|
||
reduce_result,
|
||
world_size=get_tp_group().world_size,
|
||
rank=get_tp_group().rank_in_group,
|
||
group_id=get_tp_group().group_id,
|
||
dev_info=get_tp_group().rank_device_infos)
|
||
else:
|
||
attn_outs = vacc_fused_attn_qwen2_naive(
|
||
hidden_states=hidden_states,
|
||
residual=residual,
|
||
hidden_states_norm_weight=self.fused_params['input_layernorm_weight'],
|
||
###ignore the group size, it is not used in the fused kernel
|
||
# qkv_proj_group_size=get_gptq_group_size(self.qkv_proj),
|
||
# o_proj_group_size=get_gptq_group_size(self.o_proj),
|
||
###
|
||
qkv_proj_weight=self.fused_params['qkv_proj_weight'],
|
||
qkv_proj_weight_scale=self.fused_params['qkv_proj_weight_scale'],
|
||
qkv_proj_bias=self.fused_params['qkv_proj_bias'],
|
||
qkv_proj_qzeros=None,
|
||
sin_cache=sin_cache,
|
||
cos_cache=cos_cache,
|
||
slot_mapping=attn_metadata.slot_mapping,
|
||
kv_cache=kv_cache,
|
||
block_tables=attn_metadata.block_tables,
|
||
block_group_size=env_blk_grp_size,
|
||
o_proj_weight=self.fused_params['o_proj_weight'],
|
||
o_proj_weight_scale=self.fused_params['o_proj_weight_scale'],
|
||
o_proj_bias=None,
|
||
o_proj_qzeros=None,
|
||
seq_lens=attn_metadata.seq_lens,
|
||
sm_scale=self.scaling,
|
||
num_attention_heads=self.total_num_heads,
|
||
num_key_value_heads=self.total_num_kv_heads,
|
||
# head_dim=self.head_dim,
|
||
is_decode=is_decode,
|
||
flash_attention=is_decode,
|
||
reduce_result=reduce_result,
|
||
world_size=get_tp_group().world_size,
|
||
rank=get_tp_group().rank_in_group,
|
||
group_id=get_tp_group().group_id,
|
||
dev_info=get_tp_group().rank_device_infos
|
||
)
|
||
|
||
# if is_decode:
|
||
# assert(1<0)
|
||
# if residual is None:
|
||
# output_ = attn_outs
|
||
# else:
|
||
# output_ = attn_outs[0]
|
||
# print("starting tensor_model_parallel_all_reduce......................................, is_decode:", is_decode)
|
||
if residual is None:
|
||
attn_out = tensor_model_parallel_all_reduce(attn_outs) if not reduce_result else attn_outs
|
||
# print(f"--------------------------attn_outs :{attn_out}")
|
||
# print(f"--------------------------attn_outs shape :{attn_out.shape}")
|
||
# print(f"--------------------------res_out :{res_out}")
|
||
else:
|
||
res_out = attn_outs[1]
|
||
attn_out = tensor_model_parallel_all_reduce(attn_outs[0]) if not reduce_result else attn_outs[0]
|
||
# print(f"--------------------------attn_outs :{attn_out}")
|
||
# print(f"--------------------------attn_outs shape :{attn_out.shape}")
|
||
# print(f"--------------------------res_out :{res_out}")
|
||
# print("finished tensor_model_parallel_all_reduce......................................, is_decode:", is_decode)
|
||
|
||
return attn_out, res_out
|
||
else:
|
||
# Fallback to original implementation
|
||
# save_tensor_pth("hidden_states", hidden_states)
|
||
qkv, _ = self.qkv_proj(hidden_states)
|
||
# save_tensor_pth("qkv", qkv)
|
||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||
# save_tensor_pth("q", q)
|
||
# save_tensor_pth("k", k)
|
||
# save_tensor_pth("v", v)
|
||
q, k = self.rotary_emb(positions, q, k)
|
||
# save_tensor_pth("q_rot", q)
|
||
# save_tensor_pth("k_rot", k)
|
||
attn_output = self.attn(q, k, v)
|
||
# save_tensor_pth("attn_output", attn_output)
|
||
output, _ = self.o_proj(attn_output)
|
||
# save_tensor_pth("output", output)
|
||
return output
|
||
|
||
|
||
class Qwen2DecoderLayer(nn.Module):
|
||
|
||
def forward(
|
||
self,
|
||
positions: torch.Tensor,
|
||
hidden_states: torch.Tensor,
|
||
residual: Optional[torch.Tensor],
|
||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
# forward_context: ForwardContext = get_forward_context()
|
||
# attn_metadata = forward_context.attn_metadata
|
||
# kv_cache = self.attn.kv_cache[forward_context.virtual_engine]
|
||
# is_decode = attn_metadata.prefill_metadata is None
|
||
use_gptq = hasattr(self.self_attn.qkv_proj, 'quant_method') and isinstance(self.self_attn.qkv_proj.quant_method, GPTQLinearMethod)
|
||
use_unquan = hasattr(self.self_attn.qkv_proj, 'quant_method') and isinstance(self.self_attn.qkv_proj.quant_method, UnquantizedLinearMethod)
|
||
# if USE_FUSED_QWEN_ATTENTION and not is_decode:
|
||
if use_gptq or use_unquan:
|
||
# Self Attention with fused pre-layernorm
|
||
# Pass the layernorm weight to the attention layer for the fused kernel
|
||
self.self_attn.input_layernorm_weight = self.input_layernorm.weight
|
||
self.self_attn.fused_params = {}
|
||
self.self_attn.fused_params['input_layernorm_weight'] = self.input_layernorm.weight
|
||
set_fused_params(self.self_attn.fused_params, self.self_attn.qkv_proj.quant_method, self.self_attn.qkv_proj, 'qkv_proj')
|
||
set_fused_params(self.self_attn.fused_params, self.self_attn.o_proj.quant_method, self.self_attn.o_proj, 'o_proj')
|
||
hidden_states, residual = self.self_attn(
|
||
positions=positions,
|
||
hidden_states=hidden_states,
|
||
residual=residual,
|
||
)
|
||
else:
|
||
# Original non-fused path
|
||
if residual is None:
|
||
residual = hidden_states
|
||
hidden_states = self.input_layernorm(hidden_states)
|
||
else:
|
||
hidden_states, residual = self.input_layernorm(
|
||
hidden_states, residual)
|
||
hidden_states = self.self_attn(
|
||
positions=positions,
|
||
hidden_states=hidden_states,
|
||
)
|
||
|
||
# Fully Connected (MLP) part remains the same
|
||
hidden_states, residual = self.post_attention_layernorm(
|
||
hidden_states, residual)
|
||
# save_tensor_pth("post_attention_layernorm_hidden_states", hidden_states)
|
||
# save_tensor_pth("post_attention_layernorm_residual", residual)
|
||
|
||
hidden_states = self.mlp(hidden_states)
|
||
return hidden_states, residual |