Feat/support encoder model (like bert) (#4887)
This commit is contained in:
@@ -6,6 +6,7 @@ import torch
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.radix_attention import AttentionType
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -202,6 +203,10 @@ class TorchNativeAttnBackend(AttentionBackend):
|
||||
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
||||
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||
|
||||
causal = True
|
||||
if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:
|
||||
causal = False
|
||||
|
||||
self._run_sdpa_forward_extend(
|
||||
q_,
|
||||
o_,
|
||||
@@ -214,7 +219,7 @@ class TorchNativeAttnBackend(AttentionBackend):
|
||||
forward_batch.extend_seq_lens,
|
||||
scaling=layer.scaling,
|
||||
enable_gqa=use_gqa,
|
||||
causal=not layer.is_cross_attention,
|
||||
causal=causal,
|
||||
)
|
||||
return o
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import triton.language as tl
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.layers.radix_attention import AttentionType
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.utils import get_bool_env_var, get_device_core_count
|
||||
|
||||
@@ -528,6 +529,10 @@ class TritonAttnBackend(AttentionBackend):
|
||||
layer, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
|
||||
causal = True
|
||||
if layer.attn_type == AttentionType.ENCODER_ONLY:
|
||||
causal = False
|
||||
|
||||
self.extend_attention_fwd(
|
||||
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||
k.contiguous(),
|
||||
@@ -539,6 +544,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
self.forward_metadata.kv_indptr,
|
||||
self.forward_metadata.kv_indices,
|
||||
self.forward_metadata.custom_mask,
|
||||
causal,
|
||||
self.forward_metadata.mask_indptr,
|
||||
self.forward_metadata.max_extend_len,
|
||||
layer.scaling,
|
||||
|
||||
@@ -74,6 +74,7 @@ def _fwd_kernel(
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
USE_CUSTOM_MASK: tl.constexpr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
|
||||
STORE_TRANSPOSE: tl.constexpr,
|
||||
):
|
||||
@@ -129,6 +130,7 @@ def _fwd_kernel(
|
||||
for start_n in range(0, cur_seq_len_prefix, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
mask_n = (start_n + offs_n) < cur_seq_len_prefix
|
||||
|
||||
offs_kv_loc = tl.load(
|
||||
kv_indices + cur_seq_kv_start_idx + start_n + offs_n, mask=mask_n, other=0
|
||||
)
|
||||
@@ -196,7 +198,11 @@ def _fwd_kernel(
|
||||
|
||||
# stage 2: compute the triangle part
|
||||
|
||||
cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
|
||||
cur_block_m_end = (
|
||||
cur_seq_len_extend
|
||||
if not IS_CAUSAL
|
||||
else tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
|
||||
)
|
||||
for start_n in range(0, cur_block_m_end, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
mask_n = (start_n + offs_n) < cur_block_m_end
|
||||
@@ -243,12 +249,15 @@ def _fwd_kernel(
|
||||
)
|
||||
custom_mask &= mask_m[:, None] & mask_n[None, :]
|
||||
qk = tl.where(custom_mask, qk, float("-inf"))
|
||||
else:
|
||||
elif IS_CAUSAL:
|
||||
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
|
||||
start_n + offs_n[None, :]
|
||||
)
|
||||
mask_causual &= mask_m[:, None] & mask_n[None, :]
|
||||
qk = tl.where(mask_causual, qk, float("-inf"))
|
||||
else:
|
||||
mask_non_causal = mask_m[:, None] & mask_n[None, :]
|
||||
qk = tl.where(mask_non_causal, qk, float("-inf"))
|
||||
|
||||
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
||||
re_scale = tl.exp(e_max - n_e_max)
|
||||
@@ -299,6 +308,7 @@ def extend_attention_fwd(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
custom_mask,
|
||||
is_causal,
|
||||
mask_indptr,
|
||||
max_len_extend,
|
||||
sm_scale=None,
|
||||
@@ -411,6 +421,7 @@ def extend_attention_fwd(
|
||||
Lq=Lq,
|
||||
Lv=Lv,
|
||||
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
|
||||
IS_CAUSAL=is_causal,
|
||||
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
|
||||
STORE_TRANSPOSE=_is_hip,
|
||||
num_warps=num_warps,
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# ==============================================================================
|
||||
"""Radix attention."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from torch import nn
|
||||
@@ -22,6 +23,18 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
|
||||
class AttentionType(Enum):
|
||||
"""
|
||||
Attention type.
|
||||
Use string to be compatible with `torch.compile`.
|
||||
"""
|
||||
|
||||
# Decoder attention between previous layer Q/K/V
|
||||
DECODER = "decoder"
|
||||
# Encoder attention between previous layer Q/K/V
|
||||
ENCODER_ONLY = "encoder_only"
|
||||
|
||||
|
||||
class RadixAttention(nn.Module):
|
||||
"""
|
||||
The attention layer implementation.
|
||||
@@ -39,6 +52,7 @@ class RadixAttention(nn.Module):
|
||||
sliding_window_size: int = -1,
|
||||
is_cross_attention: bool = False,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
attn_type=AttentionType.DECODER,
|
||||
prefix: str = "",
|
||||
use_irope: bool = False,
|
||||
):
|
||||
@@ -64,6 +78,7 @@ class RadixAttention(nn.Module):
|
||||
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
||||
if self.quant_method is not None:
|
||||
self.quant_method.create_weights(self)
|
||||
self.attn_type = attn_type
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
398
python/sglang/srt/models/bert.py
Normal file
398
python/sglang/srt/models/bert.py
Normal file
@@ -0,0 +1,398 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Any, Dict, Iterable, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
from sglang.srt.layers.activation import get_act_fn
|
||||
from sglang.srt.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import AttentionType, RadixAttention
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
BertConfig = None
|
||||
|
||||
|
||||
class BertEmbedding(nn.Module):
|
||||
|
||||
def __init__(self, config: BertConfig):
|
||||
|
||||
super().__init__()
|
||||
self.size = config.hidden_size
|
||||
self.word_embeddings = VocabParallelEmbedding(
|
||||
config.vocab_size, config.hidden_size
|
||||
)
|
||||
self.position_embeddings = VocabParallelEmbedding(
|
||||
config.max_position_embeddings, config.hidden_size
|
||||
)
|
||||
self.token_type_embeddings = VocabParallelEmbedding(
|
||||
config.type_vocab_size, config.hidden_size
|
||||
)
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.position_ids = nn.Parameter(
|
||||
torch.empty((1, config.max_position_embeddings)),
|
||||
)
|
||||
|
||||
self.position_embedding_type = config.position_embedding_type
|
||||
if self.position_embedding_type != "absolute":
|
||||
raise ValueError(
|
||||
"Only 'absolute' position_embedding_type" + " is supported"
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
input_shape = input_ids.size()
|
||||
|
||||
# Input embeddings.
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
|
||||
# Position embeddings.
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
|
||||
token_type_ids = torch.zeros(
|
||||
input_shape, dtype=torch.long, device=inputs_embeds.device
|
||||
)
|
||||
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
|
||||
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
class BertEncoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BertConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.layer = nn.ModuleList(
|
||||
[
|
||||
BertLayer(
|
||||
config=config,
|
||||
layer_id=layer_idx,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layer.{layer_idx}",
|
||||
)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
||||
) -> torch.Tensor:
|
||||
for layer in self.layer:
|
||||
hidden_states = layer(hidden_states, forward_batch)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BertConfig,
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention = BertAttention(
|
||||
hidden_size=config.hidden_size,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
layer_id=layer_id,
|
||||
layer_norm_eps=config.layer_norm_eps,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attention",
|
||||
)
|
||||
|
||||
self.intermediate = BertIntermediate(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.intermediate",
|
||||
)
|
||||
|
||||
self.output = BertOutput(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
layer_norm_eps=config.layer_norm_eps,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.output",
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, forward_batch: ForwardBatch):
|
||||
attn_output = self.attention(hidden_states, forward_batch)
|
||||
intermediate_output = self.intermediate(attn_output)
|
||||
output = self.output(intermediate_output, attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class BertAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
layer_norm_eps: float,
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.self_attn = BertSelfAttention(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.output",
|
||||
)
|
||||
|
||||
self.output = BertSelfOutput(
|
||||
hidden_size=hidden_size,
|
||||
layer_norm_eps=layer_norm_eps,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.output",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
||||
) -> torch.Tensor:
|
||||
self_output = self.self_attn(hidden_states, forward_batch)
|
||||
return self.output(self_output, hidden_states)
|
||||
|
||||
|
||||
class BertSelfAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
self.total_num_heads = num_attention_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = self.total_num_heads
|
||||
self.head_dim = self.hidden_size // self.total_num_heads
|
||||
assert self.head_dim * self.total_num_heads == self.hidden_size
|
||||
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=self.hidden_size,
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.total_num_heads,
|
||||
total_num_kv_heads=self.total_num_kv_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
|
||||
self.attn = RadixAttention(
|
||||
num_heads=self.num_heads,
|
||||
head_dim=self.head_dim,
|
||||
scaling=self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
prefix=f"{prefix}.attn",
|
||||
attn_type=AttentionType.ENCODER_ONLY,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
output = self.attn(q, k, v, forward_batch)
|
||||
return output
|
||||
|
||||
|
||||
class BertSelfOutput(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
layer_norm_eps: float,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.dense = RowParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_size=hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense",
|
||||
)
|
||||
self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
hidden_states, _ = self.dense(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertIntermediate(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.dense = ColumnParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_size=intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense",
|
||||
)
|
||||
self.intermediate_act_fn = get_act_fn(hidden_act)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertOutput(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
layer_norm_eps: float,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dense = RowParallelLinear(
|
||||
input_size=intermediate_size,
|
||||
output_size=hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense",
|
||||
)
|
||||
|
||||
self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
hidden_states, _ = self.dense(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
config: BertConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embeddings = BertEmbedding(config)
|
||||
self.encoder = BertEncoder(
|
||||
config=config, quant_config=quant_config, prefix=f"encoder"
|
||||
)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
# self.pooler = BertPooler(config)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
get_embedding: bool = False,
|
||||
) -> torch.Tensor:
|
||||
assert get_embedding == True
|
||||
# Your tokenized IDs
|
||||
|
||||
hidden_states = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=positions,
|
||||
)
|
||||
|
||||
hidden_states = self.encoder(hidden_states, forward_batch=forward_batch)
|
||||
return self.pooler(hidden_states, forward_batch)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "query", "q"),
|
||||
("qkv_proj", "key", "k"),
|
||||
("qkv_proj", "value", "v"),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
name = name.replace("self", "self_attn")
|
||||
if "pooler" in name:
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
class Contriever(BertModel):
|
||||
pass
|
||||
|
||||
|
||||
EntryClass = [BertModel, Contriever]
|
||||
@@ -51,6 +51,8 @@ NUM_TOP_LOGPROBS = 5
|
||||
def get_dtype_str(torch_dtype):
|
||||
if torch_dtype is torch.float16:
|
||||
return "float16"
|
||||
if torch_dtype is torch.float32:
|
||||
return "float32"
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -447,6 +449,7 @@ class SRTRunner:
|
||||
port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
|
||||
lora_paths: List[str] = None,
|
||||
max_loras_per_batch: int = 4,
|
||||
attention_backend: Optional[str] = None,
|
||||
lora_backend: str = "triton",
|
||||
disable_cuda_graph: bool = False,
|
||||
disable_radix_cache: bool = False,
|
||||
@@ -487,6 +490,7 @@ class SRTRunner:
|
||||
lora_paths=lora_paths,
|
||||
max_loras_per_batch=max_loras_per_batch,
|
||||
lora_backend=lora_backend,
|
||||
attention_backend=attention_backend,
|
||||
disable_cuda_graph=disable_cuda_graph,
|
||||
disable_radix_cache=disable_radix_cache,
|
||||
chunked_prefill_size=chunked_prefill_size,
|
||||
|
||||
Reference in New Issue
Block a user