Feat/support encoder model (like bert) (#4887)
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
# ==============================================================================
|
||||
"""Radix attention."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from torch import nn
|
||||
@@ -22,6 +23,18 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
|
||||
class AttentionType(Enum):
|
||||
"""
|
||||
Attention type.
|
||||
Use string to be compatible with `torch.compile`.
|
||||
"""
|
||||
|
||||
# Decoder attention between previous layer Q/K/V
|
||||
DECODER = "decoder"
|
||||
# Encoder attention between previous layer Q/K/V
|
||||
ENCODER_ONLY = "encoder_only"
|
||||
|
||||
|
||||
class RadixAttention(nn.Module):
|
||||
"""
|
||||
The attention layer implementation.
|
||||
@@ -39,6 +52,7 @@ class RadixAttention(nn.Module):
|
||||
sliding_window_size: int = -1,
|
||||
is_cross_attention: bool = False,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
attn_type=AttentionType.DECODER,
|
||||
prefix: str = "",
|
||||
use_irope: bool = False,
|
||||
):
|
||||
@@ -64,6 +78,7 @@ class RadixAttention(nn.Module):
|
||||
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
||||
if self.quant_method is not None:
|
||||
self.quant_method.create_weights(self)
|
||||
self.attn_type = attn_type
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user