Feat/support encoder model (like bert) (#4887)

This commit is contained in:
woodx
2025-04-17 16:50:48 +08:00
committed by GitHub
parent 6fb29ffd9e
commit 3bface15e6
8 changed files with 593 additions and 3 deletions

View File

@@ -13,6 +13,7 @@
# ==============================================================================
"""Radix attention."""
from enum import Enum
from typing import Optional
from torch import nn
@@ -22,6 +23,18 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class AttentionType(Enum):
"""
Attention type.
Use string to be compatible with `torch.compile`.
"""
# Decoder attention between previous layer Q/K/V
DECODER = "decoder"
# Encoder attention between previous layer Q/K/V
ENCODER_ONLY = "encoder_only"
class RadixAttention(nn.Module):
"""
The attention layer implementation.
@@ -39,6 +52,7 @@ class RadixAttention(nn.Module):
sliding_window_size: int = -1,
is_cross_attention: bool = False,
quant_config: Optional[QuantizationConfig] = None,
attn_type=AttentionType.DECODER,
prefix: str = "",
use_irope: bool = False,
):
@@ -64,6 +78,7 @@ class RadixAttention(nn.Module):
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
if self.quant_method is not None:
self.quant_method.create_weights(self)
self.attn_type = attn_type
def forward(
self,