# SPDX-License-Identifier: Apache-2.0 # Adapted from # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py # Copyright 2024 The Qwen team. # Copyright 2023 The vLLM team. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2 model compatible with HuggingFace weights, with VACC modifications.""" from typing import Iterable, Optional, Set, Tuple, Union, List, Any, Dict import math import os import torch import numpy as np from torch import nn from transformers import Qwen2Config from vllm.attention import Attention, AttentionType from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, get_tp_group) from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.forward_context import ForwardContext, get_forward_context # Import original classes to be modified from vllm.model_executor.models.qwen2 import (Qwen2Attention as Qwen2AttentionOrig, Qwen2DecoderLayer as Qwen2DecoderLayerOrig, Qwen2MLP as Qwen2MLPOrig) from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod from vllm.model_executor.layers.quantization.awq import AWQLinearMethod from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm_vacc.vllm.model_executor.layers.layernorm import RMSNorm_forward_vacc from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding, apply_interleaved_rope from ..ops.mrope_op import get_sin_cos_mrope import copy logger = init_logger(__name__) # --- VACC specific additions --- from vllm_vacc.vllm.model_executor.models.vars import TRANSPOSE_GPTQ_WEIGHT, USE_FUSED_QWEN_ATTENTION from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size def save_tensor_pth(name, tensor): return import time if tensor is None: return process_id = os.getpid() timestamp = time.strftime("%Y-%m-%d_%H", time.localtime()) pathdir = f'./dump_qwen_pth/{timestamp}/process_{process_id}' if not os.path.exists(pathdir): os.makedirs(pathdir) counter = 0 filename = f"{name}_{counter}.pth" while os.path.exists(os.path.join(pathdir, filename)): filename = f"{name}_{counter}.pth" counter += 1 file_path = os.path.join(pathdir, filename) if isinstance(tensor, torch.Tensor): print(f'save {filename} {tensor.shape}') torch.save(tensor.cpu(), file_path) elif 'sin_cache' in name or 'cos_cache' in name: tensor_list = [i.cpu() for i in tensor] torch.save(tensor_list, file_path) else: torch.save(tensor, file_path) return # --- GPTQ Dequantization Logic (from gptq.py) --- def int32_to_int4(s0, axis = -2): # 要先拉平 shape[1, n] # 每个int32 拆成8个int4, 8个int32表示, 得到[8, n] # x32(int32) => 32bit => 4bit x 8 x4[8] 4bit # x32 31-28 => x4[7] # x32 27-24 => x4[6] # ... # x32 3-0 => x4[0] # x32[index=0] => x4[7,6,5,4,3,2,1,0] # 4bit转真实数字: # 不是按补码方式 # 1111 => 15 => 7 # 15-8 = 7 # 0101 => 6 =>-2 # 6-8 = -2 # 0x 6A CB 37 2B (内存中排列 2B 37 CB 6A) => B273BCA6 => (-8) => int4: 3, -6, -1, -5, 3, 4, 2, -2 # 内存中实际排布为小端模式: # int32: 2B 37 CB 6A => 2,11,3,7,12,11,6,10 => (-8) => -6,3, -5,-1, 4,3, -2,2 => 同一字节所在的两个交换得到 3, -6, -1, -5, 3, 4, 2, -2 # int4: 3, -6, -1, -5, 3, 4, 2, -2 s = s0.view(torch.uint32) all = [] for i in range(8): x = 15 << (i*4) # s2 = torch.bitwise_and(x,s) s2 = torch.from_numpy(np.bitwise_and(x, s.numpy())) s3 = s2 / (2 ** (i*4)) s4 = s3.to(torch.int32) # 补码, 结果不对 # s4[s4 > 7] = s4[s4 > 7]-16 # 直接 - 8 结果正确, 范围: -8-7 s4 = s4 - 8 all.append(s4.reshape(1,*s4.shape)) all = torch.concatenate(all, 0) if axis == -2 or axis == 0: # 8,K//8,N => K//8,8,N => K,N all = all.transpose(-2,0).reshape(-1,all.shape[-1]).contiguous() else: # 8,N,K//8 => N,K//8,8 => N,K all = all.permute(1,2,0).reshape(all.shape[-2],-1).contiguous() return all def dequant_weight(qw, scales, group_size = 128): N = qw.shape[1] int4_to_int32_axis = -2 if TRANSPOSE_GPTQ_WEIGHT: N = qw.shape[0] int4_to_int32_axis = -1 qweight = int32_to_int4(qw,int4_to_int32_axis).to(torch.float16) #int32 => 8 int4 +> fp16 if TRANSPOSE_GPTQ_WEIGHT: scales = scales.T.contiguous() qweight = qweight.T.contiguous() scales = torch.concatenate([scales] * group_size, 1).reshape(-1, N) # scale 按 group_size 扩展, 每 group_size 个数共用一个scale # print('qweight', qweight.shape, qweight.dtype) # print('scale', scales.shape, scales.dtype) dequant_weight = qweight * scales #dequant return dequant_weight def apply_gptq_linear( input_tensor: torch.Tensor, layer: Union[QKVParallelLinear, RowParallelLinear], ) -> torch.Tensor: """ Applies a GPTQ-quantized linear layer by dequantizing weights on-the-fly. """ out_shape = input_tensor.shape[:-1] + (layer.qweight.shape[-2 if TRANSPOSE_GPTQ_WEIGHT else -1], ) # M,N reshaped_x = input_tensor.reshape(-1, input_tensor.shape[-1]) # This assumes the linear_method is attached and is a GPTQ derivative linear_method = layer.quant_method quant_config = linear_method.quant_config # scale_k is a VACC-specific parameter for scale broadcasting scale_k = getattr(linear_method, "scale_k", 1) # Dequantize weight # weight = dequant_weight(layer.qweight.data, layer.scales.data, quant_config.group_size // scale_k) weight = dequant_weight(layer.qweight.cpu(), layer.scales.cpu(), quant_config.group_size // scale_k).to(layer.qweight.device) # Perform GEMM output = torch.matmul(reshaped_x, weight) # Add bias if it exists if hasattr(layer, 'bias') and layer.bias is not None: output = output + layer.bias # out_shape = input_tensor.shape[:-1] + (weight.shape[-1],) return output.reshape(out_shape) def get_gptq_group_size(layer: Union[QKVParallelLinear, RowParallelLinear]): # This assumes the linear_method is attached and is a GPTQ derivative linear_method = layer.quant_method quant_config = linear_method.quant_config # scale_k is a VACC-specific parameter for scale broadcasting scale_k = getattr(linear_method, "scale_k", 1) return quant_config.group_size // scale_k def apply_gptq_linear_( input_tensor: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, bias: Optional[torch.Tensor], qzeros: Optional[torch.Tensor], group_size_h: int, group_size_w: int, ) -> torch.Tensor: """ Applies a GPTQ-quantized linear layer by dequantizing weights on-the-fly. """ out_shape = input_tensor.shape[:-1] + (weight.shape[-2 if TRANSPOSE_GPTQ_WEIGHT else -1], ) # M,N reshaped_x = input_tensor.reshape(-1, input_tensor.shape[-1]) # This assumes the linear_method is attached and is a GPTQ derivative # linear_method = layer.quant_method # quant_config = linear_method.quant_config # # scale_k is a VACC-specific parameter for scale broadcasting # scale_k = getattr(linear_method, "scale_k", 1) # Dequantize weight # weight = dequant_weight(layer.qweight.data, layer.scales.data, quant_config.group_size // scale_k) # weight = dequant_weight(weight.cpu(), weight_scale.cpu(), group_size).to(weight.device) # Perform GEMM # output = torch.matmul(reshaped_x, weight) # print("entering apply_gptq_linear_, reshaped_x shape:", reshaped_x.shape, "reshaped_x stride", reshaped_x.stride(), "input_tensor", input_tensor.shape, "weight shape:", weight.shape, "weight_scale shape:", weight_scale.shape, "group_size:", group_size) output = torch.vacc.w4a8_block_int4_matmul( reshaped_x, weight.transpose(-1, -2), weight_scale.transpose(-1, -2), [group_size_h, group_size_w], ) # print("exiting apply_gptq_linear_, output shape:", output.shape) # Add bias if it exists if bias is not None: output = output + bias # out_shape = input_tensor.shape[:-1] + (weight.shape[-1],) return output.reshape(out_shape) def post_layernorm( x: torch.Tensor, residual: Optional[torch.Tensor] = None, weight: Optional[torch.Tensor] = None, variance_size_override: Optional[int] = None, variance_epsilon: float = 1e-6, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """PyTorch-native implementation equivalent to forward().""" orig_dtype = x.dtype x = x.to(torch.float32) if residual is not None: x = x + residual.to(torch.float32) residual = x.to(orig_dtype) hidden_size = x.shape[-1] if variance_size_override is None: x_var = x else: if hidden_size < variance_size_override: raise ValueError( "Expected hidden_size to be at least " f"{variance_size_override}, but found: {hidden_size}") x_var = x[:, :, :variance_size_override] variance = x_var.pow(2).mean(dim=-1, keepdim=True) x = x * torch.rsqrt(variance + variance_epsilon) x = x.to(orig_dtype) if weight is not None: x = x * weight if residual is None: return x else: return x, residual # --- VACC Fused Kernel for Qwen2 --- def vacc_fused_attn_qwen2_naive( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], hidden_states_norm_weight: torch.Tensor, qkv_proj_weight: torch.Tensor, qkv_proj_weight_scale: torch.Tensor, qkv_proj_bias: Optional[torch.Tensor], qkv_proj_qzeros: Optional[torch.Tensor], sin_cache: List[torch.Tensor], cos_cache: List[torch.Tensor], slot_mapping: torch.Tensor, kv_cache: torch.Tensor, block_tables: torch.Tensor, block_group_size: int, o_proj_weight: torch.Tensor, o_proj_weight_scale: torch.Tensor, o_proj_bias: Optional[torch.Tensor], o_proj_qzeros: Optional[torch.Tensor], seq_lens: List[int], sm_scale: float, num_attention_heads: int, num_key_value_heads: int, flash_attention: bool, is_decode: bool, reduce_result: bool, world_size: int, rank: int, group_id: int, dev_info: List[int] | Tuple[int], block_size: int = 16 ): # qkv_proj_group_size = 128 # o_proj_group_size = 64 qkv_proj_group_size_h = qkv_proj_weight.shape[-1] * 8 // qkv_proj_weight_scale.shape[-1] qkv_proj_group_size_w = qkv_proj_weight.shape[-2] * 8 // qkv_proj_weight_scale.shape[-2] o_proj_group_size_h = o_proj_weight.shape[-1] * 8 // o_proj_weight_scale.shape[-1] o_proj_group_size_w = o_proj_weight.shape[-2] * 8 // o_proj_weight_scale.shape[-2] if residual is not None: hidden_states = hidden_states + residual residual_out = hidden_states save_tensor_pth("hidden_states", hidden_states) # 1. Fused RMSNorm hidden_states_norm = torch.vacc.rms_norm( hidden_states.unsqueeze(0), hidden_states_norm_weight, 1e-6).squeeze(0) # NOTE: for qwen3 and qwen2.5, head_dim is always 128 head_dim = 128 # 2. QKV Projection using on-the-fly dequantization save_tensor_pth("hidden_states_norm", hidden_states_norm) qkv = apply_gptq_linear_(hidden_states_norm, qkv_proj_weight, qkv_proj_weight_scale, qkv_proj_bias, qkv_proj_qzeros, qkv_proj_group_size_h, qkv_proj_group_size_w) save_tensor_pth("qkv", qkv) # Split Q, K, V num_q_heads = num_attention_heads // world_size num_kv_heads = max(1, num_key_value_heads // world_size) q_size = num_q_heads * head_dim kv_size = num_kv_heads * head_dim q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1) save_tensor_pth("q", q) save_tensor_pth("k", k) save_tensor_pth("v", v) q = q.view(-1, num_q_heads, head_dim) k = k.view(-1, num_kv_heads, head_dim) v = v.view(-1, num_kv_heads, head_dim) save_tensor_pth("q_reshaped", q) save_tensor_pth("k_reshaped", k) save_tensor_pth("v_reshaped", v) # 3. RoPE, KV Caching, and Attention loop start = 0 attn_outs = [] if is_decode: # convert block_tables to 8K group index block_per_group = block_group_size // block_size block_tables_grouped = (block_tables // block_per_group).to(torch.int32) # logger.warning(f"decode block table: {block_tables}") num_blocks = kv_cache.shape[1] key_cache_split = kv_cache[0].view(num_blocks, -1, num_kv_heads, head_dim) value_cache_split = kv_cache[1].view(num_blocks, -1, num_kv_heads, head_dim) # bs loop for i, seq_len in enumerate(seq_lens): if not is_decode: # Prefill end = start + seq_len else: # Decode end = start + 1 cos = cos_cache[i].unsqueeze(-2) sin = sin_cache[i].unsqueeze(-2) q_i, k_i, v_i = q[start:end], k[start:end], v[start:end] # Apply Rotary Positional Embedding q_rot, k_rot = torch.vacc.RotaryPosEmbedding(q_i, k_i, cos, sin, 0, "neox") save_tensor_pth("q_rot", q_rot) save_tensor_pth("k_rot", k_rot) # Reshape and cache K/V torch.vacc.reshape_and_cache_attention(k_rot, key_cache_split, slot_mapping[start : end, ...]) torch.vacc.reshape_and_cache_attention(v_i, value_cache_split, slot_mapping[start : end, ...]) # Attention calculation if not is_decode: attn_out = torch.vacc.scaled_dot_product_attention( query=q_rot, key=k_rot, value=v_i.contiguous(), attn_mask = None, dropout_p = 0.0, is_causal = True, #causal_attn and not self.need_mask, is_train = False, recompute = False, flash_attention = False, sm_scale=sm_scale) else: # For decode, reconstruct past K/V from cache key_cache_grouped = key_cache_split.view(-1, block_group_size, num_kv_heads, head_dim) value_cache_grouped = value_cache_split.view(-1, block_group_size, num_kv_heads, head_dim) # block_per_group = block_group_size // block_size # block_tables_grouped = (block_tables // block_per_group).to(torch.int32) k_slices = key_cache_grouped[block_tables_grouped[i], :, :, :] k_past = torch.cat([k_slices[j].unsqueeze(0) for j in range(len(block_tables_grouped[i]))], dim=0) k_past = k_past.reshape(-1, num_kv_heads, head_dim)[:seq_len] v_slices = value_cache_grouped[block_tables_grouped[i], :, :, :] v_past = torch.cat([v_slices[j].unsqueeze(0) for j in range(len(block_tables_grouped[i]))], dim=0) v_past = v_past.reshape(-1, num_kv_heads, head_dim)[:seq_len] # k_past = key_cache_split.reshape(-1, key_cache_split.shape[-2], key_cache_split.shape[-1])[:seq_len] # v_past = value_cache_split.reshape(-1, value_cache_split.shape[-2], value_cache_split.shape[-1])[:seq_len] attn_out = torch.vacc.scaled_dot_product_attention( query=q_rot.contiguous(), key=k_past.contiguous(), value=v_past.contiguous(), attn_mask=None, dropout_p=0, is_causal=False, is_train=False, recompute=False, flash_attention=False,#flash_attention sm_scale=sm_scale,) attn_outs.append(attn_out) start = end attn_output_cat = torch.cat(attn_outs, dim=0) save_tensor_pth("attn_output", attn_output_cat) # 4. Output Projection attn_output_reshaped = attn_output_cat.view(hidden_states.shape[0], -1) # print(f"attn_output_reshaped: {attn_output_reshaped}") output = apply_gptq_linear_(attn_output_reshaped,o_proj_weight, o_proj_weight_scale,o_proj_bias, o_proj_qzeros, o_proj_group_size_h, o_proj_group_size_w) # 5. Optional All-Reduce and post layernorm if reduce_result: output = tensor_model_parallel_all_reduce(output) if residual is not None: return output, residual_out return output # --- Rewriting forward methods for Qwen2Attention and Qwen2DecoderLayer --- # uniform the params names from different quantize method def set_fused_params(fused_params: Dict[str, Any], quant_method: QuantizeMethodBase, layer: nn.Module, name: str): if isinstance(quant_method, UnquantizedLinearMethod): fused_params[name + '_weight'] = layer.weight fused_params[name + '_weight_scale'] = torch.Tensor() fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias fused_params[name + '_qzeros'] = None elif isinstance(quant_method, Fp8LinearMethod): fused_params[name + '_weight'] = layer.weight fused_params[name + '_weight_scale'] = layer.weight_scale_inv fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros elif isinstance(quant_method, GPTQLinearMethod): fused_params[name + '_weight'] = layer.qweight fused_params[name + '_weight_scale'] = layer.scales fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros elif isinstance(quant_method, AWQLinearMethod): fused_params[name + '_weight'] = layer.qweight fused_params[name + '_weight_scale'] = layer.scales fused_params[name + '_bias'] = None if not hasattr(layer, 'bias') else layer.bias fused_params[name + '_qzeros'] = None if not hasattr(layer, 'qzeros') else layer.qzeros else: raise ValueError(f"Unsupported quant_method: {quant_method}") # 全局变量,用于为每次 dump 操作创建一个唯一的批次 ID # NOTE: This counter is process-specific. If you run multiple processes, # each will have its own counter starting from 0. The directory structure # already includes the process_id, so there won't be conflicts. DUMP_COUNTER = 0 def dump_qwen2_attention_params(hidden_states, residual, fused_params, attn_metadata, is_decode, sin_cache, cos_cache, kv_cache, env_blk_grp_size, scaling, total_num_heads, total_num_kv_heads, world_size, rank, group_id, dev_info): """ 收集所有需要dump的参数,并以“批次”为单位进行保存。 """ # if is_decode: # return # 1. 将所有需要保存的张量和数据收集到一个字典中 # 键是文件名(不含扩展名),值是数据本身。 # 这里的关键是:即使某个值是 None,我们也把它放进字典, # 让保存函数自己去处理,而不是在这里跳过。 params_to_dump = { 'hidden_states': hidden_states, 'residual': residual, 'hidden_states_norm_weight': fused_params.get('input_layernorm_weight'), 'qkv_proj_weight': fused_params.get('qkv_proj_weight'), 'qkv_proj_weight_scale': fused_params.get('qkv_proj_weight_scale'), 'qkv_proj_bias': fused_params.get('qkv_proj_bias'), 'qkv_proj_qzeros': fused_params.get('qkv_proj_qzeros'), 'o_proj_weight': fused_params.get('o_proj_weight'), 'o_proj_weight_scale': fused_params.get('o_proj_weight_scale'), 'o_proj_bias': fused_params.get('o_proj_bias'), 'o_proj_qzeros': fused_params.get('o_proj_qzeros'), 'sin_cache': sin_cache, 'cos_cache': cos_cache, 'kv_cache': kv_cache, 'slot_mapping': attn_metadata.slot_mapping if attn_metadata else None, 'block_tables': attn_metadata.block_tables if attn_metadata else None, 'seq_lens': attn_metadata.seq_lens if attn_metadata else None, 'block_group_size': env_blk_grp_size, 'sm_scale': scaling, 'num_attention_heads': total_num_heads, 'num_key_value_heads': total_num_kv_heads, 'world_size': world_size, 'rank': rank, 'group_id': group_id, 'dev_info': dev_info, } # 2. 一次性调用新的批次保存函数 save_dump_batch(params_to_dump) def save_dump_batch(tensors_dict): """ 将一个字典中的所有张量和数据保存到一个唯一的、代表“同一次调用”的子目录中。 """ global DUMP_COUNTER import time # 基础目录结构:./dump_data/{timestamp}/process_{pid} process_id = os.getpid() timestamp = time.strftime("%Y-%m-%d_%H", time.localtime()) base_dir = f'./dump_qwen2_prefill_pth/{timestamp}/process_{process_id}' # 为这次“批次”的dump创建一个唯一的子目录 # 这就是你想要的“魔法数”的实现 batch_dir = os.path.join(base_dir, f'dump_{DUMP_COUNTER}') # 递增全局计数器,为下一次调用做准备 DUMP_COUNTER += 1 # 创建目录(如果不存在) os.makedirs(batch_dir, exist_ok=True) print(f"--- Dumping batch to: {batch_dir} ---") # 遍历字典,保存每一个非 None 的项 for name, data in tensors_dict.items(): if data is None: # 如果数据是 None,直接跳过,不保存任何文件 continue file_path = os.path.join(batch_dir, f"{name}.pth") try: if isinstance(data, torch.Tensor): print(f' Saving tensor: {name}.pth, shape: {data.shape}, dtype: {data.dtype}') torch.save(data.cpu(), file_path) elif name in ['sin_cache', 'cos_cache'] and isinstance(data, (list, tuple)): # 对 sin/cos cache 的特殊处理 tensor_list = [i.cpu() for i in data] print(f' Saving tensor list: {name}.pth, count: {len(tensor_list)}') torch.save(tensor_list, file_path) else: # 保存其他可序列化的 Python 对象(如数字、字符串等) print(f' Saving metadata: {name}.pth, value: {data}') torch.save(data, file_path) except Exception as e: print(f" [ERROR] Failed to save {name}.pth: {e}") print(f"--- Finished dumping batch {DUMP_COUNTER - 1} ---") def has_nan(t1, t2): has_nan_1 = torch.isnan(t1).any() has_nan_2 = torch.isnan(t2).any() # 2. 根据检查结果进行判断和计算 if has_nan_1 or has_nan_2: print("计算被中止,因为检测到了NaN值。") if has_nan_1: print(" - Tensor 1 包含 NaN。") if has_nan_2: print(" - Tensor 2 包含 NaN。") class MlpSplitConfig: def __init__(self): self.k_split = 0 self.n_split = 0 self.w1_weight_addr = 0 self.w3_weight_addr = 0 self.matA_addr_0 = 0 self.matA_addr_1 = 0 self.w1_splitk_buffer = 0 self.w3_splitk_buffer = 0 self.w13_multiply = 0 self.w2_splitn_buffer_list = [] self.w13_dequant_block_size = 0 self.w2_dequant_block_size = 0 def get_dequant_block_size_int4(k, n, k_split, n_split): ret = [] w13_dequant_block_init_size = 128 k_per_split = k // k_split if k_per_split % w13_dequant_block_init_size == 0: ret.append(w13_dequant_block_init_size) else: while w13_dequant_block_init_size >= 16 and k_per_split % w13_dequant_block_init_size != 0: w13_dequant_block_init_size //= 2 if w13_dequant_block_init_size >= 16: ret.append(w13_dequant_block_init_size) else: ret.append(-1) w2_dequant_block_init_size = 128 n_per_split = n // n_split if n_per_split % w2_dequant_block_init_size == 0: ret.append(w2_dequant_block_init_size) else: while w2_dequant_block_init_size >= 16 and n_per_split % w2_dequant_block_init_size != 0: w2_dequant_block_init_size //= 2 if w2_dequant_block_init_size >= 16: ret.append(w2_dequant_block_init_size) else: ret.append(-1) return ret def get_split_config(total_split): ret = [] max_split_n = math.ceil(math.sqrt(total_split)) for i in range(1, max_split_n + 1): if total_split % i == 0: cur_config = [total_split // i, i] # [k_split, n_split] ret.append(cur_config) split_config_size = len(ret) for i in range(split_config_size - 1, -1, -1): if ret[i][0] != ret[i][1]: ret.append([ret[i][1], ret[i][0]]) return ret def validate_mlp_split_config(k, n, k_split, n_split, config): min_seqlen = 1024 # 尺寸可除性检查 if k % k_split != 0: return False if n % (n_split * 4) != 0: return False if k % (k_split * 4) != 0: return False if (n / (n_split * 4)) % 16 != 0: return False if (k / (k_split * 4)) % 16 != 0: return False # 计算权重和输入矩阵的大小 w13_split_size = (k // k_split) * (n // n_split // 4) * 2 * 2 matA_size = min_seqlen * (k // k_split // 4) * 2 * 2 # 检查SSRAM 4-7区空间是否足够 if w13_split_size + matA_size > 5 * 1024 * 1024: return False # 计算剩余空间 ssram_4_to_7_size_left = 5 * 1024 * 1024 - (w13_split_size + matA_size) w13_result_size = min_seqlen * (n // n_split // 4) * 2 # 检查结果存储空间 if w13_result_size * 3 > ssram_4_to_7_size_left + 5 * 1024 * 1024: return False ssram_0_to_3_size_left = 5 * 1024 * 1024 w13_result_save_type = -1 # 确定结果存储方案 if w13_result_size * 3 <= 5 * 1024 * 1024: w13_result_save_type = 0 ssram_0_to_3_size_left -= w13_result_size * 3 elif w13_result_size * 2 <= 5 * 1024 * 1024: if w13_result_size <= ssram_4_to_7_size_left: w13_result_save_type = 1 ssram_0_to_3_size_left -= w13_result_size * 2 ssram_4_to_7_size_left -= w13_result_size else: return False else: return False # 处理n_split > 1的情况 if n_split > 1: w2_split_k_buffer_total_size = min_seqlen * (k // k_split // 4) * 2 * k_split w2_split_k_buffer_size = min_seqlen * (k // k_split // 4) * 2 if w2_split_k_buffer_total_size > ssram_0_to_3_size_left + ssram_4_to_7_size_left: return False available_blocks = (ssram_0_to_3_size_left // w2_split_k_buffer_size) + \ (ssram_4_to_7_size_left // w2_split_k_buffer_size) if available_blocks < k_split: return False # 设置配置参数 config.k_split = k_split config.n_split = n_split config.w1_weight_addr = 0x54000000 config.w3_weight_addr = config.w1_weight_addr + (k // k_split) * (n // n_split // 4) * 2 config.matA_addr_0 = config.w3_weight_addr + (k // k_split) * (n // n_split // 4) * 2 config.matA_addr_1 = config.matA_addr_0 + min_seqlen * (k // k_split // 4) * 2 config.w1_splitk_buffer = 0x50000000 config.w3_splitk_buffer = config.w1_splitk_buffer + min_seqlen * (n // n_split // 4) * 2 if w13_result_save_type == 0: config.w13_multiply = config.w3_splitk_buffer + min_seqlen * (n // n_split // 4) * 2 elif w13_result_save_type == 1: config.w13_multiply = config.matA_addr_1 + min_seqlen * (k // k_split // 4) * 2 # 处理w2缓冲区分配 if n_split > 1: w2_split_k_buffer_size = min_seqlen * (k // k_split // 4) * 2 w2_split_k_buffer_in_ssram_0_to_3 = min(ssram_0_to_3_size_left // w2_split_k_buffer_size, k_split) w2_split_k_buffer_in_ssram_4_to_7 = k_split - w2_split_k_buffer_in_ssram_0_to_3 if w13_result_save_type == 0: ssram_0_to_3_start_addr = config.w13_multiply + min_seqlen * (n // n_split // 4) * 2 ssram_4_to_7_start_addr = config.matA_addr_1 + min_seqlen * (k // k_split // 4) * 2 else: ssram_0_to_3_start_addr = config.w3_splitk_buffer + min_seqlen * (n // n_split // 4) * 2 ssram_4_to_7_start_addr = config.w13_multiply + min_seqlen * (n // n_split // 4) * 2 for i in range(w2_split_k_buffer_in_ssram_0_to_3): addr = ssram_0_to_3_start_addr + i * w2_split_k_buffer_size config.w2_splitn_buffer_list.append(addr) for i in range(w2_split_k_buffer_in_ssram_4_to_7): addr = ssram_4_to_7_start_addr + i * w2_split_k_buffer_size config.w2_splitn_buffer_list.append(addr) else: config.w2_splitn_buffer_list.append(0) # 获取解量化块大小 block_size = get_dequant_block_size_int4(k, n, k_split, n_split) if block_size[0] == -1 or block_size[1] == -1: return False config.w13_dequant_block_size = block_size[0] config.w2_dequant_block_size = block_size[1] return True def get_mlp_split_schedule(k, n): min_split = math.ceil(k * n / 2 / 1310720) # 最大权重分割大小1.25MB current_split = min_split config = MlpSplitConfig() for _ in range(100): # 最多尝试100次 split_configs = get_split_config(current_split) for cfg in split_configs: if validate_mlp_split_config(k, n, cfg[0], cfg[1], config): return True, config current_split += 1 return False, None def Qwen2MLP__init__( self, hidden_size: int, intermediate_size: int, hidden_act: str, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super(Qwen2MLPOrig, self).__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config, prefix=f"{prefix}.gate_up_proj", ) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.down_proj", ) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") self.act_fn = SiluAndMul() if hasattr(self.gate_up_proj, 'quant_method') and isinstance(self.gate_up_proj.quant_method, GPTQLinearMethod): intermediate_size_per_partition = intermediate_size // self.down_proj.tp_size success, config = get_mlp_split_schedule(hidden_size, intermediate_size_per_partition) if success: self.w13_dequant_block_size = config.w13_dequant_block_size self.w2_dequant_block_size = config.w2_dequant_block_size group_size = self.gate_up_proj.quant_method.quant_config.group_size scale_k_w13 = group_size // config.w13_dequant_block_size scale_k_w2 = group_size // config.w2_dequant_block_size self.gate_up_proj.quant_method.scale_k = max(scale_k_w13, self.gate_up_proj.quant_method.scale_k) self.down_proj.quant_method.scale_k = max(scale_k_w2, self.down_proj.quant_method.scale_k) if self.gate_up_proj.quant_method.scale_k > 1 and len(self.gate_up_proj.scales.data.shape) == 2 and \ self.gate_up_proj.scales.data.dtype in [torch.float16, torch.bfloat16, torch.float32]: w13_scales = self.gate_up_proj.scales.data scale_and_zero_size = hidden_size * self.gate_up_proj.quant_method.scale_k // group_size w13_scales = torch.empty( scale_and_zero_size, w13_scales.shape[-1], dtype=w13_scales.dtype, ) self.gate_up_proj.scales.data = w13_scales if self.down_proj.quant_method.scale_k > 1 and len(self.down_proj.scales.data.shape) == 2 and \ self.down_proj.scales.data.dtype in [torch.float16, torch.bfloat16, torch.float32]: w2_scales = self.down_proj.scales.data scale_and_zero_size = intermediate_size_per_partition * self.down_proj.quant_method.scale_k // group_size w2_scales = torch.empty( scale_and_zero_size, w2_scales.shape[-1], dtype=w2_scales.dtype, ) self.down_proj.scales.data = w2_scales else: raise ValueError(f"Cannot find valid MLP split schedule for hidden_size {hidden_size}, " "intermediate_size {intermediate_size}, tp_size {self.down_proj.tp_size}") def get_cos_sin_cache(rotary_emb: Union["MRotaryEmbedding", "RotaryEmbedding"], attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]], positions: Union[torch.Tensor, list], is_decode: bool): if isinstance(rotary_emb, MRotaryEmbedding): # get mrope sin/cos cos_cache, sin_cache = get_sin_cos_mrope(rotary_emb, positions) if len(attn_metadata.seq_lens) > 1: if is_decode: cos_cache = torch.chunk(cos_cache, len(attn_metadata.seq_lens)) sin_cache = torch.chunk(sin_cache, len(attn_metadata.seq_lens)) else: cos_cache = torch.split(cos_cache, attn_metadata.seq_lens) sin_cache = torch.split(sin_cache, attn_metadata.seq_lens) else: cos_cache = [cos_cache] sin_cache = [sin_cache] else: if is_decode: positions = [i - 1 for i in attn_metadata.seq_lens] cos_cache = [rotary_emb.cos_cache[i:i+1, ...] for i in positions] sin_cache = [rotary_emb.sin_cache[i:i+1, ...] for i in positions] else: cos_cache = [rotary_emb.cos_cache[:i, ...] for i in attn_metadata.seq_lens] sin_cache = [rotary_emb.sin_cache[:i, ...] for i in attn_metadata.seq_lens] return cos_cache, sin_cache class Qwen2MLP(nn.Module): def forward(self, x): #TODO for other quant_method if hasattr(self.gate_up_proj, 'quant_method') and isinstance(self.gate_up_proj.quant_method, GPTQLinearMethod): # bit = self.gate_up_proj.quant_method.quant_config.weight_bits # pack_num = torch.iinfo(self.gate_up_proj.qweight.dtype).bits // bit # w13_group_size = pack_num * self.gate_up_proj.qweight.shape[1] // self.gate_up_proj.scales.shape[1] # w2_group_size = pack_num * self.down_proj.qweight.shape[1] // self.down_proj.scales.shape[1] tp_group = get_tp_group() batch_size = x.shape[0] if batch_size > 4: y = torch.vacc.fuse_mlp_qwen_int4( x, self.gate_up_proj.qweight, self.down_proj.qweight, self.gate_up_proj.scales, self.down_proj.scales, self.gate_up_proj.qzeros, self.down_proj.qzeros, [1, self.w13_dequant_block_size], [1, self.w2_dequant_block_size] ) mlp_out = tensor_model_parallel_all_reduce(y) return mlp_out else: mlp_out = torch.vacc.fuse_mlp_qwen_int4_reduce( x, self.gate_up_proj.qweight, self.down_proj.qweight, self.gate_up_proj.scales, self.down_proj.scales, self.gate_up_proj.qzeros, self.down_proj.qzeros, [1, self.w13_dequant_block_size], [1, self.w2_dequant_block_size], world_size = tp_group.world_size, rank = tp_group.rank_in_group, group_id = tp_group.group_id, dev_info = tp_group.rank_device_infos ) return mlp_out elif hasattr(self.gate_up_proj, 'quant_method') and isinstance(self.gate_up_proj.quant_method, UnquantizedLinearMethod): batch_size = x.shape[0] if batch_size <= 4: hidden_states = torch.vacc.fuse_mlp_qwen_fp16_bf16_reduce(x.view(-1, x.shape[-1]), self.gate_up_proj.weight, self.down_proj.weight, world_size = get_tp_group().world_size, rank = get_tp_group().rank_in_group, group_id = get_tp_group().group_id, dev_info = get_tp_group().rank_device_infos ) return hidden_states gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) return x class Qwen2Attention(nn.Module): def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor] = None, # Added for fused kernel cos_cache: list[torch.Tensor] = None, sin_cache: list[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[self.attn.layer_name] kv_cache = self.attn.kv_cache[forward_context.virtual_engine] is_decode = attn_metadata.prefill_metadata is None # print('attn_metadata', attn_metadata) use_gptq = hasattr(self.qkv_proj, 'quant_method') and isinstance(self.qkv_proj.quant_method, GPTQLinearMethod) use_unquan = hasattr(self.qkv_proj, 'quant_method') and isinstance(self.qkv_proj.quant_method, UnquantizedLinearMethod) # if USE_FUSED_QWEN_ATTENTION and not is_decode: if use_gptq: # Determine if we should reduce the result inside the kernel # print(f"Qwen2Attention forward, is_decode: {is_decode}, kv_cache shape: {kv_cache.shape}, attn_metadata.seq_lens: {attn_metadata.seq_lens}") kv_cache = kv_cache.view(2, -1, 16, self.total_num_kv_heads // get_tp_group().world_size, self.head_dim) reduce_result = is_decode # reduce_result = False # total_bytes = hidden_states.numel() * hidden_states.element_size() * get_tp_group().world_size # if total_bytes < 4194304: # Heuristic from qwen3_moe_vacc # reduce_result = True # Get rotary caches # cos_cache = [] # sin_cache = [] # start = 0 # for seq_len in attn_metadata.seq_lens: # if not is_decode: # end = start + seq_len # else: # end = start + 1 # # NOTE: # # in prefill stage, the value from positions[start] to positions[end] should be always contiguous # cos_cache.append(self.rotary_emb.cos_cache[positions[start]:positions[end-1] + 1, ...]) # sin_cache.append(self.rotary_emb.sin_cache[positions[start]:positions[end-1] + 1, ...]) # start = end if is_decode: positions = [i - 1 for i in attn_metadata.seq_lens] cos_cache = [self.rotary_emb.cos_cache[i:i+1, ...] for i in positions] sin_cache = [self.rotary_emb.sin_cache[i:i+1, ...] for i in positions] else: cos_cache = [self.rotary_emb.cos_cache[:i, ...] for i in attn_metadata.seq_lens] sin_cache = [self.rotary_emb.sin_cache[:i, ...] for i in attn_metadata.seq_lens] if residual is None: res_out = hidden_states if USE_FUSED_QWEN_ATTENTION: qkv_proj_qzeros = torch.Tensor() o_proj_bias = torch.Tensor() o_proj_qzeros = torch.Tensor() qkv_proj_bias = self.fused_params['qkv_proj_bias'] # if is_decode: # qkv_proj_bias = self.fused_params['qkv_proj_bias'] # if hidden_states.shape[0] > 1: # qkv_proj_bias = qkv_proj_bias.repeat(len(attn_metadata.seq_lens), 1) # else: # qkv_proj_bias = self.fused_params['qkv_proj_bias'] # qkv_proj_bias = qkv_proj_bias.repeat(sum(attn_metadata.seq_lens), 1) # print("Qwen2Attention forward, is_decode:", is_decode, # "kv_cache shape:", kv_cache.shape, # "attn_metadata.seq_lens:", attn_metadata.seq_lens, # "qkv_proj_bias shape:", qkv_proj_bias.shape, # "qkv_proj_bias:", qkv_proj_bias, # "qkv_proj_qzeros shape:", qkv_proj_qzeros.shape, # "o_proj_bias shape:", o_proj_bias.shape, # "o_proj_bias:", o_proj_bias, # "o_proj_qzeros shape:", o_proj_qzeros.shape, # "o_proj_qzeros:", o_proj_qzeros, # "reduce_result:", reduce_result, # "world_size:", get_tp_group().world_size, # "rank:", get_tp_group().rank_in_group, # "group_id:", get_tp_group().group_id, # "dev_info:", get_tp_group().rank_device_infos) # print("residual shape:", residual.shape if residual is not None else "None", # "hidden_states shape:", hidden_states.shape, # "hidden_states dtype:", hidden_states.dtype, # "hidden_states device:", hidden_states.device, # "attn_metadata.slot_mapping shape:", attn_metadata.slot_mapping.shape if attn_metadata.slot_mapping is not None else "None", # "attn_metadata.block_tables shape:", attn_metadata.block_tables.shape if attn_metadata.block_tables is not None else "None", # "attn_metadata.block_tables dtype:", attn_metadata.block_tables.dtype if attn_metadata.block_tables is not None else "None", # "attn_metadata.block_tables device:", attn_metadata.block_tables.device if attn_metadata.block_tables is not None else "None", # "attn_metadata.seq_lens:", attn_metadata.seq_lens, # "attn_metadata.seq_lens length:", len(attn_metadata.seq_lens), # "attn_metadata.seq_lens dtype:", type(attn_metadata.seq_lens), # # "attn_metadata.seq_lens device:", attn_metadata.seq_lens[0].device if attn_metadata.seq_lens else "None", # "attn_metadata.seq_lens type:", type(attn_metadata.seq_lens[0]) if attn_metadata.seq_lens else "None", # "attn_metadata.seq_lens content:", attn_metadata.seq_lens if attn_metadata.seq_lens else "None", # "attn_metadata.slot_mapping dtype:", attn_metadata.slot_mapping.dtype if attn_metadata.slot_mapping is not None else "None", # "attn_metadata.slot_mapping device:", attn_metadata.slot_mapping.device if attn_metadata.slot_mapping is not None else "None", # "attn_metadata.slot_mapping type:", type(attn_metadata.slot_mapping) if attn_metadata.slot_mapping is not None else "None", # "attn_metadata.slot_mapping content:", attn_metadata.slot_mapping if attn_metadata.slot_mapping is not None else "None", # "self.fused_params['input_layernorm_weight'] shape:", self.fused_params['input_layernorm_weight'].shape if 'input_layernorm_weight' in self.fused_params else "None", # "self.fused_params['input_layernorm_weight'] dtype:", self.fused_params['input_layernorm_weight'].dtype if 'input_layernorm_weight' in self.fused_params else "None", # "self.fused_params['input_layernorm_weight'] device:", self.fused_params['input_layernorm_weight'].device if 'input_layernorm_weight' in self.fused_params else "None", # "self.fused_params['qkv_proj_weight'] shape:", self.fused_params['qkv_proj_weight'].shape if 'qkv_proj_weight' in self.fused_params else "None", # "self.fused_params['qkv_proj_weight'] dtype:", self.fused_params['qkv_proj_weight'].dtype if 'qkv_proj_weight' in self.fused_params else "None", # "self.fused_params['qkv_proj_weight'] device:", self.fused_params['qkv_proj_weight'].device if 'qkv_proj_weight' in self.fused_params else "None", # "self.fused_params['qkv_proj_weight_scale'] shape:", self.fused_params['qkv_proj_weight_scale'].shape if 'qkv_proj_weight_scale' in self.fused_params else "None", # "self.fused_params['qkv_proj_weight_scale'] dtype:", self.fused_params['qkv_proj_weight_scale'].dtype if 'qkv_proj_weight_scale' in self.fused_params else "None", # "self.fused_params['qkv_proj_weight_scale'] device:", self.fused_params['qkv_proj_weight_scale'].device if 'qkv_proj_weight_scale' in self.fused_params else "None", # "self.fused_params['qkv_proj_bias'] shape:", self.fused_params['qkv_proj_bias'].shape if 'qkv_proj_bias' in self.fused_params else "None", # "self.fused_params['qkv_proj_bias'] dtype:", self.fused_params['qkv_proj_bias'].dtype if 'qkv_proj_bias' in self.fused_params else "None", # "self.fused_params['qkv_proj_bias'] device:", self.fused_params['qkv_proj_bias'].device if 'qkv_proj_bias' in self.fused_params else "None", # "self.fused_params['qkv_proj_qzeros'] shape:", self.fused_params['qkv_proj_qzeros'].shape if 'qkv_proj_qzeros' in self.fused_params else "None", # "self.fused_params['qkv_proj_qzeros'] dtype:", self.fused_params['qkv_proj_qzeros'].dtype if 'qkv_proj_qzeros' in self.fused_params else "None", # "self.fused_params['qkv_proj_qzeros'] device:", self.fused_params['qkv_proj_qzeros'].device if 'qkv_proj_qzeros' in self.fused_params else "None", # "self.fused_params['o_proj_weight'] shape:", self.fused_params['o_proj_weight'].shape if 'o_proj_weight' in self.fused_params else "None", # "self.fused_params['o_proj_weight'] dtype:", self.fused_params['o_proj_weight'].dtype if 'o_proj_weight' in self.fused_params else "None", # "self.fused_params['o_proj_weight'] device:", self.fused_params['o_proj_weight'].device if 'o_proj_weight' in self.fused_params else "None", # "self.fused_params['o_proj_weight_scale'] shape:", self.fused_params['o_proj_weight_scale'].shape if 'o_proj_weight_scale' in self.fused_params else "None", # "self.fused_params['o_proj_weight_scale'] dtype:", self.fused_params['o_proj_weight_scale'].dtype if 'o_proj_weight_scale' in self.fused_params else "None", # "self.fused_params['o_proj_weight_scale'] device:", self.fused_params['o_proj_weight_scale'].device if 'o_proj_weight_scale' in self.fused_params else "None", # # "self.fused_params['o_proj_bias'] shape:", self.fused_params['o_proj_bias'].shape if 'o_proj_bias' in self.fused_params else "None", # # "self.fused_params['o_proj_bias'] dtype:", self.fused_params['o_proj_bias'].dtype if 'o_proj_bias' in self.fused_params else "None", # # "self.fused_params['o_proj_bias'] device:", self.fused_params['o_proj_bias'].device if 'o_proj_bias' in self.fused_params else "None", # # "self.fused_params['o_proj_qzeros'] shape:", self.fused_params['o_proj_qzeros'].shape if 'o_proj_qzeros' in self.fused_params else "None", # # "self.fused_params['o_proj_qzeros'] dtype:", self.fused_params['o_proj_qzeros'].dtype if 'o_proj_qzeros' in self.fused_params else "None", # # "self.fused_params['o_proj_qzeros'] device:", self.fused_params['o_proj_qzeros'].device if 'o_proj_qzeros' in self.fused_params else "None", # "sin_cache length:", len(sin_cache), # "cos_cache length:", len(cos_cache), # "sin_cache[0] shape:", sin_cache[0].shape if sin_cache else "None", # "cos_cache[0] shape:", cos_cache[0].shape if cos_cache else "None", # "sin_cache[0] dtype:", sin_cache[0].dtype if sin_cache else "None", # "cos_cache[0] dtype:", cos_cache[0].dtype if cos_cache else "None", # "sin_cache[0] device:", sin_cache[0].device if sin_cache else "None", # "cos_cache[0] device:", cos_cache[0].device if cos_cache else "None", # "kv_cache shape:", kv_cache.shape, # "kv_cache dtype:", kv_cache.dtype, # "kv_cache device:", kv_cache.device, # "self.scaling:", self.scaling, # "self.total_num_heads:", self.total_num_heads, # "self.total_num_kv_heads:", self.total_num_kv_heads, # "is_decode:", is_decode, # "reduce_result:", reduce_result # ) # Call the naive fused kernel ####################################call fusion kernel###################################### # print("starting fused attention kernel......................................, is_decode:", is_decode) # print('hidden_states', hidden_states.shape) # print('attn_metadata.slot_mapping,', attn_metadata.slot_mapping,) # print('attn_metadata.block_tables,', attn_metadata.block_tables,) # print('attn_metadata.seq_lens,', attn_metadata.seq_lens,) # print('is_decode,', is_decode) attn_outs = torch.vacc.fuse_atten_qwen2( hidden_states, residual, self.fused_params['input_layernorm_weight'], self.fused_params['qkv_proj_weight'], self.fused_params['qkv_proj_weight_scale'], qkv_proj_bias, qkv_proj_qzeros, #self.fused_params['qkv_proj_qzeros'], sin_cache, cos_cache, attn_metadata.slot_mapping, kv_cache, attn_metadata.block_tables, env_blk_grp_size, self.fused_params['o_proj_weight'], self.fused_params['o_proj_weight_scale'], o_proj_bias, #self.fused_params['o_proj_bias'], o_proj_qzeros, #self.fused_params['o_proj_qzeros'], attn_metadata.seq_lens, self.scaling, self.total_num_heads, self.total_num_kv_heads, is_decode, is_decode, reduce_result, world_size=get_tp_group().world_size, rank=get_tp_group().rank_in_group, group_id=get_tp_group().group_id, dev_info=get_tp_group().rank_device_infos) # print("finished fused attention kernel......................................, is_decode:", is_decode) # attn_outs1 = vacc_fused_attn_qwen2_naive( # hidden_states=hidden_states, # residual=residual, # hidden_states_norm_weight=self.fused_params['input_layernorm_weight'], # ###ignore the group size, it is not used in the fused kernel # # qkv_proj_group_size=get_gptq_group_size(self.qkv_proj), # # o_proj_group_size=get_gptq_group_size(self.o_proj), # ### # qkv_proj_weight=self.fused_params['qkv_proj_weight'], # qkv_proj_weight_scale=self.fused_params['qkv_proj_weight_scale'], # qkv_proj_bias=self.fused_params['qkv_proj_bias'], # qkv_proj_qzeros=self.fused_params['qkv_proj_qzeros'], # sin_cache=sin_cache, # cos_cache=cos_cache, # slot_mapping=attn_metadata.slot_mapping, # kv_cache=kv_cache, # block_tables=attn_metadata.block_tables, # block_group_size=env_blk_grp_size, # o_proj_weight=self.fused_params['o_proj_weight'], # o_proj_weight_scale=self.fused_params['o_proj_weight_scale'], # o_proj_bias=self.fused_params['o_proj_bias'], # o_proj_qzeros=self.fused_params['o_proj_qzeros'], # seq_lens=attn_metadata.seq_lens, # sm_scale=self.scaling, # num_attention_heads=self.total_num_heads, # num_key_value_heads=self.total_num_kv_heads, # # head_dim=self.head_dim, # is_decode=is_decode, # flash_attention=is_decode, # reduce_result=reduce_result, # world_size=get_tp_group().world_size, # rank=get_tp_group().rank_in_group, # group_id=get_tp_group().group_id, # dev_info=get_tp_group().rank_device_infos # ) # if residual is None: # has_nan(attn_outs1, attn_outs) # ret = torch.isnan(attn_outs).any() # if torch.isinf(attn_outs).any() or torch.isnan(attn_outs).any(): # dump_qwen2_attention_params(hidden_states=hidden_states, # # output=output_, # residual=residual, # fused_params=self.fused_params, # attn_metadata=attn_metadata, # is_decode=is_decode, # sin_cache=sin_cache, # cos_cache=cos_cache, # kv_cache=kv_cache, # env_blk_grp_size=env_blk_grp_size, # scaling=self.scaling, # total_num_heads=self.total_num_heads, # total_num_kv_heads=self.total_num_kv_heads, # world_size=get_tp_group().world_size, # rank=get_tp_group().rank_in_group, # group_id=get_tp_group().group_id, # dev_info=get_tp_group().rank_device_infos) # cos_sim = torch.cosine_similarity(attn_outs1.cpu().reshape([-1]).float(), attn_outs.cpu().reshape([-1]).float(), dim=-1) # print("cos_sim witout residual:",cos_sim ) # else : # has_nan(attn_outs1[0], attn_outs[0]) # ret = torch.isnan(attn_outs[0]).any() # if ret: # dump_qwen2_attention_params(hidden_states=hidden_states, # # output=output_, # residual=residual, # fused_params=self.fused_params, # attn_metadata=attn_metadata, # is_decode=is_decode, # sin_cache=sin_cache, # cos_cache=cos_cache, # kv_cache=kv_cache, # env_blk_grp_size=env_blk_grp_size, # scaling=self.scaling, # total_num_heads=self.total_num_heads, # total_num_kv_heads=self.total_num_kv_heads, # world_size=get_tp_group().world_size, # rank=get_tp_group().rank_in_group, # group_id=get_tp_group().group_id, # dev_info=get_tp_group().rank_device_infos) # cos_sim = torch.cosine_similarity(attn_outs1[0].cpu().reshape([-1]).float(), attn_outs[0].cpu().reshape([-1]).float(), dim=-1) # print("cos_sim residual:",cos_sim ) else: ########################################################################## # dump_qwen2_attention_params(hidden_states=hidden_states, # # output=output_, # residual=residual, # fused_params=self.fused_params, # attn_metadata=attn_metadata, # is_decode=is_decode, # sin_cache=sin_cache, # cos_cache=cos_cache, # kv_cache=kv_cache, # env_blk_grp_size=env_blk_grp_size, # scaling=self.scaling, # total_num_heads=self.total_num_heads, # total_num_kv_heads=self.total_num_kv_heads, # world_size=get_tp_group().world_size, # rank=get_tp_group().rank_in_group, # group_id=get_tp_group().group_id, # dev_info=get_tp_group().rank_device_infos) attn_outs = vacc_fused_attn_qwen2_naive( hidden_states=hidden_states, residual=residual, hidden_states_norm_weight=self.fused_params['input_layernorm_weight'], ###ignore the group size, it is not used in the fused kernel # qkv_proj_group_size=get_gptq_group_size(self.qkv_proj), # o_proj_group_size=get_gptq_group_size(self.o_proj), ### qkv_proj_weight=self.fused_params['qkv_proj_weight'], qkv_proj_weight_scale=self.fused_params['qkv_proj_weight_scale'], qkv_proj_bias=self.fused_params['qkv_proj_bias'], qkv_proj_qzeros=None, sin_cache=sin_cache, cos_cache=cos_cache, slot_mapping=attn_metadata.slot_mapping, kv_cache=kv_cache, block_tables=attn_metadata.block_tables, block_group_size=env_blk_grp_size, o_proj_weight=self.fused_params['o_proj_weight'], o_proj_weight_scale=self.fused_params['o_proj_weight_scale'], o_proj_bias=None, o_proj_qzeros=None, seq_lens=attn_metadata.seq_lens, sm_scale=self.scaling, num_attention_heads=self.total_num_heads, num_key_value_heads=self.total_num_kv_heads, # head_dim=self.head_dim, is_decode=is_decode, flash_attention=is_decode, reduce_result=reduce_result, world_size=get_tp_group().world_size, rank=get_tp_group().rank_in_group, group_id=get_tp_group().group_id, dev_info=get_tp_group().rank_device_infos ) # if is_decode: # assert(1<0) # if residual is None: # output_ = attn_outs # else: # output_ = attn_outs[0] # print("starting tensor_model_parallel_all_reduce......................................, is_decode:", is_decode) if residual is None: attn_out = tensor_model_parallel_all_reduce(attn_outs) if not reduce_result else attn_outs else: res_out = attn_outs[1] attn_out = tensor_model_parallel_all_reduce(attn_outs[0]) if not reduce_result else attn_outs[0] # print("finished tensor_model_parallel_all_reduce......................................, is_decode:", is_decode) return attn_out, res_out elif use_unquan: # Determine if we should reduce the result inside the kernel # print(f"Qwen2Attention forward, is_decode: {is_decode}, kv_cache shape: {kv_cache.shape}, attn_metadata.seq_lens: {attn_metadata.seq_lens}") kv_cache = kv_cache.view(2, -1, 16, self.total_num_kv_heads // get_tp_group().world_size, self.head_dim) reduce_result = is_decode if cos_cache is None or sin_cache is None: cos_cache, sin_cache = get_cos_sin_cache(self.rotary_emb, attn_metadata, positions, is_decode) if residual is None: res_out = hidden_states if USE_FUSED_QWEN_ATTENTION: qkv_proj_qzeros = torch.Tensor() o_proj_bias = torch.Tensor() o_proj_qzeros = torch.Tensor() # Call the naive fused kernel ####################################call fusion kernel###################################### # print("starting fused attention kernel......................................, is_decode:", is_decode) # print('hidden_states', hidden_states.shape) # print('attn_metadata.slot_mapping,', attn_metadata.slot_mapping,) # print('attn_metadata.block_tables,', attn_metadata.block_tables,) # print('attn_metadata.seq_lens,', attn_metadata.seq_lens,) attn_outs = torch.vacc.fuse_atten_qwen2( hidden_states, residual, self.fused_params['input_layernorm_weight'], self.fused_params['qkv_proj_weight'], self.fused_params['qkv_proj_weight_scale'], # qkv_proj_bias, self.fused_params['qkv_proj_bias'], qkv_proj_qzeros, #self.fused_params['qkv_proj_qzeros'], sin_cache, cos_cache, attn_metadata.slot_mapping, kv_cache, attn_metadata.block_tables, env_blk_grp_size, self.fused_params['o_proj_weight'], self.fused_params['o_proj_weight_scale'], o_proj_bias, #self.fused_params['o_proj_bias'], o_proj_qzeros, #self.fused_params['o_proj_qzeros'], attn_metadata.seq_lens, self.scaling, self.total_num_heads, self.total_num_kv_heads, is_decode, is_decode, reduce_result, world_size=get_tp_group().world_size, rank=get_tp_group().rank_in_group, group_id=get_tp_group().group_id, dev_info=get_tp_group().rank_device_infos) else: attn_outs = vacc_fused_attn_qwen2_naive( hidden_states=hidden_states, residual=residual, hidden_states_norm_weight=self.fused_params['input_layernorm_weight'], ###ignore the group size, it is not used in the fused kernel # qkv_proj_group_size=get_gptq_group_size(self.qkv_proj), # o_proj_group_size=get_gptq_group_size(self.o_proj), ### qkv_proj_weight=self.fused_params['qkv_proj_weight'], qkv_proj_weight_scale=self.fused_params['qkv_proj_weight_scale'], qkv_proj_bias=self.fused_params['qkv_proj_bias'], qkv_proj_qzeros=None, sin_cache=sin_cache, cos_cache=cos_cache, slot_mapping=attn_metadata.slot_mapping, kv_cache=kv_cache, block_tables=attn_metadata.block_tables, block_group_size=env_blk_grp_size, o_proj_weight=self.fused_params['o_proj_weight'], o_proj_weight_scale=self.fused_params['o_proj_weight_scale'], o_proj_bias=None, o_proj_qzeros=None, seq_lens=attn_metadata.seq_lens, sm_scale=self.scaling, num_attention_heads=self.total_num_heads, num_key_value_heads=self.total_num_kv_heads, # head_dim=self.head_dim, is_decode=is_decode, flash_attention=is_decode, reduce_result=reduce_result, world_size=get_tp_group().world_size, rank=get_tp_group().rank_in_group, group_id=get_tp_group().group_id, dev_info=get_tp_group().rank_device_infos ) # if is_decode: # assert(1<0) # if residual is None: # output_ = attn_outs # else: # output_ = attn_outs[0] # print("starting tensor_model_parallel_all_reduce......................................, is_decode:", is_decode) if residual is None: attn_out = tensor_model_parallel_all_reduce(attn_outs) if not reduce_result else attn_outs # print(f"--------------------------attn_outs :{attn_out}") # print(f"--------------------------attn_outs shape :{attn_out.shape}") # print(f"--------------------------res_out :{res_out}") else: res_out = attn_outs[1] attn_out = tensor_model_parallel_all_reduce(attn_outs[0]) if not reduce_result else attn_outs[0] # print(f"--------------------------attn_outs :{attn_out}") # print(f"--------------------------attn_outs shape :{attn_out.shape}") # print(f"--------------------------res_out :{res_out}") # print("finished tensor_model_parallel_all_reduce......................................, is_decode:", is_decode) return attn_out, res_out else: # Fallback to original implementation # save_tensor_pth("hidden_states", hidden_states) qkv, _ = self.qkv_proj(hidden_states) # save_tensor_pth("qkv", qkv) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) # save_tensor_pth("q", q) # save_tensor_pth("k", k) # save_tensor_pth("v", v) q, k = self.rotary_emb(positions, q, k) # save_tensor_pth("q_rot", q) # save_tensor_pth("k_rot", k) attn_output = self.attn(q, k, v) # save_tensor_pth("attn_output", attn_output) output, _ = self.o_proj(attn_output) # save_tensor_pth("output", output) return output class Qwen2DecoderLayer(nn.Module): def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # forward_context: ForwardContext = get_forward_context() # attn_metadata = forward_context.attn_metadata # kv_cache = self.attn.kv_cache[forward_context.virtual_engine] # is_decode = attn_metadata.prefill_metadata is None use_gptq = hasattr(self.self_attn.qkv_proj, 'quant_method') and isinstance(self.self_attn.qkv_proj.quant_method, GPTQLinearMethod) use_unquan = hasattr(self.self_attn.qkv_proj, 'quant_method') and isinstance(self.self_attn.qkv_proj.quant_method, UnquantizedLinearMethod) # if USE_FUSED_QWEN_ATTENTION and not is_decode: if use_gptq or use_unquan: # Self Attention with fused pre-layernorm # Pass the layernorm weight to the attention layer for the fused kernel self.self_attn.input_layernorm_weight = self.input_layernorm.weight self.self_attn.fused_params = {} self.self_attn.fused_params['input_layernorm_weight'] = self.input_layernorm.weight set_fused_params(self.self_attn.fused_params, self.self_attn.qkv_proj.quant_method, self.self_attn.qkv_proj, 'qkv_proj') set_fused_params(self.self_attn.fused_params, self.self_attn.o_proj.quant_method, self.self_attn.o_proj, 'o_proj') hidden_states, residual = self.self_attn( positions=positions, hidden_states=hidden_states, residual=residual, ) else: # Original non-fused path if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm( hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) # Fully Connected (MLP) part remains the same hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) # save_tensor_pth("post_attention_layernorm_hidden_states", hidden_states) # save_tensor_pth("post_attention_layernorm_residual", residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual