Llama3.2 vision model support (#1551)

This commit is contained in:
Liangsheng Yin
2024-10-21 15:01:21 -07:00
committed by GitHub
parent 00611286a1
commit 94cde10920
21 changed files with 1562 additions and 122 deletions

View File

@@ -8,16 +8,12 @@ version = "0.3.4"
description = "SGLang is yet another fast serving framework for large language models and vision language models." description = "SGLang is yet another fast serving framework for large language models and vision language models."
readme = "README.md" readme = "README.md"
requires-python = ">=3.8" requires-python = ">=3.8"
license = {file = "LICENSE"} license = { file = "LICENSE" }
classifiers = [ classifiers = [
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License", "License :: OSI Approved :: Apache Software License",
] ]
dependencies = [ dependencies = ["requests", "tqdm", "numpy"]
"requests",
"tqdm",
"numpy",
]
[project.optional-dependencies] [project.optional-dependencies]
runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular", runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular",
@@ -32,7 +28,14 @@ srt_xpu = ["sglang[runtime_common]"]
openai = ["openai>=1.0", "tiktoken"] openai = ["openai>=1.0", "tiktoken"]
anthropic = ["anthropic>=0.20.0"] anthropic = ["anthropic>=0.20.0"]
litellm = ["litellm>=1.0.0"] litellm = ["litellm>=1.0.0"]
test = ["jsonlines", "matplotlib", "pandas", "sentence_transformers", "accelerate", "peft"] test = [
"jsonlines",
"matplotlib",
"pandas",
"sentence_transformers",
"accelerate",
"peft",
]
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
all_xpu = ["sglang[srt_xpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] all_xpu = ["sglang[srt_xpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
dev = ["sglang[all]", "sglang[test]"] dev = ["sglang[all]", "sglang[test]"]
@@ -43,7 +46,23 @@ dev_xpu = ["sglang[all_xpu]", "sglang[test]"]
"Bug Tracker" = "https://github.com/sgl-project/sglang/issues" "Bug Tracker" = "https://github.com/sgl-project/sglang/issues"
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"] exclude = [
"assets*",
"benchmark*",
"docs*",
"dist*",
"playground*",
"scripts*",
"tests*",
]
[tool.wheel] [tool.wheel]
exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"] exclude = [
"assets*",
"benchmark*",
"docs*",
"dist*",
"playground*",
"scripts*",
"tests*",
]

View File

@@ -227,8 +227,9 @@ def extend(reqs, model_runner):
req_to_token_pool=model_runner.req_to_token_pool, req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool, token_to_kv_pool=model_runner.token_to_kv_pool,
tree_cache=None, tree_cache=None,
model_config=model_runner.model_config,
) )
batch.prepare_for_extend(model_runner.model_config.vocab_size) batch.prepare_for_extend()
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output = model_runner.forward(forward_batch) logits_output = model_runner.forward(forward_batch)

View File

@@ -229,6 +229,7 @@ register_chat_template(
), ),
}, },
stop_str=("<|eot_id|>",), stop_str=("<|eot_id|>",),
image_token="<|image|>",
) )
) )

View File

@@ -89,6 +89,8 @@ class ModelConfig:
self.num_hidden_layers = self.hf_text_config.num_hidden_layers self.num_hidden_layers = self.hf_text_config.num_hidden_layers
self.vocab_size = self.hf_text_config.vocab_size self.vocab_size = self.hf_text_config.vocab_size
self.is_encoder_decoder = self.hf_config.model_type in ["mllama"]
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289 # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def get_total_num_kv_heads(self) -> int: def get_total_num_kv_heads(self) -> int:
"""Returns the total number of KV heads.""" """Returns the total number of KV heads."""

View File

@@ -509,6 +509,19 @@ register_conv_template(
) )
) )
register_conv_template(
Conversation(
name="llama_3_vision",
system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>",
roles=("user", "assistant"),
sep_style=SeparatorStyle.LLAMA3,
sep="",
stop_str=["<|end_of_text|>", "<|eot_id|>"],
image_token="<|image|>",
)
)
register_conv_template( register_conv_template(
Conversation( Conversation(
name="llava_llama_3", name="llava_llama_3",

View File

@@ -1,8 +1,10 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional
import torch import torch
from torch import nn from torch import nn
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -19,7 +21,11 @@ class AttentionBackend(ABC):
raise NotImplementedError() raise NotImplementedError()
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor] = None,
): ):
"""Init the metadata for a forward pass for capturing a cuda graph.""" """Init the metadata for a forward pass for capturing a cuda graph."""
raise NotImplementedError() raise NotImplementedError()
@@ -30,6 +36,7 @@ class AttentionBackend(ABC):
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor] = None,
): ):
"""Init the metadata for a forward pass for replying a cuda graph.""" """Init the metadata for a forward pass for replying a cuda graph."""
raise NotImplementedError() raise NotImplementedError()
@@ -43,7 +50,7 @@ class AttentionBackend(ABC):
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
layer: nn.Module, layer: RadixAttention,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
): ):
"""Run forward on an attention layer.""" """Run forward on an attention layer."""
@@ -57,7 +64,7 @@ class AttentionBackend(ABC):
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
layer: nn.Module, layer: RadixAttention,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
): ):
"""Run a forward for decode.""" """Run a forward for decode."""
@@ -68,7 +75,7 @@ class AttentionBackend(ABC):
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
layer: nn.Module, layer: RadixAttention,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
): ):
"""Run a forward for extend.""" """Run a forward for extend."""

View File

@@ -10,6 +10,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
@@ -134,8 +135,13 @@ class DoubleSparseAttnBackend(AttentionBackend):
) )
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens=None,
): ):
# NOTE: encoder_lens expected to be zeros or None
self.forward_metadata = ( self.forward_metadata = (
self.cuda_graph_start_loc, self.cuda_graph_start_loc,
self.cuda_graph_attn_logits, self.cuda_graph_attn_logits,
@@ -149,14 +155,18 @@ class DoubleSparseAttnBackend(AttentionBackend):
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens=None,
): ):
# NOTE: encoder_lens expected to be zeros or None
self.cuda_graph_start_loc.zero_() self.cuda_graph_start_loc.zero_()
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
def get_cuda_graph_seq_len_fill_value(self): def get_cuda_graph_seq_len_fill_value(self):
return 1 return 1
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): def forward_extend(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
):
# TODO: reuse the buffer across layers # TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim: if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
@@ -172,7 +182,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
) )
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v, k_label layer, forward_batch.out_cache_loc, k, v, k_label
) )
( (
@@ -201,7 +211,9 @@ class DoubleSparseAttnBackend(AttentionBackend):
) )
return o return o
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): def forward_decode(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
):
# During torch.compile, there is a bug in rotary_emb that causes the # During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly. # output value to have a 3D tensor shape. This reshapes the output correctly.
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
@@ -231,7 +243,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
) )
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v, k_label layer, forward_batch.out_cache_loc, k, v, k_label
) )
# NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num # NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num

View File

@@ -11,7 +11,6 @@ from enum import Enum, auto
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import torch import torch
import torch.nn as nn
import triton import triton
import triton.language as tl import triton.language as tl
@@ -21,6 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_flashinfer_available from sglang.srt.utils import is_flashinfer_available
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
if is_flashinfer_available(): if is_flashinfer_available():
@@ -56,13 +56,13 @@ class FlashInferAttnBackend(AttentionBackend):
assert not ( assert not (
model_runner.sliding_window_size is not None model_runner.sliding_window_size is not None
and model_runner.has_cross_attention and model_runner.model_config.is_encoder_decoder
), "Sliding window and cross attention are not supported together" ), "Sliding window and cross attention are not supported together"
if model_runner.sliding_window_size is not None: if model_runner.sliding_window_size is not None:
self.num_wrappers = 2 self.num_wrappers = 2
self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW
elif model_runner.has_cross_attention: elif model_runner.model_config.is_encoder_decoder:
self.num_wrappers = 2 self.num_wrappers = 2
self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION
else: else:
@@ -128,6 +128,8 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch.req_pool_indices, forward_batch.req_pool_indices,
forward_batch.seq_lens, forward_batch.seq_lens,
forward_batch.seq_lens_sum, forward_batch.seq_lens_sum,
decode_wrappers=None,
encoder_lens=forward_batch.encoder_lens,
) )
self.forward_metadata = (self.decode_wrappers,) self.forward_metadata = (self.decode_wrappers,)
else: else:
@@ -144,13 +146,11 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch.req_pool_indices, forward_batch.req_pool_indices,
forward_batch.seq_lens, forward_batch.seq_lens,
prefix_lens, prefix_lens,
use_ragged, use_ragged=use_ragged,
encoder_lens=forward_batch.encoder_lens,
) )
self.forward_metadata = ( self.forward_metadata = (use_ragged, extend_no_prefix)
use_ragged,
extend_no_prefix,
)
def init_cuda_graph_state(self, max_bs: int): def init_cuda_graph_state(self, max_bs: int):
cuda_graph_kv_indices = torch.zeros( cuda_graph_kv_indices = torch.zeros(
@@ -163,7 +163,11 @@ class FlashInferAttnBackend(AttentionBackend):
] ]
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: torch.Tensor = None,
): ):
decode_wrappers = [] decode_wrappers = []
for i in range(self.num_wrappers): for i in range(self.num_wrappers):
@@ -181,7 +185,11 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens_sum = seq_lens.sum().item() seq_lens_sum = seq_lens.sum().item()
self.indices_updater_decode.update( self.indices_updater_decode.update(
req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers req_pool_indices,
seq_lens,
seq_lens_sum,
decode_wrappers=decode_wrappers,
encoder_lens=encoder_lens,
) )
self.cuda_graph_metadata[bs] = decode_wrappers self.cuda_graph_metadata[bs] = decode_wrappers
self.forward_metadata = (decode_wrappers,) self.forward_metadata = (decode_wrappers,)
@@ -192,34 +200,42 @@ class FlashInferAttnBackend(AttentionBackend):
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens: torch.Tensor = None,
): ):
self.indices_updater_decode.update( self.indices_updater_decode.update(
req_pool_indices[:bs], req_pool_indices[:bs],
seq_lens[:bs], seq_lens[:bs],
seq_lens_sum, seq_lens_sum,
self.cuda_graph_metadata[bs], decode_wrappers=self.cuda_graph_metadata[bs],
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
) )
def get_cuda_graph_seq_len_fill_value(self): def get_cuda_graph_seq_len_fill_value(self):
return 0 return 0
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): def forward_extend(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
):
prefill_wrapper_paged = self.prefill_wrappers_paged[ prefill_wrapper_paged = self.prefill_wrappers_paged[
self._get_wrapper_idx(layer) self._get_wrapper_idx(layer)
] ]
use_ragged, extend_no_prefix = self.forward_metadata use_ragged, extend_no_prefix = self.forward_metadata
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
if not use_ragged: if not use_ragged:
if k is not None: if k is not None:
assert v is not None assert v is not None
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
layer.layer_id, forward_batch.out_cache_loc, k, v
)
o = prefill_wrapper_paged.forward( o = prefill_wrapper_paged.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=True, causal=not layer.is_cross_attention,
sm_scale=layer.scaling, sm_scale=layer.scaling,
window_left=layer.sliding_window_size, window_left=layer.sliding_window_size,
logits_soft_cap=layer.logit_cap, logits_soft_cap=layer.logit_cap,
@@ -247,20 +263,23 @@ class FlashInferAttnBackend(AttentionBackend):
o, _ = merge_state(o1, s1, o2, s2) o, _ = merge_state(o1, s1, o2, s2)
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
layer.layer_id, forward_batch.out_cache_loc, k, v
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim) return o.view(-1, layer.tp_q_head_num * layer.head_dim)
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): def forward_decode(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
):
decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)] decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)]
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
if k is not None: if k is not None:
assert v is not None assert v is not None
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
layer.layer_id, forward_batch.out_cache_loc, k, v
)
o = decode_wrapper.forward( o = decode_wrapper.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
@@ -271,7 +290,7 @@ class FlashInferAttnBackend(AttentionBackend):
return o.view(-1, layer.tp_q_head_num * layer.head_dim) return o.view(-1, layer.tp_q_head_num * layer.head_dim)
def _get_wrapper_idx(self, layer: nn.Module): def _get_wrapper_idx(self, layer: RadixAttention):
if self.num_wrappers == 1: if self.num_wrappers == 1:
return 0 return 0
@@ -298,6 +317,8 @@ class FlashInferIndicesUpdaterDecode:
self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1) self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
self.sliding_window_size = model_runner.sliding_window_size self.sliding_window_size = model_runner.sliding_window_size
self.attn_backend = attn_backend
# Buffers and wrappers # Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len self.kv_last_page_len = attn_backend.kv_last_page_len
@@ -305,20 +326,27 @@ class FlashInferIndicesUpdaterDecode:
self.decode_wrappers = attn_backend.decode_wrappers self.decode_wrappers = attn_backend.decode_wrappers
# Dispatch # Dispatch
if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW: if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
self.update = self.update_sliding_window self.update = self.update_sliding_window
elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION: elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
self.update = self.update_cross_attention self.update = self.update_cross_attention
else: else:
assert attn_backend.num_wrappers == 1 assert self.attn_backend.num_wrappers == 1
self.update = self.update_single_wrapper self.update = self.update_single_wrapper
def update(
self, req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers, encoder_lens
):
# Keep the signature for type checking, will be initialized during runtime
raise NotImplementedError()
def update_single_wrapper( def update_single_wrapper(
self, self,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers=None, decode_wrappers=None,
encoder_lens=None,
): ):
decode_wrappers = decode_wrappers or self.decode_wrappers decode_wrappers = decode_wrappers or self.decode_wrappers
self.call_begin_forward( self.call_begin_forward(
@@ -336,6 +364,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers=None, decode_wrappers=None,
encoder_lens=None,
): ):
decode_wrappers = decode_wrappers or self.decode_wrappers decode_wrappers = decode_wrappers or self.decode_wrappers
@@ -363,8 +392,35 @@ class FlashInferIndicesUpdaterDecode:
kv_start_idx_tmp, kv_start_idx_tmp,
) )
def update_cross_attention(self): def update_cross_attention(
raise NotImplementedError() self,
req_pool_indices,
seq_lens,
seq_lens_sum,
decode_wrappers=None,
encoder_lens=None,
):
decode_wrappers = decode_wrappers or self.decode_wrappers
for wrapper_id in range(2):
if wrapper_id == 0:
# Normal attention
paged_kernel_lens = seq_lens
kv_start_idx = encoder_lens
else:
# Cross attention
paged_kernel_lens = encoder_lens
kv_start_idx = torch.zeros_like(encoder_lens)
seq_lens_sum = encoder_lens.sum().item()
self.call_begin_forward(
decode_wrappers[wrapper_id],
req_pool_indices,
paged_kernel_lens,
seq_lens_sum,
self.kv_indptr[wrapper_id],
kv_start_idx,
)
def call_begin_forward( def call_begin_forward(
self, self,
@@ -421,6 +477,8 @@ class FlashInferIndicesUpdaterPrefill:
self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1) self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
self.sliding_window_size = model_runner.sliding_window_size self.sliding_window_size = model_runner.sliding_window_size
self.attn_backend = attn_backend
# Buffers and wrappers # Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len self.kv_last_page_len = attn_backend.kv_last_page_len
@@ -430,16 +488,20 @@ class FlashInferIndicesUpdaterPrefill:
self.wrappers_paged = attn_backend.prefill_wrappers_paged self.wrappers_paged = attn_backend.prefill_wrappers_paged
# Dispatch # Dispatch
if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW: if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
self.update = self.update_sliding_window self.update = self.update_sliding_window
elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION: elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
self.update = self.update_cross_attention self.update = self.update_cross_attention
else: else:
assert attn_backend.num_wrappers == 1 assert self.attn_backend.num_wrappers == 1
self.update = self.update_single_wrapper self.update = self.update_single_wrapper
def update(self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens):
# Keep the signature for type checking, will be initialized during runtime
raise NotImplementedError()
def update_single_wrapper( def update_single_wrapper(
self, req_pool_indices, seq_lens, prefix_lens, use_ragged self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
): ):
if use_ragged: if use_ragged:
paged_kernel_lens = prefix_lens paged_kernel_lens = prefix_lens
@@ -460,7 +522,7 @@ class FlashInferIndicesUpdaterPrefill:
) )
def update_sliding_window( def update_sliding_window(
self, req_pool_indices, seq_lens, prefix_lens, use_ragged self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
): ):
for wrapper_id in range(2): for wrapper_id in range(2):
if wrapper_id == 0: if wrapper_id == 0:
@@ -487,8 +549,31 @@ class FlashInferIndicesUpdaterPrefill:
use_ragged, use_ragged,
) )
def update_cross_attention(self): def update_cross_attention(
raise NotImplementedError() self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
):
for wrapper_id in range(2):
if wrapper_id == 0:
# normal attention
paged_kernel_lens = seq_lens
kv_start_idx = encoder_lens
else:
# cross attention
paged_kernel_lens = encoder_lens
kv_start_idx = torch.zeros_like(encoder_lens)
self.call_begin_forward(
self.wrapper_ragged,
self.wrappers_paged[wrapper_id],
req_pool_indices,
paged_kernel_lens,
seq_lens,
prefix_lens,
kv_start_idx,
self.kv_indptr[wrapper_id],
self.qo_indptr[wrapper_id],
use_ragged,
)
def call_begin_forward( def call_begin_forward(
self, self,

View File

@@ -10,6 +10,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
@@ -81,8 +82,13 @@ class TritonAttnBackend(AttentionBackend):
) )
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens=None,
): ):
# NOTE: encoder_lens expected to be zeros or None
self.forward_metadata = ( self.forward_metadata = (
self.cuda_graph_start_loc, self.cuda_graph_start_loc,
self.cuda_graph_attn_logits, self.cuda_graph_attn_logits,
@@ -96,14 +102,18 @@ class TritonAttnBackend(AttentionBackend):
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens=None,
): ):
# NOTE: encoder_lens expected to be zeros or None
self.cuda_graph_start_loc.zero_() self.cuda_graph_start_loc.zero_()
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
def get_cuda_graph_seq_len_fill_value(self): def get_cuda_graph_seq_len_fill_value(self):
return 1 return 1
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): def forward_extend(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
):
# TODO: reuse the buffer across layers # TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim: if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
@@ -111,7 +121,7 @@ class TritonAttnBackend(AttentionBackend):
o = torch.empty_like(q) o = torch.empty_like(q)
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v layer, forward_batch.out_cache_loc, k, v
) )
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
@@ -133,7 +143,9 @@ class TritonAttnBackend(AttentionBackend):
) )
return o return o
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): def forward_decode(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
):
# During torch.compile, there is a bug in rotary_emb that causes the # During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly. # output value to have a 3D tensor shape. This reshapes the output correctly.
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
@@ -147,7 +159,7 @@ class TritonAttnBackend(AttentionBackend):
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v layer, forward_batch.out_cache_loc, k, v
) )
self.decode_attention_fwd( self.decode_attention_fwd(

View File

@@ -33,20 +33,9 @@ def init_global_processor(server_args: ServerArgs):
class BaseImageProcessor(ABC): class BaseImageProcessor(ABC):
@abstractmethod def __init__(self, hf_config, server_args, _processor):
async def process_images_async(self, image_data, **kwargs):
pass
class DummyImageProcessor(BaseImageProcessor):
async def process_images_async(self, *args, **kwargs):
return None
class LlavaImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _image_processor):
self.hf_config = hf_config self.hf_config = hf_config
self._image_processor = _image_processor self._processor = _processor
self.executor = concurrent.futures.ProcessPoolExecutor( self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor, initializer=init_global_processor,
mp_context=mp.get_context("fork"), mp_context=mp.get_context("fork"),
@@ -54,6 +43,23 @@ class LlavaImageProcessor(BaseImageProcessor):
max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()), max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()),
) )
@abstractmethod
async def process_images_async(self, image_data, input_text, **kwargs):
pass
class DummyImageProcessor(BaseImageProcessor):
def __init__(self):
pass
async def process_images_async(self, *args, **kwargs):
return None
class LlavaImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
@staticmethod @staticmethod
def _process_single_image_task( def _process_single_image_task(
image_data: Union[str, bytes], image_data: Union[str, bytes],
@@ -119,7 +125,7 @@ class LlavaImageProcessor(BaseImageProcessor):
) )
async def process_images_async( async def process_images_async(
self, image_data: List[Union[str, bytes]], request_obj self, image_data: List[Union[str, bytes]], input_text, request_obj
): ):
if not image_data: if not image_data:
return None return None
@@ -177,6 +183,54 @@ class LlavaImageProcessor(BaseImageProcessor):
} }
class MllamaImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_single_image_task(images, input_text):
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
return global_processor(images, input_text, return_tensors="pt")
async def _process_single_image(self, images, input_text):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
MllamaImageProcessor._process_single_image_task,
images,
input_text,
)
else:
image_inputs = self._processor(images, input_text, return_tensors="pt")
return image_inputs
async def process_images_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
):
if not image_data:
return None
if isinstance(input_text, list):
assert len(input_text) and isinstance(input_text[0], int)
input_text = self._processor.tokenizer.decode(input_text)
if not isinstance(image_data, list):
image_data = [image_data]
if len(image_data) > 0:
images = [load_image(image)[0] for image in image_data]
else:
images = load_image(image_data[0])[0]
image_inputs = await self._process_single_image(images, input_text)
image_inputs["image_hashes"] = [hash(str(image_data))]
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
return image_inputs
class Qwen2VLImageProcessor(BaseImageProcessor): class Qwen2VLImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _image_processor): def __init__(self, hf_config, server_args, _image_processor):
self.hf_config = hf_config self.hf_config = hf_config
@@ -237,7 +291,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
return self._process_single_image_task(image_data) return self._process_single_image_task(image_data)
async def process_images_async( async def process_images_async(
self, image_data: List[Union[str, bytes]], request_obj self, image_data: List[Union[str, bytes]], input_text, request_obj
): ):
if not image_data: if not image_data:
return None return None
@@ -292,12 +346,14 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
def get_image_processor( def get_image_processor(
hf_config, server_args: ServerArgs, _image_processor hf_config, server_args: ServerArgs, processor
) -> BaseImageProcessor: ) -> BaseImageProcessor:
if "Qwen2VLForConditionalGeneration" in hf_config.architectures: if "MllamaForConditionalGeneration" in hf_config.architectures:
return Qwen2VLImageProcessor(hf_config, server_args, _image_processor) return MllamaImageProcessor(hf_config, server_args, processor)
elif "Qwen2VLForConditionalGeneration" in hf_config.architectures:
return Qwen2VLImageProcessor(hf_config, server_args, processor.image_processor)
else: else:
return LlavaImageProcessor(hf_config, server_args, _image_processor) return LlavaImageProcessor(hf_config, server_args, processor.image_processor)
def get_dummy_image_processor(): def get_dummy_image_processor():

View File

@@ -36,6 +36,7 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained import RegexGuide from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
@@ -121,11 +122,12 @@ class ImageInputs:
"""The image related inputs.""" """The image related inputs."""
pixel_values: torch.Tensor pixel_values: torch.Tensor
image_hash: int image_hashes: Optional[list] = None
image_sizes: Optional[list] = None image_sizes: Optional[list] = None
image_offsets: Optional[list] = None image_offsets: Optional[list] = None
pad_values: Optional[list] = None pad_values: Optional[list] = None
modalities: Optional[list] = None modalities: Optional[list] = None
num_image_tokens: Optional[int] = None
image_embeds: Optional[List[torch.Tensor]] = None image_embeds: Optional[List[torch.Tensor]] = None
aspect_ratio_ids: Optional[List[torch.Tensor]] = None aspect_ratio_ids: Optional[List[torch.Tensor]] = None
@@ -138,19 +140,27 @@ class ImageInputs:
# Use image hash as fake token_ids, which is then used for prefix matching # Use image hash as fake token_ids, which is then used for prefix matching
ret = ImageInputs( ret = ImageInputs(
pixel_values=obj["pixel_values"], pixel_values=obj["pixel_values"],
image_hash=hash(tuple(obj["image_hashes"])), image_hashes=hash(tuple(obj["image_hashes"])),
image_grid_thws=obj.get("image_grid_thws"),
) )
image_hash = ret.image_hash image_hash = ret.image_hashes
ret.pad_values = [ ret.pad_values = [
(image_hash) % vocab_size, (image_hash) % vocab_size,
(image_hash >> 16) % vocab_size, (image_hash >> 16) % vocab_size,
(image_hash >> 32) % vocab_size, (image_hash >> 32) % vocab_size,
(image_hash >> 64) % vocab_size, (image_hash >> 64) % vocab_size,
] ]
ret.image_sizes = obj["image_sizes"]
# Only when pixel values is not None we have modalities optional_args = [
ret.modalities = obj["modalities"] or ["image"] "image_sizes",
"modalities",
"aspect_ratio_ids",
"aspect_ratio_mask",
"image_grid_thws",
]
for arg in optional_args:
if arg in obj:
setattr(ret, arg, obj[arg])
return ret return ret
@@ -416,6 +426,10 @@ class ScheduleBatch:
req_to_token_pool: ReqToTokenPool = None req_to_token_pool: ReqToTokenPool = None
token_to_kv_pool: BaseTokenToKVPool = None token_to_kv_pool: BaseTokenToKVPool = None
tree_cache: BasePrefixCache = None tree_cache: BasePrefixCache = None
# For utility
model_config: ModelConfig = None
forward_mode: ForwardMode = None forward_mode: ForwardMode = None
sampling_info: SamplingBatchInfo = None sampling_info: SamplingBatchInfo = None
@@ -440,6 +454,12 @@ class ScheduleBatch:
extend_num_tokens: int = None extend_num_tokens: int = None
decoding_reqs: List[Req] = None decoding_reqs: List[Req] = None
# For encoder-decoder
encoder_cached: Optional[List[bool]] = None
encoder_lens: Optional[torch.Tensor] = None
encoder_lens_cpu: Optional[List[int]] = None
encoder_out_cache_loc: Optional[torch.Tensor] = None
# Stream # Stream
has_stream: bool = False has_stream: bool = False
@@ -450,12 +470,20 @@ class ScheduleBatch:
device: str = "cuda" device: str = "cuda"
@classmethod @classmethod
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache): def init_new(
cls,
reqs,
req_to_token_pool,
token_to_kv_pool,
tree_cache,
model_config,
):
return cls( return cls(
reqs=reqs, reqs=reqs,
req_to_token_pool=req_to_token_pool, req_to_token_pool=req_to_token_pool,
token_to_kv_pool=token_to_kv_pool, token_to_kv_pool=token_to_kv_pool,
tree_cache=tree_cache, tree_cache=tree_cache,
model_config=model_config,
return_logprob=any(req.return_logprob for req in reqs), return_logprob=any(req.return_logprob for req in reqs),
has_stream=any(req.stream for req in reqs), has_stream=any(req.stream for req in reqs),
has_regex=any(req.regex_fsm for req in reqs), has_regex=any(req.regex_fsm for req in reqs),
@@ -493,7 +521,78 @@ class ScheduleBatch:
return out_cache_loc return out_cache_loc
def prepare_for_extend(self, vocab_size: int): def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
self.encoder_lens_cpu = []
self.encoder_cached = []
for req in self.reqs:
im = req.image_inputs
if im is None or im.num_image_tokens is None:
# No image input
self.encoder_lens_cpu.append(0)
self.encoder_cached.append(True)
else:
self.encoder_lens_cpu.append(im.num_image_tokens)
self.encoder_cached.append(
self.forward_mode.is_decode()
or len(req.prefix_indices) >= im.num_image_tokens
)
self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to(
self.device, non_blocking=True
)
# Strip encoder infos
pt = 0
decoder_out_cache_loc = []
encoder_out_cache_loc = []
for i, req in enumerate(self.reqs):
encoder_len = self.encoder_lens_cpu[i]
seq_lens[i] -= encoder_len
if len(req.prefix_indices) < encoder_len:
# NOTE: the encoder part should considered as a whole
assert len(req.prefix_indices) == 0
input_ids[i] = input_ids[i][encoder_len:]
encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
decoder_out_cache_loc.append(
self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
)
self.extend_lens[i] -= encoder_len
self.extend_num_tokens -= encoder_len
else:
decoder_out_cache_loc.append(
self.out_cache_loc[pt : pt + req.extend_input_len]
)
self.prefix_lens[i] -= encoder_len
pt += req.extend_input_len
# Reassign
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
self.device, non_blocking=True
)
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
self.device, non_blocking=True
)
if not decoder_out_cache_loc:
self.out_cache_loc = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
else:
self.out_cache_loc = torch.cat(decoder_out_cache_loc)
if not encoder_out_cache_loc:
self.encoder_out_cache_loc = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
else:
self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)
assert len(self.out_cache_loc) == self.extend_num_tokens
def prepare_for_extend(self):
self.forward_mode = ForwardMode.EXTEND self.forward_mode = ForwardMode.EXTEND
bs = len(self.reqs) bs = len(self.reqs)
@@ -561,8 +660,13 @@ class ScheduleBatch:
self.extend_lens = [r.extend_input_len for r in reqs] self.extend_lens = [r.extend_input_len for r in reqs]
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs] self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
if self.model_config.is_encoder_decoder:
self.prepare_encoder_info_extend(input_ids, seq_lens)
self.sampling_info = SamplingBatchInfo.from_schedule_batch( self.sampling_info = SamplingBatchInfo.from_schedule_batch(
self, vocab_size, global_server_args_dict["disable_penalizer"] self,
self.model_config.vocab_size,
global_server_args_dict["disable_penalizer"],
) )
def mix_with_running(self, running_batch: "ScheduleBatch"): def mix_with_running(self, running_batch: "ScheduleBatch"):
@@ -752,6 +856,10 @@ class ScheduleBatch:
return jump_forward_reqs return jump_forward_reqs
def prepare_encoder_info_decode(self):
# Reset the encoder cached status
self.encoder_cached = [True] * len(self.reqs)
def prepare_for_decode(self, enable_overlap: bool = False): def prepare_for_decode(self, enable_overlap: bool = False):
self.forward_mode = ForwardMode.DECODE self.forward_mode = ForwardMode.DECODE
@@ -766,16 +874,22 @@ class ScheduleBatch:
bs = len(self.reqs) bs = len(self.reqs)
self.out_cache_loc = self.alloc_token_slots(bs) self.out_cache_loc = self.alloc_token_slots(bs)
if self.model_config.is_encoder_decoder:
locs = self.encoder_lens + self.seq_lens
self.prepare_encoder_info_decode()
else:
locs = self.seq_lens
if enable_overlap: if enable_overlap:
# Do not use in-place operations in the overlap mode # Do not use in-place operations in the overlap mode
self.req_to_token_pool.write( self.req_to_token_pool.write(
(self.req_pool_indices, self.seq_lens), self.out_cache_loc (self.req_pool_indices, locs), self.out_cache_loc
) )
self.seq_lens = self.seq_lens + 1 self.seq_lens = self.seq_lens + 1
else: else:
# A faster in-place version # A faster in-place version
self.req_to_token_pool.write( self.req_to_token_pool.write(
(self.req_pool_indices, self.seq_lens), self.out_cache_loc (self.req_pool_indices, locs), self.out_cache_loc
) )
self.seq_lens.add_(1) self.seq_lens.add_(1)
self.seq_lens_sum += bs self.seq_lens_sum += bs
@@ -802,6 +916,10 @@ class ScheduleBatch:
# No need to filter # No need to filter
return return
if self.model_config.is_encoder_decoder:
self.encoder_lens = self.encoder_lens[keep_indices]
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
self.reqs = [self.reqs[i] for i in keep_indices] self.reqs = [self.reqs[i] for i in keep_indices]
new_indices = torch.tensor(keep_indices, dtype=torch.int32).to( new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
self.device, non_blocking=True self.device, non_blocking=True
@@ -828,6 +946,11 @@ class ScheduleBatch:
# needs to be called with pre-merged Batch.reqs. # needs to be called with pre-merged Batch.reqs.
self.sampling_info.merge_batch(other.sampling_info) self.sampling_info.merge_batch(other.sampling_info)
# Encoder-decoder infos
if self.model_config.is_encoder_decoder:
self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
self.req_pool_indices = torch.concat( self.req_pool_indices = torch.concat(
[self.req_pool_indices, other.req_pool_indices] [self.req_pool_indices, other.req_pool_indices]
) )
@@ -850,14 +973,11 @@ class ScheduleBatch:
def get_model_worker_batch(self): def get_model_worker_batch(self):
if self.forward_mode.is_decode(): if self.forward_mode.is_decode():
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = ( extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
image_inputs
) = None
else: else:
extend_seq_lens = self.extend_lens extend_seq_lens = self.extend_lens
extend_prefix_lens = self.prefix_lens extend_prefix_lens = self.prefix_lens
extend_logprob_start_lens = self.extend_logprob_start_lens extend_logprob_start_lens = self.extend_logprob_start_lens
image_inputs = [r.image_inputs for r in self.reqs]
if self.has_regex: if self.has_regex:
self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs] self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs]
@@ -887,7 +1007,11 @@ class ScheduleBatch:
extend_seq_lens=extend_seq_lens, extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens, extend_prefix_lens=extend_prefix_lens,
extend_logprob_start_lens=extend_logprob_start_lens, extend_logprob_start_lens=extend_logprob_start_lens,
image_inputs=image_inputs, image_inputs=[r.image_inputs for r in self.reqs],
encoder_cached=self.encoder_cached,
encoder_lens=self.encoder_lens,
encoder_lens_cpu=self.encoder_lens_cpu,
encoder_out_cache_loc=self.encoder_out_cache_loc,
lora_paths=[req.lora_path for req in self.reqs], lora_paths=[req.lora_path for req in self.reqs],
sampling_info=self.sampling_info, sampling_info=self.sampling_info,
mrope_positions_delta=mrope_positions_delta, mrope_positions_delta=mrope_positions_delta,
@@ -897,6 +1021,7 @@ class ScheduleBatch:
# Only contain fields that will be used by process_batch_result # Only contain fields that will be used by process_batch_result
return ScheduleBatch( return ScheduleBatch(
reqs=self.reqs, reqs=self.reqs,
model_config=self.model_config,
forward_mode=self.forward_mode, forward_mode=self.forward_mode,
out_cache_loc=self.out_cache_loc, out_cache_loc=self.out_cache_loc,
return_logprob=self.return_logprob, return_logprob=self.return_logprob,
@@ -944,6 +1069,12 @@ class ModelWorkerBatch:
# For multimodal # For multimodal
image_inputs: Optional[List[ImageInputs]] image_inputs: Optional[List[ImageInputs]]
# For encoder-decoder
encoder_cached: Optional[List[bool]]
encoder_lens: Optional[torch.Tensor]
encoder_lens_cpu: Optional[List[int]]
encoder_out_cache_loc: Optional[torch.Tensor]
# For LoRA # For LoRA
lora_paths: Optional[List[str]] lora_paths: Optional[List[str]]

View File

@@ -662,8 +662,9 @@ class Scheduler:
self.req_to_token_pool, self.req_to_token_pool,
self.token_to_kv_pool, self.token_to_kv_pool,
self.tree_cache, self.tree_cache,
self.model_config,
) )
new_batch.prepare_for_extend(self.model_config.vocab_size) new_batch.prepare_for_extend()
# Mixed-style chunked prefill # Mixed-style chunked prefill
if self.is_mixed_chunk and self.running_batch is not None: if self.is_mixed_chunk and self.running_batch is not None:

View File

@@ -122,7 +122,7 @@ class TokenizerManager:
# We want to parallelize the image pre-processing so we create an executor for it # We want to parallelize the image pre-processing so we create an executor for it
self.image_processor = get_image_processor( self.image_processor = get_image_processor(
self.hf_config, server_args, self.processor.image_processor self.hf_config, server_args, self.processor
) )
else: else:
self.tokenizer = get_tokenizer( self.tokenizer = get_tokenizer(
@@ -191,8 +191,10 @@ class TokenizerManager:
sampling_params = self._get_sampling_params(obj.sampling_params) sampling_params = self._get_sampling_params(obj.sampling_params)
if self.is_generation: if self.is_generation:
image_inputs = await self.image_processor.process_images_async( image_inputs = await self.image_processor.process_images_async(
obj.image_data, obj obj.image_data, input_text or input_ids, obj
) )
if image_inputs and "input_ids" in image_inputs:
input_ids = image_inputs["input_ids"]
return_logprob = obj.return_logprob return_logprob = obj.return_logprob
logprob_start_len = obj.logprob_start_len logprob_start_len = obj.logprob_start_len
top_logprobs_num = obj.top_logprobs_num top_logprobs_num = obj.top_logprobs_num
@@ -217,8 +219,10 @@ class TokenizerManager:
sampling_params = self._get_sampling_params(obj.sampling_params[index]) sampling_params = self._get_sampling_params(obj.sampling_params[index])
if self.is_generation: if self.is_generation:
image_inputs = await self.image_processor.process_images_async( image_inputs = await self.image_processor.process_images_async(
obj.image_data[index], obj obj.image_data[index], input_text or input_ids, obj
) )
if image_inputs and "input_ids" in image_inputs:
input_ids = image_inputs["input_ids"]
return_logprob = obj.return_logprob[index] return_logprob = obj.return_logprob[index]
logprob_start_len = obj.logprob_start_len[index] logprob_start_len = obj.logprob_start_len[index]
top_logprobs_num = obj.top_logprobs_num[index] top_logprobs_num = obj.top_logprobs_num[index]
@@ -263,8 +267,10 @@ class TokenizerManager:
sampling_params = SamplingParams(**obj.sampling_params[0]) sampling_params = SamplingParams(**obj.sampling_params[0])
sampling_params.max_new_tokens = 0 sampling_params.max_new_tokens = 0
image_inputs = await self.image_processor.process_images_async( image_inputs = await self.image_processor.process_images_async(
obj.image_data[0], obj obj.image_data[0], input_text or input_ids, obj
) )
if image_inputs and "input_ids" in image_inputs:
input_ids = image_inputs["input_ids"]
return_logprob = obj.return_logprob[0] return_logprob = obj.return_logprob[0]
logprob_start_len = obj.logprob_start_len[0] logprob_start_len = obj.logprob_start_len[0]
top_logprobs_num = obj.top_logprobs_num[0] top_logprobs_num = obj.top_logprobs_num[0]

View File

@@ -26,6 +26,8 @@ from typing import List, Tuple, Union
import torch import torch
from sglang.srt.layers.radix_attention import RadixAttention
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -41,13 +43,17 @@ class ReqToTokenPool:
) )
self.free_slots = list(range(size)) self.free_slots = list(range(size))
self.write_records = [] self.write_records = []
self.use_records = use_records
if use_records: if self.use_records:
# records all write operations
self.write = self.write_with_records self.write = self.write_with_records
else: else:
self.write = self.write_without_records self.write = self.write_without_records
def write(self, indices, values):
# Keep the signature for type checking, will be initialized during runtime
raise NotImplementedError()
def available_size(self): def available_size(self):
return len(self.free_slots) return len(self.free_slots)
@@ -154,7 +160,7 @@ class BaseTokenToKVPool:
def set_kv_buffer( def set_kv_buffer(
self, self,
layer_id: int, layer: RadixAttention,
loc: torch.Tensor, loc: torch.Tensor,
cache_k: torch.Tensor, cache_k: torch.Tensor,
cache_v: torch.Tensor, cache_v: torch.Tensor,
@@ -209,11 +215,12 @@ class MHATokenToKVPool(BaseTokenToKVPool):
def set_kv_buffer( def set_kv_buffer(
self, self,
layer_id: int, layer: RadixAttention,
loc: torch.Tensor, loc: torch.Tensor,
cache_k: torch.Tensor, cache_k: torch.Tensor,
cache_v: torch.Tensor, cache_v: torch.Tensor,
): ):
layer_id = layer.layer_id
if cache_k.dtype != self.dtype: if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype) cache_k = cache_k.to(self.dtype)
if cache_v.dtype != self.dtype: if cache_v.dtype != self.dtype:
@@ -265,11 +272,12 @@ class MLATokenToKVPool(BaseTokenToKVPool):
def set_kv_buffer( def set_kv_buffer(
self, self,
layer_id: int, layer: RadixAttention,
loc: torch.Tensor, loc: torch.Tensor,
cache_k: torch.Tensor, cache_k: torch.Tensor,
cache_v: torch.Tensor, cache_v: torch.Tensor,
): ):
layer_id = layer.layer_id
if cache_k.dtype != self.dtype: if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype) cache_k = cache_k.to(self.dtype)
if self.store_dtype != self.dtype: if self.store_dtype != self.dtype:
@@ -324,13 +332,14 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
def set_kv_buffer( def set_kv_buffer(
self, self,
layer_id: int, layer: RadixAttention,
loc: torch.Tensor, loc: torch.Tensor,
cache_k: torch.Tensor, cache_k: torch.Tensor,
cache_v: torch.Tensor, cache_v: torch.Tensor,
cache_label: torch.Tensor, cache_label: torch.Tensor,
): ):
# NOTE(Andy): ignore the dtype check # NOTE(Andy): ignore the dtype check
layer_id = layer.layer_id
self.k_buffer[layer_id][loc] = cache_k self.k_buffer[layer_id][loc] = cache_k
self.v_buffer[layer_id][loc] = cache_v self.v_buffer[layer_id][loc] = cache_v
self.label_buffer[layer_id][loc] = cache_label self.label_buffer[layer_id][loc] = cache_label

View File

@@ -105,6 +105,7 @@ class CudaGraphRunner:
self.graph_memory_pool = None self.graph_memory_pool = None
self.use_torch_compile = model_runner.server_args.enable_torch_compile self.use_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
# Batch sizes to capture # Batch sizes to capture
if self.model_runner.server_args.disable_cuda_graph_padding: if self.model_runner.server_args.disable_cuda_graph_padding:
@@ -132,6 +133,9 @@ class CudaGraphRunner:
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
) )
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
self.encoder_len_fill_value = 0
if self.use_torch_compile: if self.use_torch_compile:
set_torch_compile_config() set_torch_compile_config()
@@ -144,9 +148,18 @@ class CudaGraphRunner:
) )
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32) self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)
if self.is_encoder_decoder:
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
self.encoder_lens = torch.full(
(self.max_bs,), self.encoder_len_fill_value, dtype=torch.int32
)
else:
self.encoder_lens = None
# Capture # Capture
try: try:
self.capture() with self.model_capture_mode():
self.capture()
except RuntimeError as e: except RuntimeError as e:
raise Exception( raise Exception(
f"Capture cuda graph failed: {e}\n" f"Capture cuda graph failed: {e}\n"
@@ -157,11 +170,32 @@ class CudaGraphRunner:
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
) )
def can_run(self, batch_size: int): @contextmanager
if self.disable_padding: def model_capture_mode(self):
return batch_size in self.graphs if hasattr(self.model_runner.model, "capture_mode"):
else: self.model_runner.model.capture_mode = True
return batch_size <= self.max_bs
yield
if hasattr(self.model_runner.model, "capture_mode"):
self.model_runner.model.capture_mode = False
def can_run(self, forward_batch: ForwardBatch):
is_bs_supported = (
forward_batch.batch_size in self.graphs
if self.disable_padding
else forward_batch.batch_size <= self.max_bs
)
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
# because the full_text_row_masked_out_mask tensor will always be ones
is_encoder_lens_supported = (
torch.all(forward_batch.encoder_lens > 0)
if self.is_encoder_decoder
else True
)
return is_bs_supported and is_encoder_lens_supported
def capture(self): def capture(self):
with graph_capture() as graph_capture_context: with graph_capture() as graph_capture_context:
@@ -188,11 +222,19 @@ class CudaGraphRunner:
req_pool_indices = self.req_pool_indices[:bs] req_pool_indices = self.req_pool_indices[:bs]
seq_lens = self.seq_lens[:bs] seq_lens = self.seq_lens[:bs]
out_cache_loc = self.out_cache_loc[:bs] out_cache_loc = self.out_cache_loc[:bs]
if self.is_encoder_decoder:
encoder_lens = self.encoder_lens[:bs]
else:
encoder_lens = None
seq_lens_sum = seq_lens.sum().item() seq_lens_sum = seq_lens.sum().item()
# Attention backend # Attention backend
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph( self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
bs, req_pool_indices, seq_lens bs,
req_pool_indices,
seq_lens,
encoder_lens,
) )
# Run and capture # Run and capture
@@ -208,6 +250,7 @@ class CudaGraphRunner:
attn_backend=self.model_runner.attn_backend, attn_backend=self.model_runner.attn_backend,
out_cache_loc=out_cache_loc, out_cache_loc=out_cache_loc,
seq_lens_sum=seq_lens_sum, seq_lens_sum=seq_lens_sum,
encoder_lens=encoder_lens,
return_logprob=False, return_logprob=False,
top_logprobs_nums=[0] * bs, top_logprobs_nums=[0] * bs,
positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64), positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
@@ -251,6 +294,8 @@ class CudaGraphRunner:
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc) self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc)
if self.is_encoder_decoder:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
# Attention backend # Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
@@ -258,6 +303,7 @@ class CudaGraphRunner:
self.req_pool_indices, self.req_pool_indices,
self.seq_lens, self.seq_lens,
forward_batch.seq_lens_sum, forward_batch.seq_lens_sum,
self.encoder_lens,
) )
# Replay # Replay

View File

@@ -108,6 +108,12 @@ class ForwardBatch:
# For multimodal # For multimodal
image_inputs: Optional[List[ImageInputs]] = None image_inputs: Optional[List[ImageInputs]] = None
# Encoder-decoder
encoder_cached: Optional[List[bool]] = None
encoder_lens: Optional[torch.Tensor] = None
encoder_lens_cpu: Optional[List[int]] = None
encoder_out_cache_loc: Optional[torch.Tensor] = None
# For LoRA # For LoRA
lora_paths: Optional[List[str]] = None lora_paths: Optional[List[str]] = None
@@ -194,6 +200,11 @@ class ForwardBatch:
req_pool_indices=batch.req_pool_indices, req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens, seq_lens=batch.seq_lens,
out_cache_loc=batch.out_cache_loc, out_cache_loc=batch.out_cache_loc,
image_inputs=batch.image_inputs,
encoder_cached=batch.encoder_cached,
encoder_lens=batch.encoder_lens,
encoder_lens_cpu=batch.encoder_lens_cpu,
encoder_out_cache_loc=batch.encoder_out_cache_loc,
seq_lens_sum=batch.seq_lens_sum, seq_lens_sum=batch.seq_lens_sum,
return_logprob=batch.return_logprob, return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums, top_logprobs_nums=batch.top_logprobs_nums,
@@ -212,11 +223,11 @@ class ForwardBatch:
], ],
axis=0, axis=0,
) )
ret.image_inputs = batch.image_inputs
ret.extend_num_tokens = batch.extend_num_tokens ret.extend_num_tokens = batch.extend_num_tokens
ret.extend_seq_lens = torch.tensor( ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32 batch.extend_seq_lens, dtype=torch.int32
).to(device, non_blocking=True) ).to(device, non_blocking=True)
ret.extend_prefix_lens = torch.tensor( ret.extend_prefix_lens = torch.tensor(
batch.extend_prefix_lens, dtype=torch.int32 batch.extend_prefix_lens, dtype=torch.int32
).to(device, non_blocking=True) ).to(device, non_blocking=True)

View File

@@ -270,7 +270,6 @@ class ModelRunner:
if hasattr(self.model, "get_attention_sliding_window_size") if hasattr(self.model, "get_attention_sliding_window_size")
else None else None
) )
self.has_cross_attention = getattr(self.model, "has_cross_attention", False)
self.is_generation = is_generation_model( self.is_generation = is_generation_model(
self.model_config.hf_config.architectures, self.server_args.is_embedding self.model_config.hf_config.architectures, self.server_args.is_embedding
) )
@@ -510,7 +509,7 @@ class ModelRunner:
"Window attention is not supported in the triton attention backend. " "Window attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`." "Please use `--attention-backend flashinfer`."
) )
assert not self.has_cross_attention, ( assert not self.model_config.is_encoder_decoder, (
"Cross attention is not supported in the triton attention backend. " "Cross attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`." "Please use `--attention-backend flashinfer`."
) )
@@ -558,9 +557,7 @@ class ModelRunner:
self.cuda_graph_runner = CudaGraphRunner(self) self.cuda_graph_runner = CudaGraphRunner(self)
def forward_decode(self, forward_batch: ForwardBatch): def forward_decode(self, forward_batch: ForwardBatch):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run( if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
forward_batch.batch_size
):
return self.cuda_graph_runner.replay(forward_batch) return self.cuda_graph_runner.replay(forward_batch)
forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64) forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)

File diff suppressed because it is too large Load Diff

View File

@@ -605,7 +605,11 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
] ]
positions = forward_batch.mrope_positions positions = forward_batch.mrope_positions
if image_inputs is None or len(image_inputs) == 0: if (
forward_batch.forward_mode.is_decode()
or image_inputs is None
or len(image_inputs) == 0
):
inputs_embeds = self.model.embed_tokens(input_ids) inputs_embeds = self.model.embed_tokens(input_ids)
else: else:
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":

View File

@@ -209,6 +209,7 @@ def is_multimodal_model(model_architectures):
or "LlavaQwenForCausalLM" in model_architectures or "LlavaQwenForCausalLM" in model_architectures
or "LlavaMistralForCausalLM" in model_architectures or "LlavaMistralForCausalLM" in model_architectures
or "LlavaVidForCausalLM" in model_architectures or "LlavaVidForCausalLM" in model_architectures
or "MllamaForConditionalGeneration" in model_architectures
or "Qwen2VLForConditionalGeneration" in model_architectures or "Qwen2VLForConditionalGeneration" in model_architectures
): ):
return True return True

View File

@@ -171,7 +171,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
assert isinstance(text, str) assert isinstance(text, str)
print(text) print(text)
assert "man" in text or "cab" in text, text assert "man" in text or "cab" in text, text
assert "logo" in text, text assert "logo" in text or '"S"' in text or "SG" in text, text
assert response.id assert response.id
assert response.created assert response.created
assert response.usage.prompt_tokens > 0 assert response.usage.prompt_tokens > 0
@@ -363,5 +363,27 @@ class TestQWen2VLServer(TestOpenAIVisionServer):
cls.base_url += "/v1" cls.base_url += "/v1"
class TestMllamaServer(TestOpenAIVisionServer):
@classmethod
def setUpClass(cls):
cls.model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=[
"--chat-template",
"llama_3_vision",
],
)
cls.base_url += "/v1"
def test_video_chat_completion(self):
pass
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()