Llama3.2 vision model support (#1551)

This commit is contained in:
Liangsheng Yin
2024-10-21 15:01:21 -07:00
committed by GitHub
parent 00611286a1
commit 94cde10920
21 changed files with 1562 additions and 122 deletions

View File

@@ -8,16 +8,12 @@ version = "0.3.4"
description = "SGLang is yet another fast serving framework for large language models and vision language models."
readme = "README.md"
requires-python = ">=3.8"
license = {file = "LICENSE"}
license = { file = "LICENSE" }
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
]
dependencies = [
"requests",
"tqdm",
"numpy",
]
dependencies = ["requests", "tqdm", "numpy"]
[project.optional-dependencies]
runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular",
@@ -32,7 +28,14 @@ srt_xpu = ["sglang[runtime_common]"]
openai = ["openai>=1.0", "tiktoken"]
anthropic = ["anthropic>=0.20.0"]
litellm = ["litellm>=1.0.0"]
test = ["jsonlines", "matplotlib", "pandas", "sentence_transformers", "accelerate", "peft"]
test = [
"jsonlines",
"matplotlib",
"pandas",
"sentence_transformers",
"accelerate",
"peft",
]
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
all_xpu = ["sglang[srt_xpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
dev = ["sglang[all]", "sglang[test]"]
@@ -43,7 +46,23 @@ dev_xpu = ["sglang[all_xpu]", "sglang[test]"]
"Bug Tracker" = "https://github.com/sgl-project/sglang/issues"
[tool.setuptools.packages.find]
exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"]
exclude = [
"assets*",
"benchmark*",
"docs*",
"dist*",
"playground*",
"scripts*",
"tests*",
]
[tool.wheel]
exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"]
exclude = [
"assets*",
"benchmark*",
"docs*",
"dist*",
"playground*",
"scripts*",
"tests*",
]

View File

@@ -227,8 +227,9 @@ def extend(reqs, model_runner):
req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool,
tree_cache=None,
model_config=model_runner.model_config,
)
batch.prepare_for_extend(model_runner.model_config.vocab_size)
batch.prepare_for_extend()
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output = model_runner.forward(forward_batch)

View File

@@ -229,6 +229,7 @@ register_chat_template(
),
},
stop_str=("<|eot_id|>",),
image_token="<|image|>",
)
)

View File

@@ -89,6 +89,8 @@ class ModelConfig:
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
self.vocab_size = self.hf_text_config.vocab_size
self.is_encoder_decoder = self.hf_config.model_type in ["mllama"]
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def get_total_num_kv_heads(self) -> int:
"""Returns the total number of KV heads."""

View File

@@ -509,6 +509,19 @@ register_conv_template(
)
)
register_conv_template(
Conversation(
name="llama_3_vision",
system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>",
roles=("user", "assistant"),
sep_style=SeparatorStyle.LLAMA3,
sep="",
stop_str=["<|end_of_text|>", "<|eot_id|>"],
image_token="<|image|>",
)
)
register_conv_template(
Conversation(
name="llava_llama_3",

View File

@@ -1,8 +1,10 @@
from abc import ABC, abstractmethod
from typing import Optional
import torch
from torch import nn
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -19,7 +21,11 @@ class AttentionBackend(ABC):
raise NotImplementedError()
def init_forward_metadata_capture_cuda_graph(
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor] = None,
):
"""Init the metadata for a forward pass for capturing a cuda graph."""
raise NotImplementedError()
@@ -30,6 +36,7 @@ class AttentionBackend(ABC):
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor] = None,
):
"""Init the metadata for a forward pass for replying a cuda graph."""
raise NotImplementedError()
@@ -43,7 +50,7 @@ class AttentionBackend(ABC):
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: nn.Module,
layer: RadixAttention,
forward_batch: ForwardBatch,
):
"""Run forward on an attention layer."""
@@ -57,7 +64,7 @@ class AttentionBackend(ABC):
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: nn.Module,
layer: RadixAttention,
forward_batch: ForwardBatch,
):
"""Run a forward for decode."""
@@ -68,7 +75,7 @@ class AttentionBackend(ABC):
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: nn.Module,
layer: RadixAttention,
forward_batch: ForwardBatch,
):
"""Run a forward for extend."""

View File

