Llama3.2 vision model support (#1551)
This commit is contained in:
@@ -8,16 +8,12 @@ version = "0.3.4"
|
||||
description = "SGLang is yet another fast serving framework for large language models and vision language models."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
license = {file = "LICENSE"}
|
||||
license = { file = "LICENSE" }
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
]
|
||||
dependencies = [
|
||||
"requests",
|
||||
"tqdm",
|
||||
"numpy",
|
||||
]
|
||||
dependencies = ["requests", "tqdm", "numpy"]
|
||||
|
||||
[project.optional-dependencies]
|
||||
runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular",
|
||||
@@ -32,7 +28,14 @@ srt_xpu = ["sglang[runtime_common]"]
|
||||
openai = ["openai>=1.0", "tiktoken"]
|
||||
anthropic = ["anthropic>=0.20.0"]
|
||||
litellm = ["litellm>=1.0.0"]
|
||||
test = ["jsonlines", "matplotlib", "pandas", "sentence_transformers", "accelerate", "peft"]
|
||||
test = [
|
||||
"jsonlines",
|
||||
"matplotlib",
|
||||
"pandas",
|
||||
"sentence_transformers",
|
||||
"accelerate",
|
||||
"peft",
|
||||
]
|
||||
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
|
||||
all_xpu = ["sglang[srt_xpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
|
||||
dev = ["sglang[all]", "sglang[test]"]
|
||||
@@ -43,7 +46,23 @@ dev_xpu = ["sglang[all_xpu]", "sglang[test]"]
|
||||
"Bug Tracker" = "https://github.com/sgl-project/sglang/issues"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"]
|
||||
exclude = [
|
||||
"assets*",
|
||||
"benchmark*",
|
||||
"docs*",
|
||||
"dist*",
|
||||
"playground*",
|
||||
"scripts*",
|
||||
"tests*",
|
||||
]
|
||||
|
||||
[tool.wheel]
|
||||
exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"]
|
||||
exclude = [
|
||||
"assets*",
|
||||
"benchmark*",
|
||||
"docs*",
|
||||
"dist*",
|
||||
"playground*",
|
||||
"scripts*",
|
||||
"tests*",
|
||||
]
|
||||
|
||||
@@ -227,8 +227,9 @@ def extend(reqs, model_runner):
|
||||
req_to_token_pool=model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=model_runner.token_to_kv_pool,
|
||||
tree_cache=None,
|
||||
model_config=model_runner.model_config,
|
||||
)
|
||||
batch.prepare_for_extend(model_runner.model_config.vocab_size)
|
||||
batch.prepare_for_extend()
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
||||
logits_output = model_runner.forward(forward_batch)
|
||||
|
||||
@@ -229,6 +229,7 @@ register_chat_template(
|
||||
),
|
||||
},
|
||||
stop_str=("<|eot_id|>",),
|
||||
image_token="<|image|>",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -89,6 +89,8 @@ class ModelConfig:
|
||||
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
|
||||
self.vocab_size = self.hf_text_config.vocab_size
|
||||
|
||||
self.is_encoder_decoder = self.hf_config.model_type in ["mllama"]
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
||||
def get_total_num_kv_heads(self) -> int:
|
||||
"""Returns the total number of KV heads."""
|
||||
|
||||
@@ -509,6 +509,19 @@ register_conv_template(
|
||||
)
|
||||
)
|
||||
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="llama_3_vision",
|
||||
system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
|
||||
system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>",
|
||||
roles=("user", "assistant"),
|
||||
sep_style=SeparatorStyle.LLAMA3,
|
||||
sep="",
|
||||
stop_str=["<|end_of_text|>", "<|eot_id|>"],
|
||||
image_token="<|image|>",
|
||||
)
|
||||
)
|
||||
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="llava_llama_3",
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
|
||||
@@ -19,7 +21,11 @@ class AttentionBackend(ABC):
|
||||
raise NotImplementedError()
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
|
||||
self,
|
||||
bs: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""Init the metadata for a forward pass for capturing a cuda graph."""
|
||||
raise NotImplementedError()
|
||||
@@ -30,6 +36,7 @@ class AttentionBackend(ABC):
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""Init the metadata for a forward pass for replying a cuda graph."""
|
||||
raise NotImplementedError()
|
||||
@@ -43,7 +50,7 @@ class AttentionBackend(ABC):
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: nn.Module,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
):
|
||||
"""Run forward on an attention layer."""
|
||||
@@ -57,7 +64,7 @@ class AttentionBackend(ABC):
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: nn.Module,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
):
|
||||
"""Run a forward for decode."""
|
||||
@@ -68,7 +75,7 @@ class AttentionBackend(ABC):
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: nn.Module,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
):
|
||||
"""Run a forward for extend."""
|
||||
|
||||
@@ -10,6 +10,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
|
||||
@@ -134,8 +135,13 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
|
||||
self,
|
||||
bs: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens=None,
|
||||
):
|
||||
# NOTE: encoder_lens expected to be zeros or None
|
||||
self.forward_metadata = (
|
||||
self.cuda_graph_start_loc,
|
||||
self.cuda_graph_attn_logits,
|
||||
@@ -149,14 +155,18 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
encoder_lens=None,
|
||||
):
|
||||
# NOTE: encoder_lens expected to be zeros or None
|
||||
self.cuda_graph_start_loc.zero_()
|
||||
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 1
|
||||
|
||||
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
||||
def forward_extend(
|
||||
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
||||
):
|
||||
# TODO: reuse the buffer across layers
|
||||
if layer.qk_head_dim != layer.v_head_dim:
|
||||
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||
@@ -172,7 +182,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer.layer_id, forward_batch.out_cache_loc, k, v, k_label
|
||||
layer, forward_batch.out_cache_loc, k, v, k_label
|
||||
)
|
||||
|
||||
(
|
||||
@@ -201,7 +211,9 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
||||
)
|
||||
return o
|
||||
|
||||
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
||||
def forward_decode(
|
||||
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
||||
):
|
||||
# During torch.compile, there is a bug in rotary_emb that causes the
|
||||
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
||||
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||
@@ -231,7 +243,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer.layer_id, forward_batch.out_cache_loc, k, v, k_label
|
||||
layer, forward_batch.out_cache_loc, k, v, k_label
|
||||
)
|
||||
|
||||
# NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
|
||||
|
||||
@@ -11,7 +11,6 @@ from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
@@ -21,6 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import is_flashinfer_available
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
if is_flashinfer_available():
|
||||
@@ -56,13 +56,13 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
|
||||
assert not (
|
||||
model_runner.sliding_window_size is not None
|
||||
and model_runner.has_cross_attention
|
||||
and model_runner.model_config.is_encoder_decoder
|
||||
), "Sliding window and cross attention are not supported together"
|
||||
|
||||
if model_runner.sliding_window_size is not None:
|
||||
self.num_wrappers = 2
|
||||
self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW
|
||||
elif model_runner.has_cross_attention:
|
||||
elif model_runner.model_config.is_encoder_decoder:
|
||||
self.num_wrappers = 2
|
||||
self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION
|
||||
else:
|
||||
@@ -128,6 +128,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.seq_lens_sum,
|
||||
decode_wrappers=None,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
)
|
||||
self.forward_metadata = (self.decode_wrappers,)
|
||||
else:
|
||||
@@ -144,13 +146,11 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
prefix_lens,
|
||||
use_ragged,
|
||||
use_ragged=use_ragged,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
)
|
||||
|
||||
self.forward_metadata = (
|
||||
use_ragged,
|
||||
extend_no_prefix,
|
||||
)
|
||||
self.forward_metadata = (use_ragged, extend_no_prefix)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int):
|
||||
cuda_graph_kv_indices = torch.zeros(
|
||||
@@ -163,7 +163,11 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
]
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
|
||||
self,
|
||||
bs: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: torch.Tensor = None,
|
||||
):
|
||||
decode_wrappers = []
|
||||
for i in range(self.num_wrappers):
|
||||
@@ -181,7 +185,11 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
|
||||
seq_lens_sum = seq_lens.sum().item()
|
||||
self.indices_updater_decode.update(
|
||||
req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
seq_lens_sum,
|
||||
decode_wrappers=decode_wrappers,
|
||||
encoder_lens=encoder_lens,
|
||||
)
|
||||
self.cuda_graph_metadata[bs] = decode_wrappers
|
||||
self.forward_metadata = (decode_wrappers,)
|
||||
@@ -192,34 +200,42 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: torch.Tensor = None,
|
||||
):
|
||||
self.indices_updater_decode.update(
|
||||
req_pool_indices[:bs],
|
||||
seq_lens[:bs],
|
||||
seq_lens_sum,
|
||||
self.cuda_graph_metadata[bs],
|
||||
decode_wrappers=self.cuda_graph_metadata[bs],
|
||||
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
||||
)
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 0
|
||||
|
||||
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
||||
def forward_extend(
|
||||
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
||||
):
|
||||
prefill_wrapper_paged = self.prefill_wrappers_paged[
|
||||
self._get_wrapper_idx(layer)
|
||||
]
|
||||
|
||||
use_ragged, extend_no_prefix = self.forward_metadata
|
||||
cache_loc = (
|
||||
forward_batch.out_cache_loc
|
||||
if not layer.is_cross_attention
|
||||
else forward_batch.encoder_out_cache_loc
|
||||
)
|
||||
|
||||
if not use_ragged:
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer.layer_id, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
||||
|
||||
o = prefill_wrapper_paged.forward(
|
||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||
causal=True,
|
||||
causal=not layer.is_cross_attention,
|
||||
sm_scale=layer.scaling,
|
||||
window_left=layer.sliding_window_size,
|
||||
logits_soft_cap=layer.logit_cap,
|
||||
@@ -247,20 +263,23 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
|
||||
o, _ = merge_state(o1, s1, o2, s2)
|
||||
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer.layer_id, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
|
||||
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
||||
def forward_decode(
|
||||
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
||||
):
|
||||
decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)]
|
||||
cache_loc = (
|
||||
forward_batch.out_cache_loc
|
||||
if not layer.is_cross_attention
|
||||
else forward_batch.encoder_out_cache_loc
|
||||
)
|
||||
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer.layer_id, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
||||
|
||||
o = decode_wrapper.forward(
|
||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
@@ -271,7 +290,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
|
||||
def _get_wrapper_idx(self, layer: nn.Module):
|
||||
def _get_wrapper_idx(self, layer: RadixAttention):
|
||||
if self.num_wrappers == 1:
|
||||
return 0
|
||||
|
||||
@@ -298,6 +317,8 @@ class FlashInferIndicesUpdaterDecode:
|
||||
self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
|
||||
self.sliding_window_size = model_runner.sliding_window_size
|
||||
|
||||
self.attn_backend = attn_backend
|
||||
|
||||
# Buffers and wrappers
|
||||
self.kv_indptr = attn_backend.kv_indptr
|
||||
self.kv_last_page_len = attn_backend.kv_last_page_len
|
||||
@@ -305,20 +326,27 @@ class FlashInferIndicesUpdaterDecode:
|
||||
self.decode_wrappers = attn_backend.decode_wrappers
|
||||
|
||||
# Dispatch
|
||||
if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
||||
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
||||
self.update = self.update_sliding_window
|
||||
elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
||||
elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
||||
self.update = self.update_cross_attention
|
||||
else:
|
||||
assert attn_backend.num_wrappers == 1
|
||||
assert self.attn_backend.num_wrappers == 1
|
||||
self.update = self.update_single_wrapper
|
||||
|
||||
def update(
|
||||
self, req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers, encoder_lens
|
||||
):
|
||||
# Keep the signature for type checking, will be initialized during runtime
|
||||
raise NotImplementedError()
|
||||
|
||||
def update_single_wrapper(
|
||||
self,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers=None,
|
||||
encoder_lens=None,
|
||||
):
|
||||
decode_wrappers = decode_wrappers or self.decode_wrappers
|
||||
self.call_begin_forward(
|
||||
@@ -336,6 +364,7 @@ class FlashInferIndicesUpdaterDecode:
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers=None,
|
||||
encoder_lens=None,
|
||||
):
|
||||
decode_wrappers = decode_wrappers or self.decode_wrappers
|
||||
|
||||
@@ -363,8 +392,35 @@ class FlashInferIndicesUpdaterDecode:
|
||||
kv_start_idx_tmp,
|
||||
)
|
||||
|
||||
def update_cross_attention(self):
|
||||
raise NotImplementedError()
|
||||
def update_cross_attention(
|
||||
self,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
seq_lens_sum,
|
||||
decode_wrappers=None,
|
||||
encoder_lens=None,
|
||||
):
|
||||
decode_wrappers = decode_wrappers or self.decode_wrappers
|
||||
|
||||
for wrapper_id in range(2):
|
||||
if wrapper_id == 0:
|
||||
# Normal attention
|
||||
paged_kernel_lens = seq_lens
|
||||
kv_start_idx = encoder_lens
|
||||
else:
|
||||
# Cross attention
|
||||
paged_kernel_lens = encoder_lens
|
||||
kv_start_idx = torch.zeros_like(encoder_lens)
|
||||
seq_lens_sum = encoder_lens.sum().item()
|
||||
|
||||
self.call_begin_forward(
|
||||
decode_wrappers[wrapper_id],
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
seq_lens_sum,
|
||||
self.kv_indptr[wrapper_id],
|
||||
kv_start_idx,
|
||||
)
|
||||
|
||||
def call_begin_forward(
|
||||
self,
|
||||
@@ -421,6 +477,8 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
|
||||
self.sliding_window_size = model_runner.sliding_window_size
|
||||
|
||||
self.attn_backend = attn_backend
|
||||
|
||||
# Buffers and wrappers
|
||||
self.kv_indptr = attn_backend.kv_indptr
|
||||
self.kv_last_page_len = attn_backend.kv_last_page_len
|
||||
@@ -430,16 +488,20 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
self.wrappers_paged = attn_backend.prefill_wrappers_paged
|
||||
|
||||
# Dispatch
|
||||
if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
||||
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
||||
self.update = self.update_sliding_window
|
||||
elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
||||
elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
||||
self.update = self.update_cross_attention
|
||||
else:
|
||||
assert attn_backend.num_wrappers == 1
|
||||
assert self.attn_backend.num_wrappers == 1
|
||||
self.update = self.update_single_wrapper
|
||||
|
||||
def update(self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens):
|
||||
# Keep the signature for type checking, will be initialized during runtime
|
||||
raise NotImplementedError()
|
||||
|
||||
def update_single_wrapper(
|
||||
self, req_pool_indices, seq_lens, prefix_lens, use_ragged
|
||||
self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
|
||||
):
|
||||
if use_ragged:
|
||||
paged_kernel_lens = prefix_lens
|
||||
@@ -460,7 +522,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
)
|
||||
|
||||
def update_sliding_window(
|
||||
self, req_pool_indices, seq_lens, prefix_lens, use_ragged
|
||||
self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
|
||||
):
|
||||
for wrapper_id in range(2):
|
||||
if wrapper_id == 0:
|
||||
@@ -487,8 +549,31 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
use_ragged,
|
||||
)
|
||||
|
||||
def update_cross_attention(self):
|
||||
raise NotImplementedError()
|
||||
def update_cross_attention(
|
||||
self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
|
||||
):
|
||||
for wrapper_id in range(2):
|
||||
if wrapper_id == 0:
|
||||
# normal attention
|
||||
paged_kernel_lens = seq_lens
|
||||
kv_start_idx = encoder_lens
|
||||
else:
|
||||
# cross attention
|
||||
paged_kernel_lens = encoder_lens
|
||||
kv_start_idx = torch.zeros_like(encoder_lens)
|
||||
|
||||
self.call_begin_forward(
|
||||
self.wrapper_ragged,
|
||||
self.wrappers_paged[wrapper_id],
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
kv_start_idx,
|
||||
self.kv_indptr[wrapper_id],
|
||||
self.qo_indptr[wrapper_id],
|
||||
use_ragged,
|
||||
)
|
||||
|
||||
def call_begin_forward(
|
||||
self,
|
||||
|
||||
@@ -10,6 +10,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
|
||||
@@ -81,8 +82,13 @@ class TritonAttnBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
|
||||
self,
|
||||
bs: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens=None,
|
||||
):
|
||||
# NOTE: encoder_lens expected to be zeros or None
|
||||
self.forward_metadata = (
|
||||
self.cuda_graph_start_loc,
|
||||
self.cuda_graph_attn_logits,
|
||||
@@ -96,14 +102,18 @@ class TritonAttnBackend(AttentionBackend):
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
encoder_lens=None,
|
||||
):
|
||||
# NOTE: encoder_lens expected to be zeros or None
|
||||
self.cuda_graph_start_loc.zero_()
|
||||
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 1
|
||||
|
||||
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
||||
def forward_extend(
|
||||
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
||||
):
|
||||
# TODO: reuse the buffer across layers
|
||||
if layer.qk_head_dim != layer.v_head_dim:
|
||||
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||
@@ -111,7 +121,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
o = torch.empty_like(q)
|
||||
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer.layer_id, forward_batch.out_cache_loc, k, v
|
||||
layer, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
|
||||
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
|
||||
@@ -133,7 +143,9 @@ class TritonAttnBackend(AttentionBackend):
|
||||
)
|
||||
return o
|
||||
|
||||
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
||||
def forward_decode(
|
||||
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
||||
):
|
||||
# During torch.compile, there is a bug in rotary_emb that causes the
|
||||
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
||||
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||
@@ -147,7 +159,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
|
||||
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer.layer_id, forward_batch.out_cache_loc, k, v
|
||||
layer, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
|
||||
self.decode_attention_fwd(
|
||||
|
||||
@@ -33,20 +33,9 @@ def init_global_processor(server_args: ServerArgs):
|
||||
|
||||
|
||||
class BaseImageProcessor(ABC):
|
||||
@abstractmethod
|
||||
async def process_images_async(self, image_data, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class DummyImageProcessor(BaseImageProcessor):
|
||||
async def process_images_async(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
|
||||
class LlavaImageProcessor(BaseImageProcessor):
|
||||
def __init__(self, hf_config, server_args, _image_processor):
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
self.hf_config = hf_config
|
||||
self._image_processor = _image_processor
|
||||
self._processor = _processor
|
||||
self.executor = concurrent.futures.ProcessPoolExecutor(
|
||||
initializer=init_global_processor,
|
||||
mp_context=mp.get_context("fork"),
|
||||
@@ -54,6 +43,23 @@ class LlavaImageProcessor(BaseImageProcessor):
|
||||
max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()),
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def process_images_async(self, image_data, input_text, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class DummyImageProcessor(BaseImageProcessor):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def process_images_async(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
|
||||
class LlavaImageProcessor(BaseImageProcessor):
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
|
||||
@staticmethod
|
||||
def _process_single_image_task(
|
||||
image_data: Union[str, bytes],
|
||||
@@ -119,7 +125,7 @@ class LlavaImageProcessor(BaseImageProcessor):
|
||||
)
|
||||
|
||||
async def process_images_async(
|
||||
self, image_data: List[Union[str, bytes]], request_obj
|
||||
self, image_data: List[Union[str, bytes]], input_text, request_obj
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
@@ -177,6 +183,54 @@ class LlavaImageProcessor(BaseImageProcessor):
|
||||
}
|
||||
|
||||
|
||||
class MllamaImageProcessor(BaseImageProcessor):
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
|
||||
@staticmethod
|
||||
def _process_single_image_task(images, input_text):
|
||||
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
|
||||
return global_processor(images, input_text, return_tensors="pt")
|
||||
|
||||
async def _process_single_image(self, images, input_text):
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
image_inputs = await loop.run_in_executor(
|
||||
self.executor,
|
||||
MllamaImageProcessor._process_single_image_task,
|
||||
images,
|
||||
input_text,
|
||||
)
|
||||
else:
|
||||
image_inputs = self._processor(images, input_text, return_tensors="pt")
|
||||
|
||||
return image_inputs
|
||||
|
||||
async def process_images_async(
|
||||
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
if isinstance(input_text, list):
|
||||
assert len(input_text) and isinstance(input_text[0], int)
|
||||
input_text = self._processor.tokenizer.decode(input_text)
|
||||
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
|
||||
if len(image_data) > 0:
|
||||
images = [load_image(image)[0] for image in image_data]
|
||||
else:
|
||||
images = load_image(image_data[0])[0]
|
||||
|
||||
image_inputs = await self._process_single_image(images, input_text)
|
||||
image_inputs["image_hashes"] = [hash(str(image_data))]
|
||||
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
|
||||
|
||||
return image_inputs
|
||||
|
||||
|
||||
class Qwen2VLImageProcessor(BaseImageProcessor):
|
||||
def __init__(self, hf_config, server_args, _image_processor):
|
||||
self.hf_config = hf_config
|
||||
@@ -237,7 +291,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
||||
return self._process_single_image_task(image_data)
|
||||
|
||||
async def process_images_async(
|
||||
self, image_data: List[Union[str, bytes]], request_obj
|
||||
self, image_data: List[Union[str, bytes]], input_text, request_obj
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
@@ -292,12 +346,14 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
||||
|
||||
|
||||
def get_image_processor(
|
||||
hf_config, server_args: ServerArgs, _image_processor
|
||||
hf_config, server_args: ServerArgs, processor
|
||||
) -> BaseImageProcessor:
|
||||
if "Qwen2VLForConditionalGeneration" in hf_config.architectures:
|
||||
return Qwen2VLImageProcessor(hf_config, server_args, _image_processor)
|
||||
if "MllamaForConditionalGeneration" in hf_config.architectures:
|
||||
return MllamaImageProcessor(hf_config, server_args, processor)
|
||||
elif "Qwen2VLForConditionalGeneration" in hf_config.architectures:
|
||||
return Qwen2VLImageProcessor(hf_config, server_args, processor.image_processor)
|
||||
else:
|
||||
return LlavaImageProcessor(hf_config, server_args, _image_processor)
|
||||
return LlavaImageProcessor(hf_config, server_args, processor.image_processor)
|
||||
|
||||
|
||||
def get_dummy_image_processor():
|
||||
|
||||
@@ -36,6 +36,7 @@ from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.constrained import RegexGuide
|
||||
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
@@ -121,11 +122,12 @@ class ImageInputs:
|
||||
"""The image related inputs."""
|
||||
|
||||
pixel_values: torch.Tensor
|
||||
image_hash: int
|
||||
image_hashes: Optional[list] = None
|
||||
image_sizes: Optional[list] = None
|
||||
image_offsets: Optional[list] = None
|
||||
pad_values: Optional[list] = None
|
||||
modalities: Optional[list] = None
|
||||
num_image_tokens: Optional[int] = None
|
||||
|
||||
image_embeds: Optional[List[torch.Tensor]] = None
|
||||
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
|
||||
@@ -138,19 +140,27 @@ class ImageInputs:
|
||||
# Use image hash as fake token_ids, which is then used for prefix matching
|
||||
ret = ImageInputs(
|
||||
pixel_values=obj["pixel_values"],
|
||||
image_hash=hash(tuple(obj["image_hashes"])),
|
||||
image_grid_thws=obj.get("image_grid_thws"),
|
||||
image_hashes=hash(tuple(obj["image_hashes"])),
|
||||
)
|
||||
image_hash = ret.image_hash
|
||||
image_hash = ret.image_hashes
|
||||
ret.pad_values = [
|
||||
(image_hash) % vocab_size,
|
||||
(image_hash >> 16) % vocab_size,
|
||||
(image_hash >> 32) % vocab_size,
|
||||
(image_hash >> 64) % vocab_size,
|
||||
]
|
||||
ret.image_sizes = obj["image_sizes"]
|
||||
# Only when pixel values is not None we have modalities
|
||||
ret.modalities = obj["modalities"] or ["image"]
|
||||
|
||||
optional_args = [
|
||||
"image_sizes",
|
||||
"modalities",
|
||||
"aspect_ratio_ids",
|
||||
"aspect_ratio_mask",
|
||||
"image_grid_thws",
|
||||
]
|
||||
for arg in optional_args:
|
||||
if arg in obj:
|
||||
setattr(ret, arg, obj[arg])
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
@@ -416,6 +426,10 @@ class ScheduleBatch:
|
||||
req_to_token_pool: ReqToTokenPool = None
|
||||
token_to_kv_pool: BaseTokenToKVPool = None
|
||||
tree_cache: BasePrefixCache = None
|
||||
|
||||
# For utility
|
||||
model_config: ModelConfig = None
|
||||
|
||||
forward_mode: ForwardMode = None
|
||||
sampling_info: SamplingBatchInfo = None
|
||||
|
||||
@@ -440,6 +454,12 @@ class ScheduleBatch:
|
||||
extend_num_tokens: int = None
|
||||
decoding_reqs: List[Req] = None
|
||||
|
||||
# For encoder-decoder
|
||||
encoder_cached: Optional[List[bool]] = None
|
||||
encoder_lens: Optional[torch.Tensor] = None
|
||||
encoder_lens_cpu: Optional[List[int]] = None
|
||||
encoder_out_cache_loc: Optional[torch.Tensor] = None
|
||||
|
||||
# Stream
|
||||
has_stream: bool = False
|
||||
|
||||
@@ -450,12 +470,20 @@ class ScheduleBatch:
|
||||
device: str = "cuda"
|
||||
|
||||
@classmethod
|
||||
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
|
||||
def init_new(
|
||||
cls,
|
||||
reqs,
|
||||
req_to_token_pool,
|
||||
token_to_kv_pool,
|
||||
tree_cache,
|
||||
model_config,
|
||||
):
|
||||
return cls(
|
||||
reqs=reqs,
|
||||
req_to_token_pool=req_to_token_pool,
|
||||
token_to_kv_pool=token_to_kv_pool,
|
||||
tree_cache=tree_cache,
|
||||
model_config=model_config,
|
||||
return_logprob=any(req.return_logprob for req in reqs),
|
||||
has_stream=any(req.stream for req in reqs),
|
||||
has_regex=any(req.regex_fsm for req in reqs),
|
||||
@@ -493,7 +521,78 @@ class ScheduleBatch:
|
||||
|
||||
return out_cache_loc
|
||||
|
||||
def prepare_for_extend(self, vocab_size: int):
|
||||
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
|
||||
self.encoder_lens_cpu = []
|
||||
self.encoder_cached = []
|
||||
|
||||
for req in self.reqs:
|
||||
im = req.image_inputs
|
||||
if im is None or im.num_image_tokens is None:
|
||||
# No image input
|
||||
self.encoder_lens_cpu.append(0)
|
||||
self.encoder_cached.append(True)
|
||||
else:
|
||||
self.encoder_lens_cpu.append(im.num_image_tokens)
|
||||
self.encoder_cached.append(
|
||||
self.forward_mode.is_decode()
|
||||
or len(req.prefix_indices) >= im.num_image_tokens
|
||||
)
|
||||
|
||||
self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
|
||||
# Strip encoder infos
|
||||
pt = 0
|
||||
decoder_out_cache_loc = []
|
||||
encoder_out_cache_loc = []
|
||||
for i, req in enumerate(self.reqs):
|
||||
encoder_len = self.encoder_lens_cpu[i]
|
||||
seq_lens[i] -= encoder_len
|
||||
|
||||
if len(req.prefix_indices) < encoder_len:
|
||||
# NOTE: the encoder part should considered as a whole
|
||||
assert len(req.prefix_indices) == 0
|
||||
input_ids[i] = input_ids[i][encoder_len:]
|
||||
encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
|
||||
decoder_out_cache_loc.append(
|
||||
self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
|
||||
)
|
||||
self.extend_lens[i] -= encoder_len
|
||||
self.extend_num_tokens -= encoder_len
|
||||
else:
|
||||
decoder_out_cache_loc.append(
|
||||
self.out_cache_loc[pt : pt + req.extend_input_len]
|
||||
)
|
||||
self.prefix_lens[i] -= encoder_len
|
||||
|
||||
pt += req.extend_input_len
|
||||
|
||||
# Reassign
|
||||
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
|
||||
if not decoder_out_cache_loc:
|
||||
self.out_cache_loc = torch.empty(0, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
else:
|
||||
self.out_cache_loc = torch.cat(decoder_out_cache_loc)
|
||||
|
||||
if not encoder_out_cache_loc:
|
||||
self.encoder_out_cache_loc = torch.empty(0, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
else:
|
||||
self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)
|
||||
|
||||
assert len(self.out_cache_loc) == self.extend_num_tokens
|
||||
|
||||
def prepare_for_extend(self):
|
||||
self.forward_mode = ForwardMode.EXTEND
|
||||
|
||||
bs = len(self.reqs)
|
||||
@@ -561,8 +660,13 @@ class ScheduleBatch:
|
||||
self.extend_lens = [r.extend_input_len for r in reqs]
|
||||
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
||||
|
||||
if self.model_config.is_encoder_decoder:
|
||||
self.prepare_encoder_info_extend(input_ids, seq_lens)
|
||||
|
||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
||||
self, vocab_size, global_server_args_dict["disable_penalizer"]
|
||||
self,
|
||||
self.model_config.vocab_size,
|
||||
global_server_args_dict["disable_penalizer"],
|
||||
)
|
||||
|
||||
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
||||
@@ -752,6 +856,10 @@ class ScheduleBatch:
|
||||
|
||||
return jump_forward_reqs
|
||||
|
||||
def prepare_encoder_info_decode(self):
|
||||
# Reset the encoder cached status
|
||||
self.encoder_cached = [True] * len(self.reqs)
|
||||
|
||||
def prepare_for_decode(self, enable_overlap: bool = False):
|
||||
self.forward_mode = ForwardMode.DECODE
|
||||
|
||||
@@ -766,16 +874,22 @@ class ScheduleBatch:
|
||||
bs = len(self.reqs)
|
||||
self.out_cache_loc = self.alloc_token_slots(bs)
|
||||
|
||||
if self.model_config.is_encoder_decoder:
|
||||
locs = self.encoder_lens + self.seq_lens
|
||||
self.prepare_encoder_info_decode()
|
||||
else:
|
||||
locs = self.seq_lens
|
||||
|
||||
if enable_overlap:
|
||||
# Do not use in-place operations in the overlap mode
|
||||
self.req_to_token_pool.write(
|
||||
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
|
||||
(self.req_pool_indices, locs), self.out_cache_loc
|
||||
)
|
||||
self.seq_lens = self.seq_lens + 1
|
||||
else:
|
||||
# A faster in-place version
|
||||
self.req_to_token_pool.write(
|
||||
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
|
||||
(self.req_pool_indices, locs), self.out_cache_loc
|
||||
)
|
||||
self.seq_lens.add_(1)
|
||||
self.seq_lens_sum += bs
|
||||
@@ -802,6 +916,10 @@ class ScheduleBatch:
|
||||
# No need to filter
|
||||
return
|
||||
|
||||
if self.model_config.is_encoder_decoder:
|
||||
self.encoder_lens = self.encoder_lens[keep_indices]
|
||||
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
|
||||
|
||||
self.reqs = [self.reqs[i] for i in keep_indices]
|
||||
new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
@@ -828,6 +946,11 @@ class ScheduleBatch:
|
||||
# needs to be called with pre-merged Batch.reqs.
|
||||
self.sampling_info.merge_batch(other.sampling_info)
|
||||
|
||||
# Encoder-decoder infos
|
||||
if self.model_config.is_encoder_decoder:
|
||||
self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
|
||||
self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
|
||||
|
||||
self.req_pool_indices = torch.concat(
|
||||
[self.req_pool_indices, other.req_pool_indices]
|
||||
)
|
||||
@@ -850,14 +973,11 @@ class ScheduleBatch:
|
||||
|
||||
def get_model_worker_batch(self):
|
||||
if self.forward_mode.is_decode():
|
||||
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = (
|
||||
image_inputs
|
||||
) = None
|
||||
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
||||
else:
|
||||
extend_seq_lens = self.extend_lens
|
||||
extend_prefix_lens = self.prefix_lens
|
||||
extend_logprob_start_lens = self.extend_logprob_start_lens
|
||||
image_inputs = [r.image_inputs for r in self.reqs]
|
||||
|
||||
if self.has_regex:
|
||||
self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs]
|
||||
@@ -887,7 +1007,11 @@ class ScheduleBatch:
|
||||
extend_seq_lens=extend_seq_lens,
|
||||
extend_prefix_lens=extend_prefix_lens,
|
||||
extend_logprob_start_lens=extend_logprob_start_lens,
|
||||
image_inputs=image_inputs,
|
||||
image_inputs=[r.image_inputs for r in self.reqs],
|
||||
encoder_cached=self.encoder_cached,
|
||||
encoder_lens=self.encoder_lens,
|
||||
encoder_lens_cpu=self.encoder_lens_cpu,
|
||||
encoder_out_cache_loc=self.encoder_out_cache_loc,
|
||||
lora_paths=[req.lora_path for req in self.reqs],
|
||||
sampling_info=self.sampling_info,
|
||||
mrope_positions_delta=mrope_positions_delta,
|
||||
@@ -897,6 +1021,7 @@ class ScheduleBatch:
|
||||
# Only contain fields that will be used by process_batch_result
|
||||
return ScheduleBatch(
|
||||
reqs=self.reqs,
|
||||
model_config=self.model_config,
|
||||
forward_mode=self.forward_mode,
|
||||
out_cache_loc=self.out_cache_loc,
|
||||
return_logprob=self.return_logprob,
|
||||
@@ -944,6 +1069,12 @@ class ModelWorkerBatch:
|
||||
# For multimodal
|
||||
image_inputs: Optional[List[ImageInputs]]
|
||||
|
||||
# For encoder-decoder
|
||||
encoder_cached: Optional[List[bool]]
|
||||
encoder_lens: Optional[torch.Tensor]
|
||||
encoder_lens_cpu: Optional[List[int]]
|
||||
encoder_out_cache_loc: Optional[torch.Tensor]
|
||||
|
||||
# For LoRA
|
||||
lora_paths: Optional[List[str]]
|
||||
|
||||
|
||||
@@ -662,8 +662,9 @@ class Scheduler:
|
||||
self.req_to_token_pool,
|
||||
self.token_to_kv_pool,
|
||||
self.tree_cache,
|
||||
self.model_config,
|
||||
)
|
||||
new_batch.prepare_for_extend(self.model_config.vocab_size)
|
||||
new_batch.prepare_for_extend()
|
||||
|
||||
# Mixed-style chunked prefill
|
||||
if self.is_mixed_chunk and self.running_batch is not None:
|
||||
|
||||
@@ -122,7 +122,7 @@ class TokenizerManager:
|
||||
|
||||
# We want to parallelize the image pre-processing so we create an executor for it
|
||||
self.image_processor = get_image_processor(
|
||||
self.hf_config, server_args, self.processor.image_processor
|
||||
self.hf_config, server_args, self.processor
|
||||
)
|
||||
else:
|
||||
self.tokenizer = get_tokenizer(
|
||||
@@ -191,8 +191,10 @@ class TokenizerManager:
|
||||
sampling_params = self._get_sampling_params(obj.sampling_params)
|
||||
if self.is_generation:
|
||||
image_inputs = await self.image_processor.process_images_async(
|
||||
obj.image_data, obj
|
||||
obj.image_data, input_text or input_ids, obj
|
||||
)
|
||||
if image_inputs and "input_ids" in image_inputs:
|
||||
input_ids = image_inputs["input_ids"]
|
||||
return_logprob = obj.return_logprob
|
||||
logprob_start_len = obj.logprob_start_len
|
||||
top_logprobs_num = obj.top_logprobs_num
|
||||
@@ -217,8 +219,10 @@ class TokenizerManager:
|
||||
sampling_params = self._get_sampling_params(obj.sampling_params[index])
|
||||
if self.is_generation:
|
||||
image_inputs = await self.image_processor.process_images_async(
|
||||
obj.image_data[index], obj
|
||||
obj.image_data[index], input_text or input_ids, obj
|
||||
)
|
||||
if image_inputs and "input_ids" in image_inputs:
|
||||
input_ids = image_inputs["input_ids"]
|
||||
return_logprob = obj.return_logprob[index]
|
||||
logprob_start_len = obj.logprob_start_len[index]
|
||||
top_logprobs_num = obj.top_logprobs_num[index]
|
||||
@@ -263,8 +267,10 @@ class TokenizerManager:
|
||||
sampling_params = SamplingParams(**obj.sampling_params[0])
|
||||
sampling_params.max_new_tokens = 0
|
||||
image_inputs = await self.image_processor.process_images_async(
|
||||
obj.image_data[0], obj
|
||||
obj.image_data[0], input_text or input_ids, obj
|
||||
)
|
||||
if image_inputs and "input_ids" in image_inputs:
|
||||
input_ids = image_inputs["input_ids"]
|
||||
return_logprob = obj.return_logprob[0]
|
||||
logprob_start_len = obj.logprob_start_len[0]
|
||||
top_logprobs_num = obj.top_logprobs_num[0]
|
||||
|
||||
@@ -26,6 +26,8 @@ from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -41,13 +43,17 @@ class ReqToTokenPool:
|
||||
)
|
||||
self.free_slots = list(range(size))
|
||||
self.write_records = []
|
||||
self.use_records = use_records
|
||||
|
||||
if use_records:
|
||||
# records all write operations
|
||||
if self.use_records:
|
||||
self.write = self.write_with_records
|
||||
else:
|
||||
self.write = self.write_without_records
|
||||
|
||||
def write(self, indices, values):
|
||||
# Keep the signature for type checking, will be initialized during runtime
|
||||
raise NotImplementedError()
|
||||
|
||||
def available_size(self):
|
||||
return len(self.free_slots)
|
||||
|
||||
@@ -154,7 +160,7 @@ class BaseTokenToKVPool:
|
||||
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
layer_id: int,
|
||||
layer: RadixAttention,
|
||||
loc: torch.Tensor,
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
@@ -209,11 +215,12 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
||||
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
layer_id: int,
|
||||
layer: RadixAttention,
|
||||
loc: torch.Tensor,
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
):
|
||||
layer_id = layer.layer_id
|
||||
if cache_k.dtype != self.dtype:
|
||||
cache_k = cache_k.to(self.dtype)
|
||||
if cache_v.dtype != self.dtype:
|
||||
@@ -265,11 +272,12 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
||||
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
layer_id: int,
|
||||
layer: RadixAttention,
|
||||
loc: torch.Tensor,
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
):
|
||||
layer_id = layer.layer_id
|
||||
if cache_k.dtype != self.dtype:
|
||||
cache_k = cache_k.to(self.dtype)
|
||||
if self.store_dtype != self.dtype:
|
||||
@@ -324,13 +332,14 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
|
||||
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
layer_id: int,
|
||||
layer: RadixAttention,
|
||||
loc: torch.Tensor,
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
cache_label: torch.Tensor,
|
||||
):
|
||||
# NOTE(Andy): ignore the dtype check
|
||||
layer_id = layer.layer_id
|
||||
self.k_buffer[layer_id][loc] = cache_k
|
||||
self.v_buffer[layer_id][loc] = cache_v
|
||||
self.label_buffer[layer_id][loc] = cache_label
|
||||
|
||||
@@ -105,6 +105,7 @@ class CudaGraphRunner:
|
||||
self.graph_memory_pool = None
|
||||
self.use_torch_compile = model_runner.server_args.enable_torch_compile
|
||||
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
|
||||
|
||||
# Batch sizes to capture
|
||||
if self.model_runner.server_args.disable_cuda_graph_padding:
|
||||
@@ -132,6 +133,9 @@ class CudaGraphRunner:
|
||||
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
||||
)
|
||||
|
||||
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
|
||||
self.encoder_len_fill_value = 0
|
||||
|
||||
if self.use_torch_compile:
|
||||
set_torch_compile_config()
|
||||
|
||||
@@ -144,9 +148,18 @@ class CudaGraphRunner:
|
||||
)
|
||||
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
|
||||
self.encoder_lens = torch.full(
|
||||
(self.max_bs,), self.encoder_len_fill_value, dtype=torch.int32
|
||||
)
|
||||
else:
|
||||
self.encoder_lens = None
|
||||
|
||||
# Capture
|
||||
try:
|
||||
self.capture()
|
||||
with self.model_capture_mode():
|
||||
self.capture()
|
||||
except RuntimeError as e:
|
||||
raise Exception(
|
||||
f"Capture cuda graph failed: {e}\n"
|
||||
@@ -157,11 +170,32 @@ class CudaGraphRunner:
|
||||
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
||||
)
|
||||
|
||||
def can_run(self, batch_size: int):
|
||||
if self.disable_padding:
|
||||
return batch_size in self.graphs
|
||||
else:
|
||||
return batch_size <= self.max_bs
|
||||
@contextmanager
|
||||
def model_capture_mode(self):
|
||||
if hasattr(self.model_runner.model, "capture_mode"):
|
||||
self.model_runner.model.capture_mode = True
|
||||
|
||||
yield
|
||||
|
||||
if hasattr(self.model_runner.model, "capture_mode"):
|
||||
self.model_runner.model.capture_mode = False
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
is_bs_supported = (
|
||||
forward_batch.batch_size in self.graphs
|
||||
if self.disable_padding
|
||||
else forward_batch.batch_size <= self.max_bs
|
||||
)
|
||||
|
||||
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
|
||||
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
|
||||
# because the full_text_row_masked_out_mask tensor will always be ones
|
||||
is_encoder_lens_supported = (
|
||||
torch.all(forward_batch.encoder_lens > 0)
|
||||
if self.is_encoder_decoder
|
||||
else True
|
||||
)
|
||||
return is_bs_supported and is_encoder_lens_supported
|
||||
|
||||
def capture(self):
|
||||
with graph_capture() as graph_capture_context:
|
||||
@@ -188,11 +222,19 @@ class CudaGraphRunner:
|
||||
req_pool_indices = self.req_pool_indices[:bs]
|
||||
seq_lens = self.seq_lens[:bs]
|
||||
out_cache_loc = self.out_cache_loc[:bs]
|
||||
if self.is_encoder_decoder:
|
||||
encoder_lens = self.encoder_lens[:bs]
|
||||
else:
|
||||
encoder_lens = None
|
||||
|
||||
seq_lens_sum = seq_lens.sum().item()
|
||||
|
||||
# Attention backend
|
||||
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
|
||||
bs, req_pool_indices, seq_lens
|
||||
bs,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
encoder_lens,
|
||||
)
|
||||
|
||||
# Run and capture
|
||||
@@ -208,6 +250,7 @@ class CudaGraphRunner:
|
||||
attn_backend=self.model_runner.attn_backend,
|
||||
out_cache_loc=out_cache_loc,
|
||||
seq_lens_sum=seq_lens_sum,
|
||||
encoder_lens=encoder_lens,
|
||||
return_logprob=False,
|
||||
top_logprobs_nums=[0] * bs,
|
||||
positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
|
||||
@@ -251,6 +294,8 @@ class CudaGraphRunner:
|
||||
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
||||
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
||||
self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc)
|
||||
if self.is_encoder_decoder:
|
||||
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
||||
|
||||
# Attention backend
|
||||
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||
@@ -258,6 +303,7 @@ class CudaGraphRunner:
|
||||
self.req_pool_indices,
|
||||
self.seq_lens,
|
||||
forward_batch.seq_lens_sum,
|
||||
self.encoder_lens,
|
||||
)
|
||||
|
||||
# Replay
|
||||
|
||||
@@ -108,6 +108,12 @@ class ForwardBatch:
|
||||
# For multimodal
|
||||
image_inputs: Optional[List[ImageInputs]] = None
|
||||
|
||||
# Encoder-decoder
|
||||
encoder_cached: Optional[List[bool]] = None
|
||||
encoder_lens: Optional[torch.Tensor] = None
|
||||
encoder_lens_cpu: Optional[List[int]] = None
|
||||
encoder_out_cache_loc: Optional[torch.Tensor] = None
|
||||
|
||||
# For LoRA
|
||||
lora_paths: Optional[List[str]] = None
|
||||
|
||||
@@ -194,6 +200,11 @@ class ForwardBatch:
|
||||
req_pool_indices=batch.req_pool_indices,
|
||||
seq_lens=batch.seq_lens,
|
||||
out_cache_loc=batch.out_cache_loc,
|
||||
image_inputs=batch.image_inputs,
|
||||
encoder_cached=batch.encoder_cached,
|
||||
encoder_lens=batch.encoder_lens,
|
||||
encoder_lens_cpu=batch.encoder_lens_cpu,
|
||||
encoder_out_cache_loc=batch.encoder_out_cache_loc,
|
||||
seq_lens_sum=batch.seq_lens_sum,
|
||||
return_logprob=batch.return_logprob,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
@@ -212,11 +223,11 @@ class ForwardBatch:
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
ret.image_inputs = batch.image_inputs
|
||||
ret.extend_num_tokens = batch.extend_num_tokens
|
||||
ret.extend_seq_lens = torch.tensor(
|
||||
batch.extend_seq_lens, dtype=torch.int32
|
||||
).to(device, non_blocking=True)
|
||||
|
||||
ret.extend_prefix_lens = torch.tensor(
|
||||
batch.extend_prefix_lens, dtype=torch.int32
|
||||
).to(device, non_blocking=True)
|
||||
|
||||
@@ -270,7 +270,6 @@ class ModelRunner:
|
||||
if hasattr(self.model, "get_attention_sliding_window_size")
|
||||
else None
|
||||
)
|
||||
self.has_cross_attention = getattr(self.model, "has_cross_attention", False)
|
||||
self.is_generation = is_generation_model(
|
||||
self.model_config.hf_config.architectures, self.server_args.is_embedding
|
||||
)
|
||||
@@ -510,7 +509,7 @@ class ModelRunner:
|
||||
"Window attention is not supported in the triton attention backend. "
|
||||
"Please use `--attention-backend flashinfer`."
|
||||
)
|
||||
assert not self.has_cross_attention, (
|
||||
assert not self.model_config.is_encoder_decoder, (
|
||||
"Cross attention is not supported in the triton attention backend. "
|
||||
"Please use `--attention-backend flashinfer`."
|
||||
)
|
||||
@@ -558,9 +557,7 @@ class ModelRunner:
|
||||
self.cuda_graph_runner = CudaGraphRunner(self)
|
||||
|
||||
def forward_decode(self, forward_batch: ForwardBatch):
|
||||
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(
|
||||
forward_batch.batch_size
|
||||
):
|
||||
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
|
||||
return self.cuda_graph_runner.replay(forward_batch)
|
||||
|
||||
forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)
|
||||
|
||||
1004
python/sglang/srt/models/mllama.py
Normal file
1004
python/sglang/srt/models/mllama.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -605,7 +605,11 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
]
|
||||
|
||||
positions = forward_batch.mrope_positions
|
||||
if image_inputs is None or len(image_inputs) == 0:
|
||||
if (
|
||||
forward_batch.forward_mode.is_decode()
|
||||
or image_inputs is None
|
||||
or len(image_inputs) == 0
|
||||
):
|
||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||
else:
|
||||
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
||||
|
||||
@@ -209,6 +209,7 @@ def is_multimodal_model(model_architectures):
|
||||
or "LlavaQwenForCausalLM" in model_architectures
|
||||
or "LlavaMistralForCausalLM" in model_architectures
|
||||
or "LlavaVidForCausalLM" in model_architectures
|
||||
or "MllamaForConditionalGeneration" in model_architectures
|
||||
or "Qwen2VLForConditionalGeneration" in model_architectures
|
||||
):
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user