Feat/support encoder model (like bert) (#4887)
This commit is contained in:
@@ -6,6 +6,7 @@ import torch
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.radix_attention import AttentionType
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -202,6 +203,10 @@ class TorchNativeAttnBackend(AttentionBackend):
|
||||
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
||||
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||
|
||||
causal = True
|
||||
if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:
|
||||
causal = False
|
||||
|
||||
self._run_sdpa_forward_extend(
|
||||
q_,
|
||||
o_,
|
||||
@@ -214,7 +219,7 @@ class TorchNativeAttnBackend(AttentionBackend):
|
||||
forward_batch.extend_seq_lens,
|
||||
scaling=layer.scaling,
|
||||
enable_gqa=use_gqa,
|
||||
causal=not layer.is_cross_attention,
|
||||
causal=causal,
|
||||
)
|
||||
return o
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import triton.language as tl
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.layers.radix_attention import AttentionType
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.utils import get_bool_env_var, get_device_core_count
|
||||
|
||||
@@ -528,6 +529,10 @@ class TritonAttnBackend(AttentionBackend):
|
||||
layer, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
|
||||
causal = True
|
||||
if layer.attn_type == AttentionType.ENCODER_ONLY:
|
||||
causal = False
|
||||
|
||||
self.extend_attention_fwd(
|
||||
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||
k.contiguous(),
|
||||
@@ -539,6 +544,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
self.forward_metadata.kv_indptr,
|
||||
self.forward_metadata.kv_indices,
|
||||
self.forward_metadata.custom_mask,
|
||||
causal,
|
||||
self.forward_metadata.mask_indptr,
|
||||
self.forward_metadata.max_extend_len,
|
||||
layer.scaling,
|
||||
|
||||
@@ -74,6 +74,7 @@ def _fwd_kernel(
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
USE_CUSTOM_MASK: tl.constexpr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
|
||||
STORE_TRANSPOSE: tl.constexpr,
|
||||
):
|
||||
@@ -129,6 +130,7 @@ def _fwd_kernel(
|
||||
for start_n in range(0, cur_seq_len_prefix, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
mask_n = (start_n + offs_n) < cur_seq_len_prefix
|
||||
|
||||
offs_kv_loc = tl.load(
|
||||
kv_indices + cur_seq_kv_start_idx + start_n + offs_n, mask=mask_n, other=0
|
||||
)
|
||||
@@ -196,7 +198,11 @@ def _fwd_kernel(
|
||||
|
||||
# stage 2: compute the triangle part
|
||||
|
||||
cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
|
||||
cur_block_m_end = (
|
||||
cur_seq_len_extend
|
||||
if not IS_CAUSAL
|
||||
else tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
|
||||
)
|
||||
for start_n in range(0, cur_block_m_end, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
mask_n = (start_n + offs_n) < cur_block_m_end
|
||||
@@ -243,12 +249,15 @@ def _fwd_kernel(
|
||||
)
|
||||
custom_mask &= mask_m[:, None] & mask_n[None, :]
|
||||
qk = tl.where(custom_mask, qk, float("-inf"))
|
||||
else:
|
||||
elif IS_CAUSAL:
|
||||
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
|
||||
start_n + offs_n[None, :]
|
||||
)
|
||||
mask_causual &= mask_m[:, None] & mask_n[None, :]
|
||||
qk = tl.where(mask_causual, qk, float("-inf"))
|
||||
else:
|
||||
mask_non_causal = mask_m[:, None] & mask_n[None, :]
|
||||
qk = tl.where(mask_non_causal, qk, float("-inf"))
|
||||
|
||||
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
||||
re_scale = tl.exp(e_max - n_e_max)
|
||||
@@ -299,6 +308,7 @@ def extend_attention_fwd(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
custom_mask,
|
||||
is_causal,
|
||||
mask_indptr,
|
||||
max_len_extend,
|
||||
sm_scale=None,
|
||||
@@ -411,6 +421,7 @@ def extend_attention_fwd(
|
||||
Lq=Lq,
|
||||
Lv=Lv,
|
||||
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
|
||||
IS_CAUSAL=is_causal,
|
||||
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
|
||||
STORE_TRANSPOSE=_is_hip,
|
||||
num_warps=num_warps,
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# ==============================================================================
|
||||
"""Radix attention."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from torch import nn
|
||||
@@ -22,6 +23,18 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
|
||||
class AttentionType(Enum):
|
||||
"""
|
||||
Attention type.
|
||||
Use string to be compatible with `torch.compile`.
|
||||
"""
|
||||
|
||||
# Decoder attention between previous layer Q/K/V
|
||||
DECODER = "decoder"
|
||||
# Encoder attention between previous layer Q/K/V
|
||||
ENCODER_ONLY = "encoder_only"
|
||||
|
||||
|
||||
class RadixAttention(nn.Module):
|
||||
"""
|
||||
The attention layer implementation.
|
||||
@@ -39,6 +52,7 @@ class RadixAttention(nn.Module):
|
||||
sliding_window_size: int = -1,
|
||||
is_cross_attention: bool = False,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
attn_type=AttentionType.DECODER,
|
||||
prefix: str = "",
|
||||
use_irope: bool = False,
|
||||
):
|
||||
@@ -64,6 +78,7 @@ class RadixAttention(nn.Module):
|
||||
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
||||
if self.quant_method is not None:
|
||||
self.quant_method.create_weights(self)
|
||||
self.attn_type = attn_type
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user