@@ -10,6 +10,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
@@ -134,8 +135,13 @@ class DoubleSparseAttnBackend(AttentionBackend):
)
def init_forward_metadata_capture_cuda_graph(
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens=None,
):
# NOTE: encoder_lens expected to be zeros or None
self.forward_metadata = (
self.cuda_graph_start_loc,
self.cuda_graph_attn_logits,
@@ -149,14 +155,18 @@ class DoubleSparseAttnBackend(AttentionBackend):
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens=None,
):
# NOTE: encoder_lens expected to be zeros or None
self.cuda_graph_start_loc.zero_()
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
def get_cuda_graph_seq_len_fill_value(self):
return 1
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
def forward_extend(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
):
# TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
@@ -172,7 +182,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
)
forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v, k_label
layer, forward_batch.out_cache_loc, k, v, k_label
)
(
@@ -201,7 +211,9 @@ class DoubleSparseAttnBackend(AttentionBackend):
)
return o
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
def forward_decode(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
@@ -231,7 +243,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
)
forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v, k_label
layer, forward_batch.out_cache_loc, k, v, k_label
)
# NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num

View File

@@ -11,7 +11,6 @@ from enum import Enum, auto
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
import triton
import triton.language as tl
@@ -21,6 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_flashinfer_available
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
if is_flashinfer_available():
@@ -56,13 +56,13 @@ class FlashInferAttnBackend(AttentionBackend):
assert not (
model_runner.sliding_window_size is not None
and model_runner.has_cross_attention
and model_runner.model_config.is_encoder_decoder
), "Sliding window and cross attention are not supported together"
if model_runner.sliding_window_size is not None:
self.num_wrappers = 2
self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW
elif model_runner.has_cross_attention:
elif model_runner.model_config.is_encoder_decoder:
self.num_wrappers = 2
self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION
else:
@@ -128,6 +128,8 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
decode_wrappers=None,
encoder_lens=forward_batch.encoder_lens,
)
self.forward_metadata = (self.decode_wrappers,)
else:
@@ -144,13 +146,11 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch.req_pool_indices,
forward_batch.seq_lens,
prefix_lens,
use_ragged,
use_ragged=use_ragged,
encoder_lens=forward_batch.encoder_lens,
)
self.forward_metadata = (
use_ragged,
extend_no_prefix,
)
self.forward_metadata = (use_ragged, extend_no_prefix)
def init_cuda_graph_state(self, max_bs: int):
cuda_graph_kv_indices = torch.zeros(
@@ -163,7 +163,11 @@ class FlashInferAttnBackend(AttentionBackend):
]
def init_forward_metadata_capture_cuda_graph(
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: torch.Tensor = None,
):
decode_wrappers = []
for i in range(self.num_wrappers):
@@ -181,7 +185,11 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens_sum = seq_lens.sum().item()
self.indices_updater_decode.update(
req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers
req_pool_indices,
seq_lens,
seq_lens_sum,
decode_wrappers=decode_wrappers,
encoder_lens=encoder_lens,
)
self.cuda_graph_metadata[bs] = decode_wrappers
self.forward_metadata = (decode_wrappers,)
@@ -192,34 +200,42 @@ class FlashInferAttnBackend(AttentionBackend):
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: torch.Tensor = None,
):
self.indices_updater_decode.update(
req_pool_indices[:bs],
seq_lens[:bs],
seq_lens_sum,
self.cuda_graph_metadata[bs],
decode_wrappers=self.cuda_graph_metadata[bs],
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
)
def get_cuda_graph_seq_len_fill_value(self):
return 0
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
def forward_extend(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
):
prefill_wrapper_paged = self.prefill_wrappers_paged[
self._get_wrapper_idx(layer)
]
use_ragged, extend_no_prefix = self.forward_metadata
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
if not use_ragged:
if k is not None:
assert v is not None
forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v
)
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
o = prefill_wrapper_paged.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=True,
causal=not layer.is_cross_attention,
sm_scale=layer.scaling,
window_left=layer.sliding_window_size,
logits_soft_cap=layer.logit_cap,
@@ -247,20 +263,23 @@ class FlashInferAttnBackend(AttentionBackend):
o, _ = merge_state(o1, s1, o2, s2)
forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v
)
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
def forward_decode(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
):
decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)]
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
if k is not None:
assert v is not None
forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v
)
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
o = decode_wrapper.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
@@ -271,7 +290,7 @@ class FlashInferAttnBackend(AttentionBackend):
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
def _get_wrapper_idx(self, layer: nn.Module):
def _get_wrapper_idx(self, layer: RadixAttention):
if self.num_wrappers == 1:
return 0
@@ -298,6 +317,8 @@ class FlashInferIndicesUpdaterDecode:
self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
self.sliding_window_size = model_runner.sliding_window_size
self.attn_backend = attn_backend
# Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len
@@ -305,20 +326,27 @@ class FlashInferIndicesUpdaterDecode:
self.decode_wrappers = attn_backend.decode_wrappers
# Dispatch
if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
self.update = self.update_sliding_window
elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
self.update = self.update_cross_attention
else:
assert attn_backend.num_wrappers == 1
assert self.attn_backend.num_wrappers == 1
self.update = self.update_single_wrapper
def update(
self, req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers, encoder_lens
):
# Keep the signature for type checking, will be initialized during runtime
raise NotImplementedError()
def update_single_wrapper(
self,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
decode_wrappers=None,
encoder_lens=None,
):
decode_wrappers = decode_wrappers or self.decode_wrappers
self.call_begin_forward(
@@ -336,6 +364,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens: torch.Tensor,
seq_lens_sum: int,
decode_wrappers=None,
encoder_lens=None,
):
decode_wrappers = decode_wrappers or self.decode_wrappers
@@ -363,8 +392,35 @@ class FlashInferIndicesUpdaterDecode:
kv_start_idx_tmp,
)
def update_cross_attention(self):
raise NotImplementedError()
def update_cross_attention(
self,
req_pool_indices,
seq_lens,
seq_lens_sum,
decode_wrappers=None,
encoder_lens=None,
):
decode_wrappers = decode_wrappers or self.decode_wrappers
for wrapper_id in range(2):
if wrapper_id == 0:
# Normal attention
paged_kernel_lens = seq_lens
kv_start_idx = encoder_lens
else:
# Cross attention
paged_kernel_lens = encoder_lens
kv_start_idx = torch.zeros_like(encoder_lens)
seq_lens_sum = encoder_lens.sum().item()
self.call_begin_forward(
decode_wrappers[wrapper_id],
req_pool_indices,
paged_kernel_lens,
seq_lens_sum,
self.kv_indptr[wrapper_id],
kv_start_idx,
)
def call_begin_forward(
self,
@@ -421,6 +477,8 @@ class FlashInferIndicesUpdaterPrefill:
self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
self.sliding_window_size = model_runner.sliding_window_size
self.attn_backend = attn_backend
# Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len
@@ -430,16 +488,20 @@ class FlashInferIndicesUpdaterPrefill:
self.wrappers_paged = attn_backend.prefill_wrappers_paged
# Dispatch
if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
self.update = self.update_sliding_window
elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
self.update = self.update_cross_attention
else:
assert attn_backend.num_wrappers == 1
assert self.attn_backend.num_wrappers == 1
self.update = self.update_single_wrapper
def update(self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens):
# Keep the signature for type checking, will be initialized during runtime
raise NotImplementedError()
def update_single_wrapper(
self, req_pool_indices, seq_lens, prefix_lens, use_ragged
self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
):
if use_ragged:
paged_kernel_lens = prefix_lens
@@ -460,7 +522,7 @@ class FlashInferIndicesUpdaterPrefill:
)
def update_sliding_window(
self, req_pool_indices, seq_lens, prefix_lens, use_ragged
self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
):
for wrapper_id in range(2):
if wrapper_id == 0:
@@ -487,8 +549,31 @@ class FlashInferIndicesUpdaterPrefill:
use_ragged,
)
def update_cross_attention(self):
raise NotImplementedError()
def update_cross_attention(
self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
):
for wrapper_id in range(2):
if wrapper_id == 0:
# normal attention
paged_kernel_lens = seq_lens
kv_start_idx = encoder_lens
else:
# cross attention
paged_kernel_lens = encoder_lens
kv_start_idx = torch.zeros_like(encoder_lens)
self.call_begin_forward(
self.wrapper_ragged,
self.wrappers_paged[wrapper_id],
req_pool_indices,
paged_kernel_lens,
seq_lens,
prefix_lens,
kv_start_idx,
self.kv_indptr[wrapper_id],
self.qo_indptr[wrapper_id],
use_ragged,
)
def call_begin_forward(
self,

View File

@@ -10,6 +10,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
@@ -81,8 +82,13 @@ class TritonAttnBackend(AttentionBackend):
)
def init_forward_metadata_capture_cuda_graph(
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens=None,
):
# NOTE: encoder_lens expected to be zeros or None
self.forward_metadata = (
self.cuda_graph_start_loc,
self.cuda_graph_attn_logits,
@@ -96,14 +102,18 @@ class TritonAttnBackend(AttentionBackend):
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens=None,
):
# NOTE: encoder_lens expected to be zeros or None
self.cuda_graph_start_loc.zero_()
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
def get_cuda_graph_seq_len_fill_value(self):
return 1
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
def forward_extend(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
):
# TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
@@ -111,7 +121,7 @@ class TritonAttnBackend(AttentionBackend):
o = torch.empty_like(q)
forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v
layer, forward_batch.out_cache_loc, k, v
)
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
@@ -133,7 +143,9 @@ class TritonAttnBackend(AttentionBackend):
)
return o
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
def forward_decode(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
@@ -147,7 +159,7 @@ class TritonAttnBackend(AttentionBackend):
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v
layer, forward_batch.out_cache_loc, k, v
)
self.decode_attention_fwd(

View File

@@ -33,20 +33,9 @@ def init_global_processor(server_args: ServerArgs):
class BaseImageProcessor(ABC):
@abstractmethod
async def process_images_async(self, image_data, **kwargs):
pass
class DummyImageProcessor(BaseImageProcessor):
async def process_images_async(self, *args, **kwargs):
return None
class LlavaImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _image_processor):
def __init__(self, hf_config, server_args, _processor):
self.hf_config = hf_config
self._image_processor = _image_processor
self._processor = _processor
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
mp_context=mp.get_context("fork"),
@@ -54,6 +43,23 @@ class LlavaImageProcessor(BaseImageProcessor):
max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()),
)
@abstractmethod
async def process_images_async(self, image_data, input_text, **kwargs):
pass
class DummyImageProcessor(BaseImageProcessor):
def __init__(self):
pass
async def process_images_async(self, *args, **kwargs):
return None
class LlavaImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_single_image_task(
image_data: Union[str, bytes],
@@ -119,7 +125,7 @@ class LlavaImageProcessor(BaseImageProcessor):
)
async def process_images_async(
self, image_data: List[Union[str, bytes]], request_obj
self, image_data: List[Union[str, bytes]], input_text, request_obj
):
if not image_data:
return None
@@ -177,6 +183,54 @@ class LlavaImageProcessor(BaseImageProcessor):
}
class MllamaImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_single_image_task(images, input_text):
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
return global_processor(images, input_text, return_tensors="pt")
async def _process_single_image(self, images, input_text):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
MllamaImageProcessor._process_single_image_task,
images,
input_text,
)
else:
image_inputs = self._processor(images, input_text, return_tensors="pt")
return image_inputs
async def process_images_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
):
if not image_data:
return None
if isinstance(input_text, list):
assert len(input_text) and isinstance(input_text[0], int)
input_text = self._processor.tokenizer.decode(input_text)
if not isinstance(image_data, list):
image_data = [image_data]
if len(image_data) > 0:
images = [load_image(image)[0] for image in image_data]
else:
images = load_image(image_data[0])[0]
image_inputs = await self._process_single_image(images, input_text)
image_inputs["image_hashes"] = [hash(str(image_data))]
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
return image_inputs
class Qwen2VLImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _image_processor):
self.hf_config = hf_config
@@ -237,7 +291,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
return self._process_single_image_task(image_data)
async def process_images_async(
self, image_data: List[Union[str, bytes]], request_obj
self, image_data: List[Union[str, bytes]], input_text, request_obj
):
if not image_data:
return None
@@ -292,12 +346,14 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
def get_image_processor(
hf_config, server_args: ServerArgs, _image_processor
hf_config, server_args: ServerArgs, processor
) -> BaseImageProcessor:
if "Qwen2VLForConditionalGeneration" in hf_config.architectures:
return Qwen2VLImageProcessor(hf_config, server_args, _image_processor)
if "MllamaForConditionalGeneration" in hf_config.architectures:
return MllamaImageProcessor(hf_config, server_args, processor)
elif "Qwen2VLForConditionalGeneration" in hf_config.architectures:
return Qwen2VLImageProcessor(hf_config, server_args, processor.image_processor)
else:
return LlavaImageProcessor(hf_config, server_args, _image_processor)
return LlavaImageProcessor(hf_config, server_args, processor.image_processor)
def get_dummy_image_processor():

