Files
sglang/python/sglang/srt/layers/radix_attention.py
2025-08-23 05:40:18 -07:00

117 lines
3.7 KiB
Python

# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Radix attention."""
from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING, Optional
from torch import nn
if TYPE_CHECKING:
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class AttentionType(Enum):
"""
Attention type.
Use string to be compatible with `torch.compile`.
"""
# Decoder attention between previous layer Q/K/V
DECODER = "decoder"
# Encoder attention between previous layer Q/K/V
ENCODER_ONLY = "encoder_only"
class RadixAttention(nn.Module):
"""
The attention layer implementation.
"""
def __init__(
self,
num_heads: int,
head_dim: int,
scaling: float,
num_kv_heads: int,
layer_id: int,
logit_cap: float = 0.0,
v_head_dim: int = -1,
sliding_window_size: int = -1,
is_cross_attention: bool = False,
pos_encoding_mode: str = "NONE",
logit_capping_method: str = "tanh",
quant_config: Optional[QuantizationConfig] = None,
attn_type: AttentionType = AttentionType.DECODER,
use_irope: bool = False,
prefix: str = "",
):
super().__init__()
self.tp_q_head_num = num_heads
self.tp_k_head_num = num_kv_heads
self.tp_v_head_num = num_kv_heads
self.head_dim = head_dim
self.qk_head_dim = head_dim
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
self.scaling = scaling
self.layer_id = layer_id
self.logit_cap = logit_cap
self.sliding_window_size = sliding_window_size or -1
self.is_cross_attention = is_cross_attention
self.use_irope = use_irope
self.k_scale = None
self.v_scale = None
self.k_scale_float = None
self.v_scale_float = None
self.quant_method = None
if quant_config is not None:
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
if self.quant_method is not None:
self.quant_method.create_weights(self)
self.attn_type = attn_type
self.pos_encoding_mode = pos_encoding_mode
self.logit_capping_method = logit_capping_method
self.xai_temperature_len = -1
def forward(
self,
q,
k,
v,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
**kwargs,
):
if k is not None:
# For cross-layer sharing, kv can be None
assert v is not None
if "k_rope" not in kwargs:
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
else:
k = k.view(-1, self.tp_k_head_num, self.v_head_dim)
return forward_batch.attn_backend.forward(
q,
k,
v,
self,
forward_batch,
save_kv_cache,
**kwargs,
)