# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy import functools import logging import math from dataclasses import replace from functools import partial import torch import torch.nn.functional as F from torch import nn from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( QKVParallelLinear, RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.models.mistral import MistralMLP from vllm.model_executor.models.whisper import WhisperPosEmbedType from vllm.v1.attention.backend import ( AttentionBackend, AttentionMetadata, AttentionType, CommonAttentionMetadata, subclass_attention_backend_with_overrides, ) from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend try: from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend except ImportError: AiterFlashAttentionBackend = None from vllm.v1.attention.backends.rocm_attn import RocmAttentionBackend from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend from vllm.v1.attention.selector import get_attn_backend from vllm.v1.kv_cache_interface import AttentionSpec from .utils import make_layers logger = logging.getLogger(__name__) CausalRMSNorm = partial(RMSNorm, eps=1e-5) def _pad1d( x: torch.Tensor, paddings: tuple[int, int], mode: str = "constant", value: float = 0.0, ) -> torch.Tensor: """Tiny wrapper around F.pad, just to allow for reflect padding on small input. If this is the case, we insert extra 0 padding to the right before the reflection happen. """ length = x.shape[-1] padding_left, padding_right = paddings assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) if mode == "reflect": max_pad = max(padding_left, padding_right) extra_pad = 0 if length <= max_pad: extra_pad = max_pad - length + 1 x = F.pad(x, (0, extra_pad)) padded = F.pad(x, paddings, mode, value) end = padded.shape[-1] - extra_pad return padded[..., :end] else: return F.pad(x, paddings, mode, value) class WhisperCausalConv1d(nn.Conv1d): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0, bias: bool = True, ) -> None: super().__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias, ) self._stride = self.stride[0] self._effective_kernel_size = (kernel_size - 1) * self.dilation[0] + 1 self._padding_total = self._effective_kernel_size - self._stride def forward(self, x: torch.Tensor) -> torch.Tensor: n_frames = ( x.shape[-1] - self._effective_kernel_size + self._padding_total ) / self._stride + 1 target_length = (math.ceil(n_frames) - 1) * self._stride + ( self._effective_kernel_size - self._padding_total ) extra_padding = target_length - x.shape[-1] x = _pad1d(x, (self._padding_total, extra_padding), mode="constant") return super().forward(x) @functools.lru_cache def create_whisper_attention_backend_with_block_pooling( underlying_attn_backend: AttentionBackend, block_pool_size: int ) -> type[AttentionBackend]: prefix = "WhisperCausalAttentionWithBlockPooling_" underlying_builder = underlying_attn_backend.get_builder_cls() underlying_impl = underlying_attn_backend.get_impl_cls() class WhisperCausalAttentionWithBlockPoolingBuilder(underlying_builder): # type: ignore def __init__( self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device, ): assert kv_cache_spec.num_kv_heads % block_pool_size == 0 kv_cache_spec = replace( kv_cache_spec, block_size=kv_cache_spec.block_size * block_pool_size, num_kv_heads=kv_cache_spec.num_kv_heads // block_pool_size, ) super().__init__(kv_cache_spec, layer_names, vllm_config, device) # Override model_config-derived values with the actual # encoder values from kv_cache_spec self.num_heads_kv = kv_cache_spec.num_kv_heads self.headdim = kv_cache_spec.head_size # num_heads_q for the encoder is the same as num_kv_heads # (no GQA in whisper encoder) self.num_heads_q = kv_cache_spec.num_kv_heads def build( self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False, ) -> AttentionMetadata: new_common_attn_metadata = copy.deepcopy(common_attn_metadata) new_common_attn_metadata.query_start_loc *= block_pool_size new_common_attn_metadata.query_start_loc_cpu *= block_pool_size new_common_attn_metadata.seq_lens *= block_pool_size new_common_attn_metadata._seq_lens_cpu *= block_pool_size new_common_attn_metadata._num_computed_tokens_cpu *= block_pool_size new_common_attn_metadata.num_actual_tokens *= block_pool_size new_common_attn_metadata.max_query_len *= block_pool_size new_common_attn_metadata.max_seq_len *= block_pool_size original_slot_mapping = common_attn_metadata.slot_mapping common_prefix_len *= block_pool_size new_common_attn_metadata.slot_mapping = ( ( original_slot_mapping.unsqueeze(1) * block_pool_size + torch.arange(block_pool_size, device=original_slot_mapping.device) ) .flatten() .clamp(min=-1) ) return super().build( common_prefix_len, new_common_attn_metadata, fast_build ) # NOTE: We need a custom impl so we can use the transformed slot_mapping # computed by `WhisperCausalAttentionWithBlockPoolingBuilder` instead of # the one from `forward_context.slot_mapping` (gpu_model_runner). # This follows the same pattern as CrossAttentionImpl. class WhisperCausalAttentionWithBlockPoolingImpl(underlying_impl): # type: ignore[valid-type,misc] def forward( self, layer: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, output: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: if ( not underlying_attn_backend.forward_includes_kv_cache_update and attn_metadata is not None and layer.kv_sharing_target_layer_name is None and key is not None and value is not None ): self.do_kv_cache_update( layer, key, value, kv_cache, attn_metadata.slot_mapping ) return super().forward( layer, query, key, value, kv_cache, attn_metadata, output, output_scale, output_block_scale, ) _SUPPORTED_BACKENDS = tuple( b for b in ( AiterFlashAttentionBackend, FlashAttentionBackend, RocmAttentionBackend, TritonAttentionBackend, ) if b is not None ) if not issubclass(underlying_attn_backend, _SUPPORTED_BACKENDS): raise NotImplementedError( f"{underlying_attn_backend} is not yet supported." "Contributions to support more backends are much " "appreciated." ) if not issubclass(underlying_attn_backend, FlashAttentionBackend): logger.info( "Using %s for Whisper causal attention with block pooling. " "This backend was recently enabled for this model. " "If you encounter any accuracy or performance issues, " "please open an issue at " "https://github.com/vllm-project/vllm/issues " "with the [ROCm] tag so it can be triaged by the " "appropriate team.", underlying_attn_backend.get_name(), ) attn_backend = subclass_attention_backend_with_overrides( name_prefix=prefix, attention_backend_cls=underlying_attn_backend, overrides={ "get_builder_cls": lambda: WhisperCausalAttentionWithBlockPoolingBuilder, "get_impl_cls": lambda: WhisperCausalAttentionWithBlockPoolingImpl, "get_kv_cache_shape": lambda num_blocks, block_size, num_kv_heads, head_size, cache_dtype_str: underlying_attn_backend.get_kv_cache_shape( num_blocks, # we stretch each block by `block_pool_size` block_size * block_pool_size, num_kv_heads // block_pool_size, head_size, cache_dtype_str, ), "forward_includes_kv_cache_update": True, }, ) return attn_backend class WhisperCausalAttentionWithBlockPooling(Attention): """Attention layer with block pooling.""" def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int | None = None, alibi_slopes: list[float] | None = None, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, logits_soft_cap: float | None = None, per_layer_sliding_window: int | None = None, prefix: str = "", attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: str | None = None, block_pool_size: int = 1, attn_backend: type[AttentionBackend] | None = None, **extra_impl_args, ) -> None: self.block_pool_size = block_pool_size dtype = torch.get_default_dtype() if cache_config is not None: kv_cache_dtype = cache_config.cache_dtype block_size = cache_config.block_size else: kv_cache_dtype = "auto" block_size = 16 underlying_attn_backend = get_attn_backend( head_size, dtype, kv_cache_dtype, block_size, attn_type=attn_type, ) attn_backend = create_whisper_attention_backend_with_block_pooling( underlying_attn_backend, block_pool_size ) super().__init__( num_heads=num_heads, head_size=head_size, scale=scale, num_kv_heads=num_kv_heads, alibi_slopes=alibi_slopes, cache_config=cache_config, quant_config=quant_config, logits_soft_cap=logits_soft_cap, per_layer_sliding_window=per_layer_sliding_window, prefix=prefix, attn_type=attn_type, kv_sharing_target_layer_name=kv_sharing_target_layer_name, attn_backend=attn_backend, **extra_impl_args, ) def get_kv_cache_spec(self, vllm_config: VllmConfig): kv_cache_spec = super().get_kv_cache_spec(vllm_config) assert isinstance(kv_cache_spec, AttentionSpec) kv_cache_spec = replace( kv_cache_spec, num_kv_heads=self.block_pool_size * kv_cache_spec.num_kv_heads, ) return kv_cache_spec class WhisperCausalAttention(nn.Module): def __init__( self, embed_dim: int, num_heads: int, head_dim: int, max_position_embeddings: int, bias: bool = True, attn_type: AttentionType = AttentionType.DECODER, per_layer_sliding_window: int | None = None, block_pool_size: int = 1, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() self.embed_dim = embed_dim tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size if self.total_num_heads >= tp_size: # Number of heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. assert self.total_num_heads % tp_size == 0 else: # Number of heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_heads == 0 self.num_kv_heads = max(1, self.total_num_heads // tp_size) self.head_dim = head_dim self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.attn_type = attn_type self.scaling = self.head_dim**-0.5 self._init_qkv(embed_dim, bias, quant_config, prefix=prefix) self.out_proj = RowParallelLinear( input_size=self.total_num_heads * self.head_dim, output_size=embed_dim, bias=bias, quant_config=quant_config, prefix=f"{prefix}.out_proj", ) assert block_pool_size > 1, ( f"Causal attention only supports block_pool_size>1, not {block_pool_size}." ) self.attn = WhisperCausalAttentionWithBlockPooling( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.attn", attn_type=AttentionType.DECODER, per_layer_sliding_window=per_layer_sliding_window, block_pool_size=block_pool_size, ) assert per_layer_sliding_window is not None, ( "rope can only used in combination with a sliding window" ) self._init_rotary_emb(max_position_embeddings) def _init_rotary_emb(self, max_position_embeddings: int) -> None: self.rotary_emb = get_rope( self.head_dim, max_position=max_position_embeddings, is_neox_style=False, rope_parameters={"rope_theta": 1e6}, ) def _init_qkv( self, embed_dim: int, bias: bool = True, quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: self.qkv_proj = QKVParallelLinear( hidden_size=embed_dim, head_size=self.head_dim, total_num_heads=self.total_num_heads, total_num_kv_heads=self.total_num_heads, bias=bias, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) def forward( self, hidden_states: torch.Tensor, positions: torch.Tensor | None = None, ): qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) assert positions is not None q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.out_proj(attn_output) return output class WhisperCausalEncoderLayer(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config sliding_window = getattr(config, "sliding_window", None) block_pool_size = config.block_pool_size assert block_pool_size > 1 cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.embed_dim = config.d_model self.head_dim = self.embed_dim // config.encoder_attention_heads self.self_attn = WhisperCausalAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, head_dim=config.encoder_head_dim, max_position_embeddings=config.max_position_embeddings, block_pool_size=block_pool_size, per_layer_sliding_window=sliding_window, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", ) self.self_attn_layer_norm = CausalRMSNorm(self.embed_dim) self.mlp = MistralMLP( hidden_size=config.d_model, intermediate_size=config.encoder_ffn_dim, hidden_act="silu", quant_config=quant_config, bias=True, gate_up_proj_bias=False, prefix=f"{prefix}.mlp", ) self.final_layer_norm = CausalRMSNorm(self.embed_dim) def forward( self, hidden_states: torch.Tensor, positions: torch.Tensor | None = None, ): residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn(hidden_states=hidden_states, positions=positions) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states class WhisperCausalEncoder(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config embed_dim = config.d_model assert WhisperPosEmbedType(config.pos_embed) == WhisperPosEmbedType.ROPE assert config.is_causal self.num_mel_bins = config.num_mel_bins self.max_source_positions = config.max_source_positions self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 self.conv1 = WhisperCausalConv1d(self.num_mel_bins, embed_dim, kernel_size=3) self.conv2 = WhisperCausalConv1d(embed_dim, embed_dim, stride=2, kernel_size=3) self.total_stride = self.conv1.stride[0] * self.conv2.stride[0] self.start_layer, self.end_layer, self.layers = make_layers( config.encoder_layers, lambda prefix: WhisperCausalEncoderLayer( vllm_config=vllm_config, prefix=f"{prefix}.layers" ), prefix=f"{prefix}.layers", ) self.layer_norm = CausalRMSNorm(config.d_model) def forward_conv( self, input_features: torch.Tensor | list[torch.Tensor] ) -> torch.Tensor: hidden_states = [] for features in input_features: embeds = nn.functional.gelu(self.conv1(features)) embeds = nn.functional.gelu(self.conv2(embeds)) embeds = embeds.transpose(-1, -2).to(embeds.dtype) hidden_states.append(embeds) hidden_states = torch.cat(hidden_states) return hidden_states def forward( self, hidden_states: torch.Tensor, positions: torch.Tensor ) -> torch.Tensor: for encoder_layer in self.layers: hidden_states = encoder_layer(hidden_states, positions) hidden_states = self.layer_norm(hidden_states) return hidden_states