View File

@@ -36,6 +36,7 @@ from typing import List, Optional, Tuple, Union
import torch
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
@@ -121,11 +122,12 @@ class ImageInputs:
"""The image related inputs."""
pixel_values: torch.Tensor
image_hash: int
image_hashes: Optional[list] = None
image_sizes: Optional[list] = None
image_offsets: Optional[list] = None
pad_values: Optional[list] = None
modalities: Optional[list] = None
num_image_tokens: Optional[int] = None
image_embeds: Optional[List[torch.Tensor]] = None
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
@@ -138,19 +140,27 @@ class ImageInputs:
# Use image hash as fake token_ids, which is then used for prefix matching
ret = ImageInputs(
pixel_values=obj["pixel_values"],
image_hash=hash(tuple(obj["image_hashes"])),
image_grid_thws=obj.get("image_grid_thws"),
image_hashes=hash(tuple(obj["image_hashes"])),
)
image_hash = ret.image_hash
image_hash = ret.image_hashes
ret.pad_values = [
(image_hash) % vocab_size,
(image_hash >> 16) % vocab_size,
(image_hash >> 32) % vocab_size,
(image_hash >> 64) % vocab_size,
]
ret.image_sizes = obj["image_sizes"]
# Only when pixel values is not None we have modalities
ret.modalities = obj["modalities"] or ["image"]
optional_args = [
"image_sizes",
"modalities",
"aspect_ratio_ids",
"aspect_ratio_mask",
"image_grid_thws",
]
for arg in optional_args:
if arg in obj:
setattr(ret, arg, obj[arg])
return ret
@@ -416,6 +426,10 @@ class ScheduleBatch:
req_to_token_pool: ReqToTokenPool = None
token_to_kv_pool: BaseTokenToKVPool = None
tree_cache: BasePrefixCache = None
# For utility
model_config: ModelConfig = None
forward_mode: ForwardMode = None
sampling_info: SamplingBatchInfo = None
@@ -440,6 +454,12 @@ class ScheduleBatch:
extend_num_tokens: int = None
decoding_reqs: List[Req] = None
# For encoder-decoder
encoder_cached: Optional[List[bool]] = None
encoder_lens: Optional[torch.Tensor] = None
encoder_lens_cpu: Optional[List[int]] = None
encoder_out_cache_loc: Optional[torch.Tensor] = None
# Stream
has_stream: bool = False
@@ -450,12 +470,20 @@ class ScheduleBatch:
device: str = "cuda"
@classmethod
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
def init_new(
cls,
reqs,
req_to_token_pool,
token_to_kv_pool,
tree_cache,
model_config,
):
return cls(
reqs=reqs,
req_to_token_pool=req_to_token_pool,
token_to_kv_pool=token_to_kv_pool,
tree_cache=tree_cache,
model_config=model_config,
return_logprob=any(req.return_logprob for req in reqs),
has_stream=any(req.stream for req in reqs),
has_regex=any(req.regex_fsm for req in reqs),
@@ -493,7 +521,78 @@ class ScheduleBatch:
return out_cache_loc
def prepare_for_extend(self, vocab_size: int):
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
self.encoder_lens_cpu = []
self.encoder_cached = []
for req in self.reqs:
im = req.image_inputs
if im is None or im.num_image_tokens is None:
# No image input
self.encoder_lens_cpu.append(0)
self.encoder_cached.append(True)
else:
self.encoder_lens_cpu.append(im.num_image_tokens)
self.encoder_cached.append(
self.forward_mode.is_decode()
or len(req.prefix_indices) >= im.num_image_tokens
)
self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to(
self.device, non_blocking=True
)
# Strip encoder infos
pt = 0
decoder_out_cache_loc = []
encoder_out_cache_loc = []
for i, req in enumerate(self.reqs):
encoder_len = self.encoder_lens_cpu[i]
seq_lens[i] -= encoder_len
if len(req.prefix_indices) < encoder_len:
# NOTE: the encoder part should considered as a whole
assert len(req.prefix_indices) == 0
input_ids[i] = input_ids[i][encoder_len:]
encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
decoder_out_cache_loc.append(
self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
)
self.extend_lens[i] -= encoder_len
self.extend_num_tokens -= encoder_len
else:
decoder_out_cache_loc.append(
self.out_cache_loc[pt : pt + req.extend_input_len]
)
self.prefix_lens[i] -= encoder_len
pt += req.extend_input_len
# Reassign
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
self.device, non_blocking=True
)
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
self.device, non_blocking=True
)
if not decoder_out_cache_loc:
self.out_cache_loc = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
else:
self.out_cache_loc = torch.cat(decoder_out_cache_loc)
if not encoder_out_cache_loc:
self.encoder_out_cache_loc = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
else:
self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)
assert len(self.out_cache_loc) == self.extend_num_tokens
def prepare_for_extend(self):
self.forward_mode = ForwardMode.EXTEND
bs = len(self.reqs)
@@ -561,8 +660,13 @@ class ScheduleBatch:
self.extend_lens = [r.extend_input_len for r in reqs]
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
if self.model_config.is_encoder_decoder:
self.prepare_encoder_info_extend(input_ids, seq_lens)
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
self, vocab_size, global_server_args_dict["disable_penalizer"]
self,
self.model_config.vocab_size,
global_server_args_dict["disable_penalizer"],
)
def mix_with_running(self, running_batch: "ScheduleBatch"):
@@ -752,6 +856,10 @@ class ScheduleBatch:
return jump_forward_reqs
def prepare_encoder_info_decode(self):
# Reset the encoder cached status
self.encoder_cached = [True] * len(self.reqs)
def prepare_for_decode(self, enable_overlap: bool = False):
self.forward_mode = ForwardMode.DECODE
@@ -766,16 +874,22 @@ class ScheduleBatch:
bs = len(self.reqs)
self.out_cache_loc = self.alloc_token_slots(bs)
if self.model_config.is_encoder_decoder:
locs = self.encoder_lens + self.seq_lens
self.prepare_encoder_info_decode()
else:
locs = self.seq_lens
if enable_overlap:
# Do not use in-place operations in the overlap mode
self.req_to_token_pool.write(
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
(self.req_pool_indices, locs), self.out_cache_loc
)
self.seq_lens = self.seq_lens + 1
else:
# A faster in-place version
self.req_to_token_pool.write(
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
(self.req_pool_indices, locs), self.out_cache_loc
)
self.seq_lens.add_(1)
self.seq_lens_sum += bs
@@ -802,6 +916,10 @@ class ScheduleBatch:
# No need to filter
return
if self.model_config.is_encoder_decoder:
self.encoder_lens = self.encoder_lens[keep_indices]
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
self.reqs = [self.reqs[i] for i in keep_indices]
new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
self.device, non_blocking=True
@@ -828,6 +946,11 @@ class ScheduleBatch:
# needs to be called with pre-merged Batch.reqs.
self.sampling_info.merge_batch(other.sampling_info)
# Encoder-decoder infos
if self.model_config.is_encoder_decoder:
self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
self.req_pool_indices = torch.concat(
[self.req_pool_indices, other.req_pool_indices]
)
@@ -850,14 +973,11 @@ class ScheduleBatch:
def get_model_worker_batch(self):
if self.forward_mode.is_decode():
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = (
image_inputs
) = None
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
else:
extend_seq_lens = self.extend_lens
extend_prefix_lens = self.prefix_lens
extend_logprob_start_lens = self.extend_logprob_start_lens
image_inputs = [r.image_inputs for r in self.reqs]
if self.has_regex:
self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs]
@@ -887,7 +1007,11 @@ class ScheduleBatch:
extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens,
extend_logprob_start_lens=extend_logprob_start_lens,
image_inputs=image_inputs,
image_inputs=[r.image_inputs for r in self.reqs],
encoder_cached=self.encoder_cached,
encoder_lens=self.encoder_lens,
encoder_lens_cpu=self.encoder_lens_cpu,
encoder_out_cache_loc=self.encoder_out_cache_loc,
lora_paths=[req.lora_path for req in self.reqs],
sampling_info=self.sampling_info,
mrope_positions_delta=mrope_positions_delta,
@@ -897,6 +1021,7 @@ class ScheduleBatch:
# Only contain fields that will be used by process_batch_result
return ScheduleBatch(
reqs=self.reqs,
model_config=self.model_config,
forward_mode=self.forward_mode,
out_cache_loc=self.out_cache_loc,
return_logprob=self.return_logprob,
@@ -944,6 +1069,12 @@ class ModelWorkerBatch:
# For multimodal
image_inputs: Optional[List[ImageInputs]]
# For encoder-decoder
encoder_cached: Optional[List[bool]]
encoder_lens: Optional[torch.Tensor]
encoder_lens_cpu: Optional[List[int]]
encoder_out_cache_loc: Optional[torch.Tensor]
# For LoRA
lora_paths: Optional[List[str]]

View File

@@ -662,8 +662,9 @@ class Scheduler:
self.req_to_token_pool,
self.token_to_kv_pool,
self.tree_cache,
self.model_config,
)
new_batch.prepare_for_extend(self.model_config.vocab_size)
new_batch.prepare_for_extend()
# Mixed-style chunked prefill
if self.is_mixed_chunk and self.running_batch is not None:

