################################################################################ # Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ################################################################################ from typing import Any, Optional, Tuple import torch import torch_br from torch import nn from torch_br.supa.profiler_kineto import record_function from vllm.attention import Attention, AttentionType from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import (MRotaryEmbedding, get_rope) from vllm.model_executor.models.utils import extract_layer_index class AttentionSplit(nn.Module): def __init__( self, hidden_size: int, num_heads: int, num_kv_heads: int, max_position: int = 4096 * 32, rope_theta: int = 10000, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, rope_scaling: Optional[Tuple] = None, attn_type: str = AttentionType.DECODER, prefix: str = "", dual_chunk_attention_config: Optional[dict[str, Any]] = None, bias: bool = False, ) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = num_kv_heads if self.total_num_kv_heads >= tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. assert self.total_num_kv_heads % tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta qconfig = None if quant_config is not None and quant_config.qkv_quantized: qconfig = quant_config self.q_proj = ColumnParallelLinear(input_size=hidden_size, output_size=self.q_size * tp_size, bias=bias, quant_config=qconfig, prefix=f"{prefix}.q_proj") self.k_proj = ColumnParallelLinear(input_size=hidden_size, output_size=self.kv_size * tp_size, bias=bias, quant_config=qconfig, prefix=f"{prefix}.k_proj") self.v_proj = ColumnParallelLinear(input_size=hidden_size, output_size=self.kv_size * tp_size, bias=bias, quant_config=qconfig, prefix=f"{prefix}.v_proj") self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, hidden_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.o_proj") self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, max_position=max_position, base=self.rope_theta, rope_scaling=rope_scaling, ) self.attn = Attention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, attn_type=attn_type, prefix=f"{prefix}.attn", **{ "layer_idx": extract_layer_index(prefix), "dual_chunk_attention_config": dual_chunk_attention_config, } if dual_chunk_attention_config else {}) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata if attn_metadata is None: ## for dummy run return hidden_states seq_len = hidden_states.shape[-2] decode_seql = 512 # numa weight and not use mrope (qwen-vl) if ((hasattr(self.q_proj, "qweight") and len(self.q_proj.qweight.shape) == 3) or (hasattr(self.q_proj, "weight") and len(self.q_proj.weight.shape) == 3)) and not isinstance( self.rotary_emb, MRotaryEmbedding) and seq_len <= decode_seql: if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[self.attn.layer_name] kv_cache = self.attn.kv_cache[forward_context.virtual_engine] if kv_cache is not None: with record_function('attention qkv_rope'): # int8 weight version q_weight = self.q_proj.qweight if hasattr( self.q_proj, "qweight") else self.q_proj.weight k_weight = self.k_proj.qweight if hasattr( self.k_proj, "qweight") else self.k_proj.weight v_weight = self.v_proj.qweight if hasattr( self.v_proj, "qweight") else self.v_proj.weight q_scale = self.q_proj.scales if hasattr( self.q_proj, "scales") else None k_scale = self.k_proj.scales if hasattr( self.k_proj, "scales") else None v_scale = self.v_proj.scales if hasattr( self.v_proj, "scales") else None q_bias = self.q_proj.bias if hasattr(self.q_proj, "bias") else None k_bias = self.k_proj.bias if hasattr(self.k_proj, "bias") else None v_bias = self.v_proj.bias if hasattr(self.v_proj, "bias") else None q, k, v = torch_br.supa_qkv_rope_decode_infer( hidden_states, q_weight, k_weight, v_weight, self.rotary_emb.sin_cache, self.rotary_emb.cos_cache, kv_cache, positions, attn_metadata.slot_mapping, self.rotary_emb.head_size, self.q_size, self.kv_size, q_scale=q_scale, k_scale=k_scale, v_scale=v_scale, q_bias=q_bias, k_bias=k_bias, v_bias=v_bias) if hasattr(attn_metadata, 'do_cache'): attn_metadata.do_cache = False with record_function('attention'): attn_output = self.attn(q, k, v) with record_function('attention o_proj'): output, _ = self.o_proj(attn_output) return output else: return hidden_states else: # uma weight or use mrope (qwen-vl) q, _ = self.q_proj(hidden_states) k, _ = self.k_proj(hidden_states) v, _ = self.v_proj(hidden_states) q, k = self.rotary_emb(positions, q, k) if hasattr(attn_metadata, 'do_cache'): attn_metadata.do_cache = True attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output