View File

@@ -122,7 +122,7 @@ class TokenizerManager:
# We want to parallelize the image pre-processing so we create an executor for it
self.image_processor = get_image_processor(
self.hf_config, server_args, self.processor.image_processor
self.hf_config, server_args, self.processor
)
else:
self.tokenizer = get_tokenizer(
@@ -191,8 +191,10 @@ class TokenizerManager:
sampling_params = self._get_sampling_params(obj.sampling_params)
if self.is_generation:
image_inputs = await self.image_processor.process_images_async(
obj.image_data, obj
obj.image_data, input_text or input_ids, obj
)
if image_inputs and "input_ids" in image_inputs:
input_ids = image_inputs["input_ids"]
return_logprob = obj.return_logprob
logprob_start_len = obj.logprob_start_len
top_logprobs_num = obj.top_logprobs_num
@@ -217,8 +219,10 @@ class TokenizerManager:
sampling_params = self._get_sampling_params(obj.sampling_params[index])
if self.is_generation:
image_inputs = await self.image_processor.process_images_async(
obj.image_data[index], obj
obj.image_data[index], input_text or input_ids, obj
)
if image_inputs and "input_ids" in image_inputs:
input_ids = image_inputs["input_ids"]
return_logprob = obj.return_logprob[index]
logprob_start_len = obj.logprob_start_len[index]
top_logprobs_num = obj.top_logprobs_num[index]
@@ -263,8 +267,10 @@ class TokenizerManager:
sampling_params = SamplingParams(**obj.sampling_params[0])
sampling_params.max_new_tokens = 0
image_inputs = await self.image_processor.process_images_async(
obj.image_data[0], obj
obj.image_data[0], input_text or input_ids, obj
)
if image_inputs and "input_ids" in image_inputs:
input_ids = image_inputs["input_ids"]
return_logprob = obj.return_logprob[0]
logprob_start_len = obj.logprob_start_len[0]
top_logprobs_num = obj.top_logprobs_num[0]

View File

@@ -26,6 +26,8 @@ from typing import List, Tuple, Union
import torch
from sglang.srt.layers.radix_attention import RadixAttention
logger = logging.getLogger(__name__)
@@ -41,13 +43,17 @@ class ReqToTokenPool:
)
self.free_slots = list(range(size))
self.write_records = []
self.use_records = use_records
if use_records:
# records all write operations
if self.use_records:
self.write = self.write_with_records
else:
self.write = self.write_without_records
def write(self, indices, values):
# Keep the signature for type checking, will be initialized during runtime
raise NotImplementedError()
def available_size(self):
return len(self.free_slots)
@@ -154,7 +160,7 @@ class BaseTokenToKVPool:
def set_kv_buffer(
self,
layer_id: int,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
@@ -209,11 +215,12 @@ class MHATokenToKVPool(BaseTokenToKVPool):
def set_kv_buffer(
self,
layer_id: int,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
):
layer_id = layer.layer_id
if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype)
if cache_v.dtype != self.dtype:
@@ -265,11 +272,12 @@ class MLATokenToKVPool(BaseTokenToKVPool):
def set_kv_buffer(
self,
layer_id: int,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
):
layer_id = layer.layer_id
if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype)
if self.store_dtype != self.dtype:
@@ -324,13 +332,14 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
def set_kv_buffer(
self,
layer_id: int,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
cache_label: torch.Tensor,
):
# NOTE(Andy): ignore the dtype check
layer_id = layer.layer_id
self.k_buffer[layer_id][loc] = cache_k
self.v_buffer[layer_id][loc] = cache_v
self.label_buffer[layer_id][loc] = cache_label

View File

@@ -105,6 +105,7 @@ class CudaGraphRunner:
self.graph_memory_pool = None
self.use_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
# Batch sizes to capture
if self.model_runner.server_args.disable_cuda_graph_padding:
@@ -132,6 +133,9 @@ class CudaGraphRunner:
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
)
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
self.encoder_len_fill_value = 0
if self.use_torch_compile:
set_torch_compile_config()
@@ -144,9 +148,18 @@ class CudaGraphRunner:
)
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)
if self.is_encoder_decoder:
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
self.encoder_lens = torch.full(
(self.max_bs,), self.encoder_len_fill_value, dtype=torch.int32
)
else:
self.encoder_lens = None
# Capture
try:
self.capture()
with self.model_capture_mode():
self.capture()
except RuntimeError as e:
raise Exception(
f"Capture cuda graph failed: {e}\n"
@@ -157,11 +170,32 @@ class CudaGraphRunner:
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
)
def can_run(self, batch_size: int):
if self.disable_padding:
return batch_size in self.graphs
else:
return batch_size <= self.max_bs
@contextmanager
def model_capture_mode(self):
if hasattr(self.model_runner.model, "capture_mode"):
self.model_runner.model.capture_mode = True
yield
if hasattr(self.model_runner.model, "capture_mode"):
self.model_runner.model.capture_mode = False
def can_run(self, forward_batch: ForwardBatch):
is_bs_supported = (
forward_batch.batch_size in self.graphs
if self.disable_padding
else forward_batch.batch_size <= self.max_bs
)
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
# because the full_text_row_masked_out_mask tensor will always be ones
is_encoder_lens_supported = (
torch.all(forward_batch.encoder_lens > 0)
if self.is_encoder_decoder
else True
)
return is_bs_supported and is_encoder_lens_supported
def capture(self):
with graph_capture() as graph_capture_context:
@@ -188,11 +222,19 @@ class CudaGraphRunner:
req_pool_indices = self.req_pool_indices[:bs]
seq_lens = self.seq_lens[:bs]
out_cache_loc = self.out_cache_loc[:bs]
if self.is_encoder_decoder:
encoder_lens = self.encoder_lens[:bs]
else:
encoder_lens = None
seq_lens_sum = seq_lens.sum().item()
# Attention backend
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
bs, req_pool_indices, seq_lens
bs,
req_pool_indices,
seq_lens,
encoder_lens,
)
# Run and capture
@@ -208,6 +250,7 @@ class CudaGraphRunner:
attn_backend=self.model_runner.attn_backend,
out_cache_loc=out_cache_loc,
seq_lens_sum=seq_lens_sum,
encoder_lens=encoder_lens,
return_logprob=False,
top_logprobs_nums=[0] * bs,
positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
@@ -251,6 +294,8 @@ class CudaGraphRunner:
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc)
if self.is_encoder_decoder:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
# Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
@@ -258,6 +303,7 @@ class CudaGraphRunner:
self.req_pool_indices,
self.seq_lens,
forward_batch.seq_lens_sum,
self.encoder_lens,
)
# Replay

View File

@@ -108,6 +108,12 @@ class ForwardBatch:
# For multimodal
image_inputs: Optional[List[ImageInputs]] = None
# Encoder-decoder
encoder_cached: Optional[List[bool]] = None
encoder_lens: Optional[torch.Tensor] = None
encoder_lens_cpu: Optional[List[int]] = None
encoder_out_cache_loc: Optional[torch.Tensor] = None
# For LoRA
lora_paths: Optional[List[str]] = None
@@ -194,6 +200,11 @@ class ForwardBatch:
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
out_cache_loc=batch.out_cache_loc,
image_inputs=batch.image_inputs,
encoder_cached=batch.encoder_cached,
encoder_lens=batch.encoder_lens,
encoder_lens_cpu=batch.encoder_lens_cpu,
encoder_out_cache_loc=batch.encoder_out_cache_loc,
seq_lens_sum=batch.seq_lens_sum,
return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums,
@@ -212,11 +223,11 @@ class ForwardBatch:
],
axis=0,
)
ret.image_inputs = batch.image_inputs
ret.extend_num_tokens = batch.extend_num_tokens
ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32
).to(device, non_blocking=True)
ret.extend_prefix_lens = torch.tensor(
batch.extend_prefix_lens, dtype=torch.int32
).to(device, non_blocking=True)

View File

@@ -270,7 +270,6 @@ class ModelRunner:
if hasattr(self.model, "get_attention_sliding_window_size")
else None
)
self.has_cross_attention = getattr(self.model, "has_cross_attention", False)
self.is_generation = is_generation_model(
self.model_config.hf_config.architectures, self.server_args.is_embedding
)
@@ -510,7 +509,7 @@ class ModelRunner:
"Window attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
assert not self.has_cross_attention, (
assert not self.model_config.is_encoder_decoder, (
"Cross attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
@@ -558,9 +557,7 @@ class ModelRunner:
self.cuda_graph_runner = CudaGraphRunner(self)
def forward_decode(self, forward_batch: ForwardBatch):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(
forward_batch.batch_size
):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
return self.cuda_graph_runner.replay(forward_batch)
forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)

File diff suppressed because it is too large Load Diff

View File

@@ -605,7 +605,11 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
]
positions = forward_batch.mrope_positions
if image_inputs is None or len(image_inputs) == 0:
if (
forward_batch.forward_mode.is_decode()
or image_inputs is None
or len(image_inputs) == 0
):
inputs_embeds = self.model.embed_tokens(input_ids)
else:
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":

View File

@@ -209,6 +209,7 @@ def is_multimodal_model(model_architectures):
or "LlavaQwenForCausalLM" in model_architectures
or "LlavaMistralForCausalLM" in model_architectures
or "LlavaVidForCausalLM" in model_architectures
or "MllamaForConditionalGeneration" in model_architectures
or "Qwen2VLForConditionalGeneration" in model_architectures
):
return True

View File

@@ -171,7 +171,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
assert isinstance(text, str)
print(text)
assert "man" in text or "cab" in text, text
assert "logo" in text, text
assert "logo" in text or '"S"' in text or "SG" in text, text
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
@@ -363,5 +363,27 @@ class TestQWen2VLServer(TestOpenAIVisionServer):
cls.base_url += "/v1"
class TestMllamaServer(TestOpenAIVisionServer):
@classmethod
def setUpClass(cls):
cls.model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=[
"--chat-template",
"llama_3_vision",
],
)
cls.base_url += "/v1"
def test_video_chat_completion(self):
pass
if __name__ == "__main__":
unittest.main()