Llama3.2 vision model support (#1551)
This commit is contained in:
@@ -8,16 +8,12 @@ version = "0.3.4"
|
|||||||
description = "SGLang is yet another fast serving framework for large language models and vision language models."
|
description = "SGLang is yet another fast serving framework for large language models and vision language models."
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.8"
|
requires-python = ">=3.8"
|
||||||
license = {file = "LICENSE"}
|
license = { file = "LICENSE" }
|
||||||
classifiers = [
|
classifiers = [
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
"License :: OSI Approved :: Apache Software License",
|
"License :: OSI Approved :: Apache Software License",
|
||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = ["requests", "tqdm", "numpy"]
|
||||||
"requests",
|
|
||||||
"tqdm",
|
|
||||||
"numpy",
|
|
||||||
]
|
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular",
|
runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular",
|
||||||
@@ -32,7 +28,14 @@ srt_xpu = ["sglang[runtime_common]"]
|
|||||||
openai = ["openai>=1.0", "tiktoken"]
|
openai = ["openai>=1.0", "tiktoken"]
|
||||||
anthropic = ["anthropic>=0.20.0"]
|
anthropic = ["anthropic>=0.20.0"]
|
||||||
litellm = ["litellm>=1.0.0"]
|
litellm = ["litellm>=1.0.0"]
|
||||||
test = ["jsonlines", "matplotlib", "pandas", "sentence_transformers", "accelerate", "peft"]
|
test = [
|
||||||
|
"jsonlines",
|
||||||
|
"matplotlib",
|
||||||
|
"pandas",
|
||||||
|
"sentence_transformers",
|
||||||
|
"accelerate",
|
||||||
|
"peft",
|
||||||
|
]
|
||||||
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
|
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
|
||||||
all_xpu = ["sglang[srt_xpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
|
all_xpu = ["sglang[srt_xpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
|
||||||
dev = ["sglang[all]", "sglang[test]"]
|
dev = ["sglang[all]", "sglang[test]"]
|
||||||
@@ -43,7 +46,23 @@ dev_xpu = ["sglang[all_xpu]", "sglang[test]"]
|
|||||||
"Bug Tracker" = "https://github.com/sgl-project/sglang/issues"
|
"Bug Tracker" = "https://github.com/sgl-project/sglang/issues"
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"]
|
exclude = [
|
||||||
|
"assets*",
|
||||||
|
"benchmark*",
|
||||||
|
"docs*",
|
||||||
|
"dist*",
|
||||||
|
"playground*",
|
||||||
|
"scripts*",
|
||||||
|
"tests*",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.wheel]
|
[tool.wheel]
|
||||||
exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"]
|
exclude = [
|
||||||
|
"assets*",
|
||||||
|
"benchmark*",
|
||||||
|
"docs*",
|
||||||
|
"dist*",
|
||||||
|
"playground*",
|
||||||
|
"scripts*",
|
||||||
|
"tests*",
|
||||||
|
]
|
||||||
|
|||||||
@@ -227,8 +227,9 @@ def extend(reqs, model_runner):
|
|||||||
req_to_token_pool=model_runner.req_to_token_pool,
|
req_to_token_pool=model_runner.req_to_token_pool,
|
||||||
token_to_kv_pool=model_runner.token_to_kv_pool,
|
token_to_kv_pool=model_runner.token_to_kv_pool,
|
||||||
tree_cache=None,
|
tree_cache=None,
|
||||||
|
model_config=model_runner.model_config,
|
||||||
)
|
)
|
||||||
batch.prepare_for_extend(model_runner.model_config.vocab_size)
|
batch.prepare_for_extend()
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
||||||
logits_output = model_runner.forward(forward_batch)
|
logits_output = model_runner.forward(forward_batch)
|
||||||
|
|||||||
@@ -229,6 +229,7 @@ register_chat_template(
|
|||||||
),
|
),
|
||||||
},
|
},
|
||||||
stop_str=("<|eot_id|>",),
|
stop_str=("<|eot_id|>",),
|
||||||
|
image_token="<|image|>",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -89,6 +89,8 @@ class ModelConfig:
|
|||||||
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
|
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
|
||||||
self.vocab_size = self.hf_text_config.vocab_size
|
self.vocab_size = self.hf_text_config.vocab_size
|
||||||
|
|
||||||
|
self.is_encoder_decoder = self.hf_config.model_type in ["mllama"]
|
||||||
|
|
||||||
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
||||||
def get_total_num_kv_heads(self) -> int:
|
def get_total_num_kv_heads(self) -> int:
|
||||||
"""Returns the total number of KV heads."""
|
"""Returns the total number of KV heads."""
|
||||||
|
|||||||
@@ -509,6 +509,19 @@ register_conv_template(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
register_conv_template(
|
||||||
|
Conversation(
|
||||||
|
name="llama_3_vision",
|
||||||
|
system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
|
||||||
|
system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>",
|
||||||
|
roles=("user", "assistant"),
|
||||||
|
sep_style=SeparatorStyle.LLAMA3,
|
||||||
|
sep="",
|
||||||
|
stop_str=["<|end_of_text|>", "<|eot_id|>"],
|
||||||
|
image_token="<|image|>",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
register_conv_template(
|
register_conv_template(
|
||||||
Conversation(
|
Conversation(
|
||||||
name="llava_llama_3",
|
name="llava_llama_3",
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
|
|
||||||
|
|
||||||
@@ -19,7 +21,11 @@ class AttentionBackend(ABC):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
|
self,
|
||||||
|
bs: int,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
encoder_lens: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
"""Init the metadata for a forward pass for capturing a cuda graph."""
|
"""Init the metadata for a forward pass for capturing a cuda graph."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@@ -30,6 +36,7 @@ class AttentionBackend(ABC):
|
|||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
|
encoder_lens: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
"""Init the metadata for a forward pass for replying a cuda graph."""
|
"""Init the metadata for a forward pass for replying a cuda graph."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@@ -43,7 +50,7 @@ class AttentionBackend(ABC):
|
|||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
layer: nn.Module,
|
layer: RadixAttention,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
):
|
):
|
||||||
"""Run forward on an attention layer."""
|
"""Run forward on an attention layer."""
|
||||||
@@ -57,7 +64,7 @@ class AttentionBackend(ABC):
|
|||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
layer: nn.Module,
|
layer: RadixAttention,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
):
|
):
|
||||||
"""Run a forward for decode."""
|
"""Run a forward for decode."""
|
||||||
@@ -68,7 +75,7 @@ class AttentionBackend(ABC):
|
|||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
layer: nn.Module,
|
layer: RadixAttention,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
):
|
):
|
||||||
"""Run a forward for extend."""
|
"""Run a forward for extend."""
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
|
||||||
|
|
||||||
@@ -134,8 +135,13 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
|
self,
|
||||||
|
bs: int,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
encoder_lens=None,
|
||||||
):
|
):
|
||||||
|
# NOTE: encoder_lens expected to be zeros or None
|
||||||
self.forward_metadata = (
|
self.forward_metadata = (
|
||||||
self.cuda_graph_start_loc,
|
self.cuda_graph_start_loc,
|
||||||
self.cuda_graph_attn_logits,
|
self.cuda_graph_attn_logits,
|
||||||
@@ -149,14 +155,18 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
|
encoder_lens=None,
|
||||||
):
|
):
|
||||||
|
# NOTE: encoder_lens expected to be zeros or None
|
||||||
self.cuda_graph_start_loc.zero_()
|
self.cuda_graph_start_loc.zero_()
|
||||||
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
||||||
|
|
||||||
def get_cuda_graph_seq_len_fill_value(self):
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
def forward_extend(
|
||||||
|
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
||||||
|
):
|
||||||
# TODO: reuse the buffer across layers
|
# TODO: reuse the buffer across layers
|
||||||
if layer.qk_head_dim != layer.v_head_dim:
|
if layer.qk_head_dim != layer.v_head_dim:
|
||||||
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||||
@@ -172,7 +182,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
|
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
layer.layer_id, forward_batch.out_cache_loc, k, v, k_label
|
layer, forward_batch.out_cache_loc, k, v, k_label
|
||||||
)
|
)
|
||||||
|
|
||||||
(
|
(
|
||||||
@@ -201,7 +211,9 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
return o
|
return o
|
||||||
|
|
||||||
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
def forward_decode(
|
||||||
|
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
||||||
|
):
|
||||||
# During torch.compile, there is a bug in rotary_emb that causes the
|
# During torch.compile, there is a bug in rotary_emb that causes the
|
||||||
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
||||||
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||||
@@ -231,7 +243,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
|
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
layer.layer_id, forward_batch.out_cache_loc, k, v, k_label
|
layer, forward_batch.out_cache_loc, k, v, k_label
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
|
# NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from enum import Enum, auto
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
@@ -21,6 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|||||||
from sglang.srt.utils import is_flashinfer_available
|
from sglang.srt.utils import is_flashinfer_available
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
|
||||||
if is_flashinfer_available():
|
if is_flashinfer_available():
|
||||||
@@ -56,13 +56,13 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
assert not (
|
assert not (
|
||||||
model_runner.sliding_window_size is not None
|
model_runner.sliding_window_size is not None
|
||||||
and model_runner.has_cross_attention
|
and model_runner.model_config.is_encoder_decoder
|
||||||
), "Sliding window and cross attention are not supported together"
|
), "Sliding window and cross attention are not supported together"
|
||||||
|
|
||||||
if model_runner.sliding_window_size is not None:
|
if model_runner.sliding_window_size is not None:
|
||||||
self.num_wrappers = 2
|
self.num_wrappers = 2
|
||||||
self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW
|
self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW
|
||||||
elif model_runner.has_cross_attention:
|
elif model_runner.model_config.is_encoder_decoder:
|
||||||
self.num_wrappers = 2
|
self.num_wrappers = 2
|
||||||
self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION
|
self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION
|
||||||
else:
|
else:
|
||||||
@@ -128,6 +128,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
forward_batch.req_pool_indices,
|
forward_batch.req_pool_indices,
|
||||||
forward_batch.seq_lens,
|
forward_batch.seq_lens,
|
||||||
forward_batch.seq_lens_sum,
|
forward_batch.seq_lens_sum,
|
||||||
|
decode_wrappers=None,
|
||||||
|
encoder_lens=forward_batch.encoder_lens,
|
||||||
)
|
)
|
||||||
self.forward_metadata = (self.decode_wrappers,)
|
self.forward_metadata = (self.decode_wrappers,)
|
||||||
else:
|
else:
|
||||||
@@ -144,13 +146,11 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
forward_batch.req_pool_indices,
|
forward_batch.req_pool_indices,
|
||||||
forward_batch.seq_lens,
|
forward_batch.seq_lens,
|
||||||
prefix_lens,
|
prefix_lens,
|
||||||
use_ragged,
|
use_ragged=use_ragged,
|
||||||
|
encoder_lens=forward_batch.encoder_lens,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.forward_metadata = (
|
self.forward_metadata = (use_ragged, extend_no_prefix)
|
||||||
use_ragged,
|
|
||||||
extend_no_prefix,
|
|
||||||
)
|
|
||||||
|
|
||||||
def init_cuda_graph_state(self, max_bs: int):
|
def init_cuda_graph_state(self, max_bs: int):
|
||||||
cuda_graph_kv_indices = torch.zeros(
|
cuda_graph_kv_indices = torch.zeros(
|
||||||
@@ -163,7 +163,11 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
|
self,
|
||||||
|
bs: int,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
encoder_lens: torch.Tensor = None,
|
||||||
):
|
):
|
||||||
decode_wrappers = []
|
decode_wrappers = []
|
||||||
for i in range(self.num_wrappers):
|
for i in range(self.num_wrappers):
|
||||||
@@ -181,7 +185,11 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
seq_lens_sum = seq_lens.sum().item()
|
seq_lens_sum = seq_lens.sum().item()
|
||||||
self.indices_updater_decode.update(
|
self.indices_updater_decode.update(
|
||||||
req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
seq_lens_sum,
|
||||||
|
decode_wrappers=decode_wrappers,
|
||||||
|
encoder_lens=encoder_lens,
|
||||||
)
|
)
|
||||||
self.cuda_graph_metadata[bs] = decode_wrappers
|
self.cuda_graph_metadata[bs] = decode_wrappers
|
||||||
self.forward_metadata = (decode_wrappers,)
|
self.forward_metadata = (decode_wrappers,)
|
||||||
@@ -192,34 +200,42 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
|
encoder_lens: torch.Tensor = None,
|
||||||
):
|
):
|
||||||
self.indices_updater_decode.update(
|
self.indices_updater_decode.update(
|
||||||
req_pool_indices[:bs],
|
req_pool_indices[:bs],
|
||||||
seq_lens[:bs],
|
seq_lens[:bs],
|
||||||
seq_lens_sum,
|
seq_lens_sum,
|
||||||
self.cuda_graph_metadata[bs],
|
decode_wrappers=self.cuda_graph_metadata[bs],
|
||||||
|
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_cuda_graph_seq_len_fill_value(self):
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
def forward_extend(
|
||||||
|
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
||||||
|
):
|
||||||
prefill_wrapper_paged = self.prefill_wrappers_paged[
|
prefill_wrapper_paged = self.prefill_wrappers_paged[
|
||||||
self._get_wrapper_idx(layer)
|
self._get_wrapper_idx(layer)
|
||||||
]
|
]
|
||||||
|
|
||||||
use_ragged, extend_no_prefix = self.forward_metadata
|
use_ragged, extend_no_prefix = self.forward_metadata
|
||||||
|
cache_loc = (
|
||||||
|
forward_batch.out_cache_loc
|
||||||
|
if not layer.is_cross_attention
|
||||||
|
else forward_batch.encoder_out_cache_loc
|
||||||
|
)
|
||||||
|
|
||||||
if not use_ragged:
|
if not use_ragged:
|
||||||
if k is not None:
|
if k is not None:
|
||||||
assert v is not None
|
assert v is not None
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
||||||
layer.layer_id, forward_batch.out_cache_loc, k, v
|
|
||||||
)
|
|
||||||
o = prefill_wrapper_paged.forward(
|
o = prefill_wrapper_paged.forward(
|
||||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||||
causal=True,
|
causal=not layer.is_cross_attention,
|
||||||
sm_scale=layer.scaling,
|
sm_scale=layer.scaling,
|
||||||
window_left=layer.sliding_window_size,
|
window_left=layer.sliding_window_size,
|
||||||
logits_soft_cap=layer.logit_cap,
|
logits_soft_cap=layer.logit_cap,
|
||||||
@@ -247,20 +263,23 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
o, _ = merge_state(o1, s1, o2, s2)
|
o, _ = merge_state(o1, s1, o2, s2)
|
||||||
|
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
||||||
layer.layer_id, forward_batch.out_cache_loc, k, v
|
|
||||||
)
|
|
||||||
|
|
||||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||||
|
|
||||||
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
def forward_decode(
|
||||||
|
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
||||||
|
):
|
||||||
decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)]
|
decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)]
|
||||||
|
cache_loc = (
|
||||||
|
forward_batch.out_cache_loc
|
||||||
|
if not layer.is_cross_attention
|
||||||
|
else forward_batch.encoder_out_cache_loc
|
||||||
|
)
|
||||||
|
|
||||||
if k is not None:
|
if k is not None:
|
||||||
assert v is not None
|
assert v is not None
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
||||||
layer.layer_id, forward_batch.out_cache_loc, k, v
|
|
||||||
)
|
|
||||||
|
|
||||||
o = decode_wrapper.forward(
|
o = decode_wrapper.forward(
|
||||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||||
@@ -271,7 +290,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||||
|
|
||||||
def _get_wrapper_idx(self, layer: nn.Module):
|
def _get_wrapper_idx(self, layer: RadixAttention):
|
||||||
if self.num_wrappers == 1:
|
if self.num_wrappers == 1:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
@@ -298,6 +317,8 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
|
self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
|
||||||
self.sliding_window_size = model_runner.sliding_window_size
|
self.sliding_window_size = model_runner.sliding_window_size
|
||||||
|
|
||||||
|
self.attn_backend = attn_backend
|
||||||
|
|
||||||
# Buffers and wrappers
|
# Buffers and wrappers
|
||||||
self.kv_indptr = attn_backend.kv_indptr
|
self.kv_indptr = attn_backend.kv_indptr
|
||||||
self.kv_last_page_len = attn_backend.kv_last_page_len
|
self.kv_last_page_len = attn_backend.kv_last_page_len
|
||||||
@@ -305,20 +326,27 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
self.decode_wrappers = attn_backend.decode_wrappers
|
self.decode_wrappers = attn_backend.decode_wrappers
|
||||||
|
|
||||||
# Dispatch
|
# Dispatch
|
||||||
if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
||||||
self.update = self.update_sliding_window
|
self.update = self.update_sliding_window
|
||||||
elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
||||||
self.update = self.update_cross_attention
|
self.update = self.update_cross_attention
|
||||||
else:
|
else:
|
||||||
assert attn_backend.num_wrappers == 1
|
assert self.attn_backend.num_wrappers == 1
|
||||||
self.update = self.update_single_wrapper
|
self.update = self.update_single_wrapper
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self, req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers, encoder_lens
|
||||||
|
):
|
||||||
|
# Keep the signature for type checking, will be initialized during runtime
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def update_single_wrapper(
|
def update_single_wrapper(
|
||||||
self,
|
self,
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
decode_wrappers=None,
|
decode_wrappers=None,
|
||||||
|
encoder_lens=None,
|
||||||
):
|
):
|
||||||
decode_wrappers = decode_wrappers or self.decode_wrappers
|
decode_wrappers = decode_wrappers or self.decode_wrappers
|
||||||
self.call_begin_forward(
|
self.call_begin_forward(
|
||||||
@@ -336,6 +364,7 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
decode_wrappers=None,
|
decode_wrappers=None,
|
||||||
|
encoder_lens=None,
|
||||||
):
|
):
|
||||||
decode_wrappers = decode_wrappers or self.decode_wrappers
|
decode_wrappers = decode_wrappers or self.decode_wrappers
|
||||||
|
|
||||||
@@ -363,8 +392,35 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
kv_start_idx_tmp,
|
kv_start_idx_tmp,
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_cross_attention(self):
|
def update_cross_attention(
|
||||||
raise NotImplementedError()
|
self,
|
||||||
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
seq_lens_sum,
|
||||||
|
decode_wrappers=None,
|
||||||
|
encoder_lens=None,
|
||||||
|
):
|
||||||
|
decode_wrappers = decode_wrappers or self.decode_wrappers
|
||||||
|
|
||||||
|
for wrapper_id in range(2):
|
||||||
|
if wrapper_id == 0:
|
||||||
|
# Normal attention
|
||||||
|
paged_kernel_lens = seq_lens
|
||||||
|
kv_start_idx = encoder_lens
|
||||||
|
else:
|
||||||
|
# Cross attention
|
||||||
|
paged_kernel_lens = encoder_lens
|
||||||
|
kv_start_idx = torch.zeros_like(encoder_lens)
|
||||||
|
seq_lens_sum = encoder_lens.sum().item()
|
||||||
|
|
||||||
|
self.call_begin_forward(
|
||||||
|
decode_wrappers[wrapper_id],
|
||||||
|
req_pool_indices,
|
||||||
|
paged_kernel_lens,
|
||||||
|
seq_lens_sum,
|
||||||
|
self.kv_indptr[wrapper_id],
|
||||||
|
kv_start_idx,
|
||||||
|
)
|
||||||
|
|
||||||
def call_begin_forward(
|
def call_begin_forward(
|
||||||
self,
|
self,
|
||||||
@@ -421,6 +477,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
|
self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
|
||||||
self.sliding_window_size = model_runner.sliding_window_size
|
self.sliding_window_size = model_runner.sliding_window_size
|
||||||
|
|
||||||
|
self.attn_backend = attn_backend
|
||||||
|
|
||||||
# Buffers and wrappers
|
# Buffers and wrappers
|
||||||
self.kv_indptr = attn_backend.kv_indptr
|
self.kv_indptr = attn_backend.kv_indptr
|
||||||
self.kv_last_page_len = attn_backend.kv_last_page_len
|
self.kv_last_page_len = attn_backend.kv_last_page_len
|
||||||
@@ -430,16 +488,20 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
self.wrappers_paged = attn_backend.prefill_wrappers_paged
|
self.wrappers_paged = attn_backend.prefill_wrappers_paged
|
||||||
|
|
||||||
# Dispatch
|
# Dispatch
|
||||||
if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
||||||
self.update = self.update_sliding_window
|
self.update = self.update_sliding_window
|
||||||
elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
||||||
self.update = self.update_cross_attention
|
self.update = self.update_cross_attention
|
||||||
else:
|
else:
|
||||||
assert attn_backend.num_wrappers == 1
|
assert self.attn_backend.num_wrappers == 1
|
||||||
self.update = self.update_single_wrapper
|
self.update = self.update_single_wrapper
|
||||||
|
|
||||||
|
def update(self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens):
|
||||||
|
# Keep the signature for type checking, will be initialized during runtime
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def update_single_wrapper(
|
def update_single_wrapper(
|
||||||
self, req_pool_indices, seq_lens, prefix_lens, use_ragged
|
self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
|
||||||
):
|
):
|
||||||
if use_ragged:
|
if use_ragged:
|
||||||
paged_kernel_lens = prefix_lens
|
paged_kernel_lens = prefix_lens
|
||||||
@@ -460,7 +522,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def update_sliding_window(
|
def update_sliding_window(
|
||||||
self, req_pool_indices, seq_lens, prefix_lens, use_ragged
|
self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
|
||||||
):
|
):
|
||||||
for wrapper_id in range(2):
|
for wrapper_id in range(2):
|
||||||
if wrapper_id == 0:
|
if wrapper_id == 0:
|
||||||
@@ -487,8 +549,31 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
use_ragged,
|
use_ragged,
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_cross_attention(self):
|
def update_cross_attention(
|
||||||
raise NotImplementedError()
|
self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
|
||||||
|
):
|
||||||
|
for wrapper_id in range(2):
|
||||||
|
if wrapper_id == 0:
|
||||||
|
# normal attention
|
||||||
|
paged_kernel_lens = seq_lens
|
||||||
|
kv_start_idx = encoder_lens
|
||||||
|
else:
|
||||||
|
# cross attention
|
||||||
|
paged_kernel_lens = encoder_lens
|
||||||
|
kv_start_idx = torch.zeros_like(encoder_lens)
|
||||||
|
|
||||||
|
self.call_begin_forward(
|
||||||
|
self.wrapper_ragged,
|
||||||
|
self.wrappers_paged[wrapper_id],
|
||||||
|
req_pool_indices,
|
||||||
|
paged_kernel_lens,
|
||||||
|
seq_lens,
|
||||||
|
prefix_lens,
|
||||||
|
kv_start_idx,
|
||||||
|
self.kv_indptr[wrapper_id],
|
||||||
|
self.qo_indptr[wrapper_id],
|
||||||
|
use_ragged,
|
||||||
|
)
|
||||||
|
|
||||||
def call_begin_forward(
|
def call_begin_forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
|
||||||
|
|
||||||
@@ -81,8 +82,13 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
|
self,
|
||||||
|
bs: int,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
encoder_lens=None,
|
||||||
):
|
):
|
||||||
|
# NOTE: encoder_lens expected to be zeros or None
|
||||||
self.forward_metadata = (
|
self.forward_metadata = (
|
||||||
self.cuda_graph_start_loc,
|
self.cuda_graph_start_loc,
|
||||||
self.cuda_graph_attn_logits,
|
self.cuda_graph_attn_logits,
|
||||||
@@ -96,14 +102,18 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
|
encoder_lens=None,
|
||||||
):
|
):
|
||||||
|
# NOTE: encoder_lens expected to be zeros or None
|
||||||
self.cuda_graph_start_loc.zero_()
|
self.cuda_graph_start_loc.zero_()
|
||||||
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
||||||
|
|
||||||
def get_cuda_graph_seq_len_fill_value(self):
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
def forward_extend(
|
||||||
|
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
||||||
|
):
|
||||||
# TODO: reuse the buffer across layers
|
# TODO: reuse the buffer across layers
|
||||||
if layer.qk_head_dim != layer.v_head_dim:
|
if layer.qk_head_dim != layer.v_head_dim:
|
||||||
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||||
@@ -111,7 +121,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
o = torch.empty_like(q)
|
o = torch.empty_like(q)
|
||||||
|
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
layer.layer_id, forward_batch.out_cache_loc, k, v
|
layer, forward_batch.out_cache_loc, k, v
|
||||||
)
|
)
|
||||||
|
|
||||||
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
|
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
|
||||||
@@ -133,7 +143,9 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
return o
|
return o
|
||||||
|
|
||||||
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
def forward_decode(
|
||||||
|
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
||||||
|
):
|
||||||
# During torch.compile, there is a bug in rotary_emb that causes the
|
# During torch.compile, there is a bug in rotary_emb that causes the
|
||||||
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
||||||
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||||
@@ -147,7 +159,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
|
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
|
||||||
|
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
layer.layer_id, forward_batch.out_cache_loc, k, v
|
layer, forward_batch.out_cache_loc, k, v
|
||||||
)
|
)
|
||||||
|
|
||||||
self.decode_attention_fwd(
|
self.decode_attention_fwd(
|
||||||
|
|||||||
@@ -33,20 +33,9 @@ def init_global_processor(server_args: ServerArgs):
|
|||||||
|
|
||||||
|
|
||||||
class BaseImageProcessor(ABC):
|
class BaseImageProcessor(ABC):
|
||||||
@abstractmethod
|
def __init__(self, hf_config, server_args, _processor):
|
||||||
async def process_images_async(self, image_data, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class DummyImageProcessor(BaseImageProcessor):
|
|
||||||
async def process_images_async(self, *args, **kwargs):
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class LlavaImageProcessor(BaseImageProcessor):
|
|
||||||
def __init__(self, hf_config, server_args, _image_processor):
|
|
||||||
self.hf_config = hf_config
|
self.hf_config = hf_config
|
||||||
self._image_processor = _image_processor
|
self._processor = _processor
|
||||||
self.executor = concurrent.futures.ProcessPoolExecutor(
|
self.executor = concurrent.futures.ProcessPoolExecutor(
|
||||||
initializer=init_global_processor,
|
initializer=init_global_processor,
|
||||||
mp_context=mp.get_context("fork"),
|
mp_context=mp.get_context("fork"),
|
||||||
@@ -54,6 +43,23 @@ class LlavaImageProcessor(BaseImageProcessor):
|
|||||||
max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()),
|
max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def process_images_async(self, image_data, input_text, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DummyImageProcessor(BaseImageProcessor):
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def process_images_async(self, *args, **kwargs):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class LlavaImageProcessor(BaseImageProcessor):
|
||||||
|
def __init__(self, hf_config, server_args, _processor):
|
||||||
|
super().__init__(hf_config, server_args, _processor)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _process_single_image_task(
|
def _process_single_image_task(
|
||||||
image_data: Union[str, bytes],
|
image_data: Union[str, bytes],
|
||||||
@@ -119,7 +125,7 @@ class LlavaImageProcessor(BaseImageProcessor):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def process_images_async(
|
async def process_images_async(
|
||||||
self, image_data: List[Union[str, bytes]], request_obj
|
self, image_data: List[Union[str, bytes]], input_text, request_obj
|
||||||
):
|
):
|
||||||
if not image_data:
|
if not image_data:
|
||||||
return None
|
return None
|
||||||
@@ -177,6 +183,54 @@ class LlavaImageProcessor(BaseImageProcessor):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MllamaImageProcessor(BaseImageProcessor):
|
||||||
|
def __init__(self, hf_config, server_args, _processor):
|
||||||
|
super().__init__(hf_config, server_args, _processor)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _process_single_image_task(images, input_text):
|
||||||
|
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
|
||||||
|
return global_processor(images, input_text, return_tensors="pt")
|
||||||
|
|
||||||
|
async def _process_single_image(self, images, input_text):
|
||||||
|
if self.executor is not None:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
image_inputs = await loop.run_in_executor(
|
||||||
|
self.executor,
|
||||||
|
MllamaImageProcessor._process_single_image_task,
|
||||||
|
images,
|
||||||
|
input_text,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
image_inputs = self._processor(images, input_text, return_tensors="pt")
|
||||||
|
|
||||||
|
return image_inputs
|
||||||
|
|
||||||
|
async def process_images_async(
|
||||||
|
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
|
||||||
|
):
|
||||||
|
if not image_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(input_text, list):
|
||||||
|
assert len(input_text) and isinstance(input_text[0], int)
|
||||||
|
input_text = self._processor.tokenizer.decode(input_text)
|
||||||
|
|
||||||
|
if not isinstance(image_data, list):
|
||||||
|
image_data = [image_data]
|
||||||
|
|
||||||
|
if len(image_data) > 0:
|
||||||
|
images = [load_image(image)[0] for image in image_data]
|
||||||
|
else:
|
||||||
|
images = load_image(image_data[0])[0]
|
||||||
|
|
||||||
|
image_inputs = await self._process_single_image(images, input_text)
|
||||||
|
image_inputs["image_hashes"] = [hash(str(image_data))]
|
||||||
|
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
|
||||||
|
|
||||||
|
return image_inputs
|
||||||
|
|
||||||
|
|
||||||
class Qwen2VLImageProcessor(BaseImageProcessor):
|
class Qwen2VLImageProcessor(BaseImageProcessor):
|
||||||
def __init__(self, hf_config, server_args, _image_processor):
|
def __init__(self, hf_config, server_args, _image_processor):
|
||||||
self.hf_config = hf_config
|
self.hf_config = hf_config
|
||||||
@@ -237,7 +291,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
|||||||
return self._process_single_image_task(image_data)
|
return self._process_single_image_task(image_data)
|
||||||
|
|
||||||
async def process_images_async(
|
async def process_images_async(
|
||||||
self, image_data: List[Union[str, bytes]], request_obj
|
self, image_data: List[Union[str, bytes]], input_text, request_obj
|
||||||
):
|
):
|
||||||
if not image_data:
|
if not image_data:
|
||||||
return None
|
return None
|
||||||
@@ -292,12 +346,14 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
|||||||
|
|
||||||
|
|
||||||
def get_image_processor(
|
def get_image_processor(
|
||||||
hf_config, server_args: ServerArgs, _image_processor
|
hf_config, server_args: ServerArgs, processor
|
||||||
) -> BaseImageProcessor:
|
) -> BaseImageProcessor:
|
||||||
if "Qwen2VLForConditionalGeneration" in hf_config.architectures:
|
if "MllamaForConditionalGeneration" in hf_config.architectures:
|
||||||
return Qwen2VLImageProcessor(hf_config, server_args, _image_processor)
|
return MllamaImageProcessor(hf_config, server_args, processor)
|
||||||
|
elif "Qwen2VLForConditionalGeneration" in hf_config.architectures:
|
||||||
|
return Qwen2VLImageProcessor(hf_config, server_args, processor.image_processor)
|
||||||
else:
|
else:
|
||||||
return LlavaImageProcessor(hf_config, server_args, _image_processor)
|
return LlavaImageProcessor(hf_config, server_args, processor.image_processor)
|
||||||
|
|
||||||
|
|
||||||
def get_dummy_image_processor():
|
def get_dummy_image_processor():
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ from typing import List, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
from sglang.srt.constrained import RegexGuide
|
from sglang.srt.constrained import RegexGuide
|
||||||
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
@@ -121,11 +122,12 @@ class ImageInputs:
|
|||||||
"""The image related inputs."""
|
"""The image related inputs."""
|
||||||
|
|
||||||
pixel_values: torch.Tensor
|
pixel_values: torch.Tensor
|
||||||
image_hash: int
|
image_hashes: Optional[list] = None
|
||||||
image_sizes: Optional[list] = None
|
image_sizes: Optional[list] = None
|
||||||
image_offsets: Optional[list] = None
|
image_offsets: Optional[list] = None
|
||||||
pad_values: Optional[list] = None
|
pad_values: Optional[list] = None
|
||||||
modalities: Optional[list] = None
|
modalities: Optional[list] = None
|
||||||
|
num_image_tokens: Optional[int] = None
|
||||||
|
|
||||||
image_embeds: Optional[List[torch.Tensor]] = None
|
image_embeds: Optional[List[torch.Tensor]] = None
|
||||||
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
|
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
|
||||||
@@ -138,19 +140,27 @@ class ImageInputs:
|
|||||||
# Use image hash as fake token_ids, which is then used for prefix matching
|
# Use image hash as fake token_ids, which is then used for prefix matching
|
||||||
ret = ImageInputs(
|
ret = ImageInputs(
|
||||||
pixel_values=obj["pixel_values"],
|
pixel_values=obj["pixel_values"],
|
||||||
image_hash=hash(tuple(obj["image_hashes"])),
|
image_hashes=hash(tuple(obj["image_hashes"])),
|
||||||
image_grid_thws=obj.get("image_grid_thws"),
|
|
||||||
)
|
)
|
||||||
image_hash = ret.image_hash
|
image_hash = ret.image_hashes
|
||||||
ret.pad_values = [
|
ret.pad_values = [
|
||||||
(image_hash) % vocab_size,
|
(image_hash) % vocab_size,
|
||||||
(image_hash >> 16) % vocab_size,
|
(image_hash >> 16) % vocab_size,
|
||||||
(image_hash >> 32) % vocab_size,
|
(image_hash >> 32) % vocab_size,
|
||||||
(image_hash >> 64) % vocab_size,
|
(image_hash >> 64) % vocab_size,
|
||||||
]
|
]
|
||||||
ret.image_sizes = obj["image_sizes"]
|
|
||||||
# Only when pixel values is not None we have modalities
|
optional_args = [
|
||||||
ret.modalities = obj["modalities"] or ["image"]
|
"image_sizes",
|
||||||
|
"modalities",
|
||||||
|
"aspect_ratio_ids",
|
||||||
|
"aspect_ratio_mask",
|
||||||
|
"image_grid_thws",
|
||||||
|
]
|
||||||
|
for arg in optional_args:
|
||||||
|
if arg in obj:
|
||||||
|
setattr(ret, arg, obj[arg])
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@@ -416,6 +426,10 @@ class ScheduleBatch:
|
|||||||
req_to_token_pool: ReqToTokenPool = None
|
req_to_token_pool: ReqToTokenPool = None
|
||||||
token_to_kv_pool: BaseTokenToKVPool = None
|
token_to_kv_pool: BaseTokenToKVPool = None
|
||||||
tree_cache: BasePrefixCache = None
|
tree_cache: BasePrefixCache = None
|
||||||
|
|
||||||
|
# For utility
|
||||||
|
model_config: ModelConfig = None
|
||||||
|
|
||||||
forward_mode: ForwardMode = None
|
forward_mode: ForwardMode = None
|
||||||
sampling_info: SamplingBatchInfo = None
|
sampling_info: SamplingBatchInfo = None
|
||||||
|
|
||||||
@@ -440,6 +454,12 @@ class ScheduleBatch:
|
|||||||
extend_num_tokens: int = None
|
extend_num_tokens: int = None
|
||||||
decoding_reqs: List[Req] = None
|
decoding_reqs: List[Req] = None
|
||||||
|
|
||||||
|
# For encoder-decoder
|
||||||
|
encoder_cached: Optional[List[bool]] = None
|
||||||
|
encoder_lens: Optional[torch.Tensor] = None
|
||||||
|
encoder_lens_cpu: Optional[List[int]] = None
|
||||||
|
encoder_out_cache_loc: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# Stream
|
# Stream
|
||||||
has_stream: bool = False
|
has_stream: bool = False
|
||||||
|
|
||||||
@@ -450,12 +470,20 @@ class ScheduleBatch:
|
|||||||
device: str = "cuda"
|
device: str = "cuda"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
|
def init_new(
|
||||||
|
cls,
|
||||||
|
reqs,
|
||||||
|
req_to_token_pool,
|
||||||
|
token_to_kv_pool,
|
||||||
|
tree_cache,
|
||||||
|
model_config,
|
||||||
|
):
|
||||||
return cls(
|
return cls(
|
||||||
reqs=reqs,
|
reqs=reqs,
|
||||||
req_to_token_pool=req_to_token_pool,
|
req_to_token_pool=req_to_token_pool,
|
||||||
token_to_kv_pool=token_to_kv_pool,
|
token_to_kv_pool=token_to_kv_pool,
|
||||||
tree_cache=tree_cache,
|
tree_cache=tree_cache,
|
||||||
|
model_config=model_config,
|
||||||
return_logprob=any(req.return_logprob for req in reqs),
|
return_logprob=any(req.return_logprob for req in reqs),
|
||||||
has_stream=any(req.stream for req in reqs),
|
has_stream=any(req.stream for req in reqs),
|
||||||
has_regex=any(req.regex_fsm for req in reqs),
|
has_regex=any(req.regex_fsm for req in reqs),
|
||||||
@@ -493,7 +521,78 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
return out_cache_loc
|
return out_cache_loc
|
||||||
|
|
||||||
def prepare_for_extend(self, vocab_size: int):
|
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
|
||||||
|
self.encoder_lens_cpu = []
|
||||||
|
self.encoder_cached = []
|
||||||
|
|
||||||
|
for req in self.reqs:
|
||||||
|
im = req.image_inputs
|
||||||
|
if im is None or im.num_image_tokens is None:
|
||||||
|
# No image input
|
||||||
|
self.encoder_lens_cpu.append(0)
|
||||||
|
self.encoder_cached.append(True)
|
||||||
|
else:
|
||||||
|
self.encoder_lens_cpu.append(im.num_image_tokens)
|
||||||
|
self.encoder_cached.append(
|
||||||
|
self.forward_mode.is_decode()
|
||||||
|
or len(req.prefix_indices) >= im.num_image_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to(
|
||||||
|
self.device, non_blocking=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Strip encoder infos
|
||||||
|
pt = 0
|
||||||
|
decoder_out_cache_loc = []
|
||||||
|
encoder_out_cache_loc = []
|
||||||
|
for i, req in enumerate(self.reqs):
|
||||||
|
encoder_len = self.encoder_lens_cpu[i]
|
||||||
|
seq_lens[i] -= encoder_len
|
||||||
|
|
||||||
|
if len(req.prefix_indices) < encoder_len:
|
||||||
|
# NOTE: the encoder part should considered as a whole
|
||||||
|
assert len(req.prefix_indices) == 0
|
||||||
|
input_ids[i] = input_ids[i][encoder_len:]
|
||||||
|
encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
|
||||||
|
decoder_out_cache_loc.append(
|
||||||
|
self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
|
||||||
|
)
|
||||||
|
self.extend_lens[i] -= encoder_len
|
||||||
|
self.extend_num_tokens -= encoder_len
|
||||||
|
else:
|
||||||
|
decoder_out_cache_loc.append(
|
||||||
|
self.out_cache_loc[pt : pt + req.extend_input_len]
|
||||||
|
)
|
||||||
|
self.prefix_lens[i] -= encoder_len
|
||||||
|
|
||||||
|
pt += req.extend_input_len
|
||||||
|
|
||||||
|
# Reassign
|
||||||
|
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
||||||
|
self.device, non_blocking=True
|
||||||
|
)
|
||||||
|
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
|
||||||
|
self.device, non_blocking=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if not decoder_out_cache_loc:
|
||||||
|
self.out_cache_loc = torch.empty(0, dtype=torch.int32).to(
|
||||||
|
self.device, non_blocking=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.out_cache_loc = torch.cat(decoder_out_cache_loc)
|
||||||
|
|
||||||
|
if not encoder_out_cache_loc:
|
||||||
|
self.encoder_out_cache_loc = torch.empty(0, dtype=torch.int32).to(
|
||||||
|
self.device, non_blocking=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)
|
||||||
|
|
||||||
|
assert len(self.out_cache_loc) == self.extend_num_tokens
|
||||||
|
|
||||||
|
def prepare_for_extend(self):
|
||||||
self.forward_mode = ForwardMode.EXTEND
|
self.forward_mode = ForwardMode.EXTEND
|
||||||
|
|
||||||
bs = len(self.reqs)
|
bs = len(self.reqs)
|
||||||
@@ -561,8 +660,13 @@ class ScheduleBatch:
|
|||||||
self.extend_lens = [r.extend_input_len for r in reqs]
|
self.extend_lens = [r.extend_input_len for r in reqs]
|
||||||
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
||||||
|
|
||||||
|
if self.model_config.is_encoder_decoder:
|
||||||
|
self.prepare_encoder_info_extend(input_ids, seq_lens)
|
||||||
|
|
||||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
||||||
self, vocab_size, global_server_args_dict["disable_penalizer"]
|
self,
|
||||||
|
self.model_config.vocab_size,
|
||||||
|
global_server_args_dict["disable_penalizer"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
||||||
@@ -752,6 +856,10 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
return jump_forward_reqs
|
return jump_forward_reqs
|
||||||
|
|
||||||
|
def prepare_encoder_info_decode(self):
|
||||||
|
# Reset the encoder cached status
|
||||||
|
self.encoder_cached = [True] * len(self.reqs)
|
||||||
|
|
||||||
def prepare_for_decode(self, enable_overlap: bool = False):
|
def prepare_for_decode(self, enable_overlap: bool = False):
|
||||||
self.forward_mode = ForwardMode.DECODE
|
self.forward_mode = ForwardMode.DECODE
|
||||||
|
|
||||||
@@ -766,16 +874,22 @@ class ScheduleBatch:
|
|||||||
bs = len(self.reqs)
|
bs = len(self.reqs)
|
||||||
self.out_cache_loc = self.alloc_token_slots(bs)
|
self.out_cache_loc = self.alloc_token_slots(bs)
|
||||||
|
|
||||||
|
if self.model_config.is_encoder_decoder:
|
||||||
|
locs = self.encoder_lens + self.seq_lens
|
||||||
|
self.prepare_encoder_info_decode()
|
||||||
|
else:
|
||||||
|
locs = self.seq_lens
|
||||||
|
|
||||||
if enable_overlap:
|
if enable_overlap:
|
||||||
# Do not use in-place operations in the overlap mode
|
# Do not use in-place operations in the overlap mode
|
||||||
self.req_to_token_pool.write(
|
self.req_to_token_pool.write(
|
||||||
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
|
(self.req_pool_indices, locs), self.out_cache_loc
|
||||||
)
|
)
|
||||||
self.seq_lens = self.seq_lens + 1
|
self.seq_lens = self.seq_lens + 1
|
||||||
else:
|
else:
|
||||||
# A faster in-place version
|
# A faster in-place version
|
||||||
self.req_to_token_pool.write(
|
self.req_to_token_pool.write(
|
||||||
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
|
(self.req_pool_indices, locs), self.out_cache_loc
|
||||||
)
|
)
|
||||||
self.seq_lens.add_(1)
|
self.seq_lens.add_(1)
|
||||||
self.seq_lens_sum += bs
|
self.seq_lens_sum += bs
|
||||||
@@ -802,6 +916,10 @@ class ScheduleBatch:
|
|||||||
# No need to filter
|
# No need to filter
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if self.model_config.is_encoder_decoder:
|
||||||
|
self.encoder_lens = self.encoder_lens[keep_indices]
|
||||||
|
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
|
||||||
|
|
||||||
self.reqs = [self.reqs[i] for i in keep_indices]
|
self.reqs = [self.reqs[i] for i in keep_indices]
|
||||||
new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
|
new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
|
||||||
self.device, non_blocking=True
|
self.device, non_blocking=True
|
||||||
@@ -828,6 +946,11 @@ class ScheduleBatch:
|
|||||||
# needs to be called with pre-merged Batch.reqs.
|
# needs to be called with pre-merged Batch.reqs.
|
||||||
self.sampling_info.merge_batch(other.sampling_info)
|
self.sampling_info.merge_batch(other.sampling_info)
|
||||||
|
|
||||||
|
# Encoder-decoder infos
|
||||||
|
if self.model_config.is_encoder_decoder:
|
||||||
|
self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
|
||||||
|
self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
|
||||||
|
|
||||||
self.req_pool_indices = torch.concat(
|
self.req_pool_indices = torch.concat(
|
||||||
[self.req_pool_indices, other.req_pool_indices]
|
[self.req_pool_indices, other.req_pool_indices]
|
||||||
)
|
)
|
||||||
@@ -850,14 +973,11 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
def get_model_worker_batch(self):
|
def get_model_worker_batch(self):
|
||||||
if self.forward_mode.is_decode():
|
if self.forward_mode.is_decode():
|
||||||
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = (
|
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
||||||
image_inputs
|
|
||||||
) = None
|
|
||||||
else:
|
else:
|
||||||
extend_seq_lens = self.extend_lens
|
extend_seq_lens = self.extend_lens
|
||||||
extend_prefix_lens = self.prefix_lens
|
extend_prefix_lens = self.prefix_lens
|
||||||
extend_logprob_start_lens = self.extend_logprob_start_lens
|
extend_logprob_start_lens = self.extend_logprob_start_lens
|
||||||
image_inputs = [r.image_inputs for r in self.reqs]
|
|
||||||
|
|
||||||
if self.has_regex:
|
if self.has_regex:
|
||||||
self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs]
|
self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs]
|
||||||
@@ -887,7 +1007,11 @@ class ScheduleBatch:
|
|||||||
extend_seq_lens=extend_seq_lens,
|
extend_seq_lens=extend_seq_lens,
|
||||||
extend_prefix_lens=extend_prefix_lens,
|
extend_prefix_lens=extend_prefix_lens,
|
||||||
extend_logprob_start_lens=extend_logprob_start_lens,
|
extend_logprob_start_lens=extend_logprob_start_lens,
|
||||||
image_inputs=image_inputs,
|
image_inputs=[r.image_inputs for r in self.reqs],
|
||||||
|
encoder_cached=self.encoder_cached,
|
||||||
|
encoder_lens=self.encoder_lens,
|
||||||
|
encoder_lens_cpu=self.encoder_lens_cpu,
|
||||||
|
encoder_out_cache_loc=self.encoder_out_cache_loc,
|
||||||
lora_paths=[req.lora_path for req in self.reqs],
|
lora_paths=[req.lora_path for req in self.reqs],
|
||||||
sampling_info=self.sampling_info,
|
sampling_info=self.sampling_info,
|
||||||
mrope_positions_delta=mrope_positions_delta,
|
mrope_positions_delta=mrope_positions_delta,
|
||||||
@@ -897,6 +1021,7 @@ class ScheduleBatch:
|
|||||||
# Only contain fields that will be used by process_batch_result
|
# Only contain fields that will be used by process_batch_result
|
||||||
return ScheduleBatch(
|
return ScheduleBatch(
|
||||||
reqs=self.reqs,
|
reqs=self.reqs,
|
||||||
|
model_config=self.model_config,
|
||||||
forward_mode=self.forward_mode,
|
forward_mode=self.forward_mode,
|
||||||
out_cache_loc=self.out_cache_loc,
|
out_cache_loc=self.out_cache_loc,
|
||||||
return_logprob=self.return_logprob,
|
return_logprob=self.return_logprob,
|
||||||
@@ -944,6 +1069,12 @@ class ModelWorkerBatch:
|
|||||||
# For multimodal
|
# For multimodal
|
||||||
image_inputs: Optional[List[ImageInputs]]
|
image_inputs: Optional[List[ImageInputs]]
|
||||||
|
|
||||||
|
# For encoder-decoder
|
||||||
|
encoder_cached: Optional[List[bool]]
|
||||||
|
encoder_lens: Optional[torch.Tensor]
|
||||||
|
encoder_lens_cpu: Optional[List[int]]
|
||||||
|
encoder_out_cache_loc: Optional[torch.Tensor]
|
||||||
|
|
||||||
# For LoRA
|
# For LoRA
|
||||||
lora_paths: Optional[List[str]]
|
lora_paths: Optional[List[str]]
|
||||||
|
|
||||||
|
|||||||
@@ -662,8 +662,9 @@ class Scheduler:
|
|||||||
self.req_to_token_pool,
|
self.req_to_token_pool,
|
||||||
self.token_to_kv_pool,
|
self.token_to_kv_pool,
|
||||||
self.tree_cache,
|
self.tree_cache,
|
||||||
|
self.model_config,
|
||||||
)
|
)
|
||||||
new_batch.prepare_for_extend(self.model_config.vocab_size)
|
new_batch.prepare_for_extend()
|
||||||
|
|
||||||
# Mixed-style chunked prefill
|
# Mixed-style chunked prefill
|
||||||
if self.is_mixed_chunk and self.running_batch is not None:
|
if self.is_mixed_chunk and self.running_batch is not None:
|
||||||
|
|||||||
@@ -122,7 +122,7 @@ class TokenizerManager:
|
|||||||
|
|
||||||
# We want to parallelize the image pre-processing so we create an executor for it
|
# We want to parallelize the image pre-processing so we create an executor for it
|
||||||
self.image_processor = get_image_processor(
|
self.image_processor = get_image_processor(
|
||||||
self.hf_config, server_args, self.processor.image_processor
|
self.hf_config, server_args, self.processor
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.tokenizer = get_tokenizer(
|
self.tokenizer = get_tokenizer(
|
||||||
@@ -191,8 +191,10 @@ class TokenizerManager:
|
|||||||
sampling_params = self._get_sampling_params(obj.sampling_params)
|
sampling_params = self._get_sampling_params(obj.sampling_params)
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
image_inputs = await self.image_processor.process_images_async(
|
image_inputs = await self.image_processor.process_images_async(
|
||||||
obj.image_data, obj
|
obj.image_data, input_text or input_ids, obj
|
||||||
)
|
)
|
||||||
|
if image_inputs and "input_ids" in image_inputs:
|
||||||
|
input_ids = image_inputs["input_ids"]
|
||||||
return_logprob = obj.return_logprob
|
return_logprob = obj.return_logprob
|
||||||
logprob_start_len = obj.logprob_start_len
|
logprob_start_len = obj.logprob_start_len
|
||||||
top_logprobs_num = obj.top_logprobs_num
|
top_logprobs_num = obj.top_logprobs_num
|
||||||
@@ -217,8 +219,10 @@ class TokenizerManager:
|
|||||||
sampling_params = self._get_sampling_params(obj.sampling_params[index])
|
sampling_params = self._get_sampling_params(obj.sampling_params[index])
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
image_inputs = await self.image_processor.process_images_async(
|
image_inputs = await self.image_processor.process_images_async(
|
||||||
obj.image_data[index], obj
|
obj.image_data[index], input_text or input_ids, obj
|
||||||
)
|
)
|
||||||
|
if image_inputs and "input_ids" in image_inputs:
|
||||||
|
input_ids = image_inputs["input_ids"]
|
||||||
return_logprob = obj.return_logprob[index]
|
return_logprob = obj.return_logprob[index]
|
||||||
logprob_start_len = obj.logprob_start_len[index]
|
logprob_start_len = obj.logprob_start_len[index]
|
||||||
top_logprobs_num = obj.top_logprobs_num[index]
|
top_logprobs_num = obj.top_logprobs_num[index]
|
||||||
@@ -263,8 +267,10 @@ class TokenizerManager:
|
|||||||
sampling_params = SamplingParams(**obj.sampling_params[0])
|
sampling_params = SamplingParams(**obj.sampling_params[0])
|
||||||
sampling_params.max_new_tokens = 0
|
sampling_params.max_new_tokens = 0
|
||||||
image_inputs = await self.image_processor.process_images_async(
|
image_inputs = await self.image_processor.process_images_async(
|
||||||
obj.image_data[0], obj
|
obj.image_data[0], input_text or input_ids, obj
|
||||||
)
|
)
|
||||||
|
if image_inputs and "input_ids" in image_inputs:
|
||||||
|
input_ids = image_inputs["input_ids"]
|
||||||
return_logprob = obj.return_logprob[0]
|
return_logprob = obj.return_logprob[0]
|
||||||
logprob_start_len = obj.logprob_start_len[0]
|
logprob_start_len = obj.logprob_start_len[0]
|
||||||
top_logprobs_num = obj.top_logprobs_num[0]
|
top_logprobs_num = obj.top_logprobs_num[0]
|
||||||
|
|||||||
@@ -26,6 +26,8 @@ from typing import List, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -41,13 +43,17 @@ class ReqToTokenPool:
|
|||||||
)
|
)
|
||||||
self.free_slots = list(range(size))
|
self.free_slots = list(range(size))
|
||||||
self.write_records = []
|
self.write_records = []
|
||||||
|
self.use_records = use_records
|
||||||
|
|
||||||
if use_records:
|
if self.use_records:
|
||||||
# records all write operations
|
|
||||||
self.write = self.write_with_records
|
self.write = self.write_with_records
|
||||||
else:
|
else:
|
||||||
self.write = self.write_without_records
|
self.write = self.write_without_records
|
||||||
|
|
||||||
|
def write(self, indices, values):
|
||||||
|
# Keep the signature for type checking, will be initialized during runtime
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def available_size(self):
|
def available_size(self):
|
||||||
return len(self.free_slots)
|
return len(self.free_slots)
|
||||||
|
|
||||||
@@ -154,7 +160,7 @@ class BaseTokenToKVPool:
|
|||||||
|
|
||||||
def set_kv_buffer(
|
def set_kv_buffer(
|
||||||
self,
|
self,
|
||||||
layer_id: int,
|
layer: RadixAttention,
|
||||||
loc: torch.Tensor,
|
loc: torch.Tensor,
|
||||||
cache_k: torch.Tensor,
|
cache_k: torch.Tensor,
|
||||||
cache_v: torch.Tensor,
|
cache_v: torch.Tensor,
|
||||||
@@ -209,11 +215,12 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|||||||
|
|
||||||
def set_kv_buffer(
|
def set_kv_buffer(
|
||||||
self,
|
self,
|
||||||
layer_id: int,
|
layer: RadixAttention,
|
||||||
loc: torch.Tensor,
|
loc: torch.Tensor,
|
||||||
cache_k: torch.Tensor,
|
cache_k: torch.Tensor,
|
||||||
cache_v: torch.Tensor,
|
cache_v: torch.Tensor,
|
||||||
):
|
):
|
||||||
|
layer_id = layer.layer_id
|
||||||
if cache_k.dtype != self.dtype:
|
if cache_k.dtype != self.dtype:
|
||||||
cache_k = cache_k.to(self.dtype)
|
cache_k = cache_k.to(self.dtype)
|
||||||
if cache_v.dtype != self.dtype:
|
if cache_v.dtype != self.dtype:
|
||||||
@@ -265,11 +272,12 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
|||||||
|
|
||||||
def set_kv_buffer(
|
def set_kv_buffer(
|
||||||
self,
|
self,
|
||||||
layer_id: int,
|
layer: RadixAttention,
|
||||||
loc: torch.Tensor,
|
loc: torch.Tensor,
|
||||||
cache_k: torch.Tensor,
|
cache_k: torch.Tensor,
|
||||||
cache_v: torch.Tensor,
|
cache_v: torch.Tensor,
|
||||||
):
|
):
|
||||||
|
layer_id = layer.layer_id
|
||||||
if cache_k.dtype != self.dtype:
|
if cache_k.dtype != self.dtype:
|
||||||
cache_k = cache_k.to(self.dtype)
|
cache_k = cache_k.to(self.dtype)
|
||||||
if self.store_dtype != self.dtype:
|
if self.store_dtype != self.dtype:
|
||||||
@@ -324,13 +332,14 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
|
|||||||
|
|
||||||
def set_kv_buffer(
|
def set_kv_buffer(
|
||||||
self,
|
self,
|
||||||
layer_id: int,
|
layer: RadixAttention,
|
||||||
loc: torch.Tensor,
|
loc: torch.Tensor,
|
||||||
cache_k: torch.Tensor,
|
cache_k: torch.Tensor,
|
||||||
cache_v: torch.Tensor,
|
cache_v: torch.Tensor,
|
||||||
cache_label: torch.Tensor,
|
cache_label: torch.Tensor,
|
||||||
):
|
):
|
||||||
# NOTE(Andy): ignore the dtype check
|
# NOTE(Andy): ignore the dtype check
|
||||||
|
layer_id = layer.layer_id
|
||||||
self.k_buffer[layer_id][loc] = cache_k
|
self.k_buffer[layer_id][loc] = cache_k
|
||||||
self.v_buffer[layer_id][loc] = cache_v
|
self.v_buffer[layer_id][loc] = cache_v
|
||||||
self.label_buffer[layer_id][loc] = cache_label
|
self.label_buffer[layer_id][loc] = cache_label
|
||||||
|
|||||||
@@ -105,6 +105,7 @@ class CudaGraphRunner:
|
|||||||
self.graph_memory_pool = None
|
self.graph_memory_pool = None
|
||||||
self.use_torch_compile = model_runner.server_args.enable_torch_compile
|
self.use_torch_compile = model_runner.server_args.enable_torch_compile
|
||||||
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||||
|
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
|
||||||
|
|
||||||
# Batch sizes to capture
|
# Batch sizes to capture
|
||||||
if self.model_runner.server_args.disable_cuda_graph_padding:
|
if self.model_runner.server_args.disable_cuda_graph_padding:
|
||||||
@@ -132,6 +133,9 @@ class CudaGraphRunner:
|
|||||||
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
|
||||||
|
self.encoder_len_fill_value = 0
|
||||||
|
|
||||||
if self.use_torch_compile:
|
if self.use_torch_compile:
|
||||||
set_torch_compile_config()
|
set_torch_compile_config()
|
||||||
|
|
||||||
@@ -144,9 +148,18 @@ class CudaGraphRunner:
|
|||||||
)
|
)
|
||||||
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)
|
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)
|
||||||
|
|
||||||
|
if self.is_encoder_decoder:
|
||||||
|
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
|
||||||
|
self.encoder_lens = torch.full(
|
||||||
|
(self.max_bs,), self.encoder_len_fill_value, dtype=torch.int32
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.encoder_lens = None
|
||||||
|
|
||||||
# Capture
|
# Capture
|
||||||
try:
|
try:
|
||||||
self.capture()
|
with self.model_capture_mode():
|
||||||
|
self.capture()
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Capture cuda graph failed: {e}\n"
|
f"Capture cuda graph failed: {e}\n"
|
||||||
@@ -157,11 +170,32 @@ class CudaGraphRunner:
|
|||||||
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
||||||
)
|
)
|
||||||
|
|
||||||
def can_run(self, batch_size: int):
|
@contextmanager
|
||||||
if self.disable_padding:
|
def model_capture_mode(self):
|
||||||
return batch_size in self.graphs
|
if hasattr(self.model_runner.model, "capture_mode"):
|
||||||
else:
|
self.model_runner.model.capture_mode = True
|
||||||
return batch_size <= self.max_bs
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
if hasattr(self.model_runner.model, "capture_mode"):
|
||||||
|
self.model_runner.model.capture_mode = False
|
||||||
|
|
||||||
|
def can_run(self, forward_batch: ForwardBatch):
|
||||||
|
is_bs_supported = (
|
||||||
|
forward_batch.batch_size in self.graphs
|
||||||
|
if self.disable_padding
|
||||||
|
else forward_batch.batch_size <= self.max_bs
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
|
||||||
|
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
|
||||||
|
# because the full_text_row_masked_out_mask tensor will always be ones
|
||||||
|
is_encoder_lens_supported = (
|
||||||
|
torch.all(forward_batch.encoder_lens > 0)
|
||||||
|
if self.is_encoder_decoder
|
||||||
|
else True
|
||||||
|
)
|
||||||
|
return is_bs_supported and is_encoder_lens_supported
|
||||||
|
|
||||||
def capture(self):
|
def capture(self):
|
||||||
with graph_capture() as graph_capture_context:
|
with graph_capture() as graph_capture_context:
|
||||||
@@ -188,11 +222,19 @@ class CudaGraphRunner:
|
|||||||
req_pool_indices = self.req_pool_indices[:bs]
|
req_pool_indices = self.req_pool_indices[:bs]
|
||||||
seq_lens = self.seq_lens[:bs]
|
seq_lens = self.seq_lens[:bs]
|
||||||
out_cache_loc = self.out_cache_loc[:bs]
|
out_cache_loc = self.out_cache_loc[:bs]
|
||||||
|
if self.is_encoder_decoder:
|
||||||
|
encoder_lens = self.encoder_lens[:bs]
|
||||||
|
else:
|
||||||
|
encoder_lens = None
|
||||||
|
|
||||||
seq_lens_sum = seq_lens.sum().item()
|
seq_lens_sum = seq_lens.sum().item()
|
||||||
|
|
||||||
# Attention backend
|
# Attention backend
|
||||||
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
|
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
|
||||||
bs, req_pool_indices, seq_lens
|
bs,
|
||||||
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
encoder_lens,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run and capture
|
# Run and capture
|
||||||
@@ -208,6 +250,7 @@ class CudaGraphRunner:
|
|||||||
attn_backend=self.model_runner.attn_backend,
|
attn_backend=self.model_runner.attn_backend,
|
||||||
out_cache_loc=out_cache_loc,
|
out_cache_loc=out_cache_loc,
|
||||||
seq_lens_sum=seq_lens_sum,
|
seq_lens_sum=seq_lens_sum,
|
||||||
|
encoder_lens=encoder_lens,
|
||||||
return_logprob=False,
|
return_logprob=False,
|
||||||
top_logprobs_nums=[0] * bs,
|
top_logprobs_nums=[0] * bs,
|
||||||
positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
|
positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
|
||||||
@@ -251,6 +294,8 @@ class CudaGraphRunner:
|
|||||||
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
||||||
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
||||||
self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc)
|
self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc)
|
||||||
|
if self.is_encoder_decoder:
|
||||||
|
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
||||||
|
|
||||||
# Attention backend
|
# Attention backend
|
||||||
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||||
@@ -258,6 +303,7 @@ class CudaGraphRunner:
|
|||||||
self.req_pool_indices,
|
self.req_pool_indices,
|
||||||
self.seq_lens,
|
self.seq_lens,
|
||||||
forward_batch.seq_lens_sum,
|
forward_batch.seq_lens_sum,
|
||||||
|
self.encoder_lens,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Replay
|
# Replay
|
||||||
|
|||||||
@@ -108,6 +108,12 @@ class ForwardBatch:
|
|||||||
# For multimodal
|
# For multimodal
|
||||||
image_inputs: Optional[List[ImageInputs]] = None
|
image_inputs: Optional[List[ImageInputs]] = None
|
||||||
|
|
||||||
|
# Encoder-decoder
|
||||||
|
encoder_cached: Optional[List[bool]] = None
|
||||||
|
encoder_lens: Optional[torch.Tensor] = None
|
||||||
|
encoder_lens_cpu: Optional[List[int]] = None
|
||||||
|
encoder_out_cache_loc: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# For LoRA
|
# For LoRA
|
||||||
lora_paths: Optional[List[str]] = None
|
lora_paths: Optional[List[str]] = None
|
||||||
|
|
||||||
@@ -194,6 +200,11 @@ class ForwardBatch:
|
|||||||
req_pool_indices=batch.req_pool_indices,
|
req_pool_indices=batch.req_pool_indices,
|
||||||
seq_lens=batch.seq_lens,
|
seq_lens=batch.seq_lens,
|
||||||
out_cache_loc=batch.out_cache_loc,
|
out_cache_loc=batch.out_cache_loc,
|
||||||
|
image_inputs=batch.image_inputs,
|
||||||
|
encoder_cached=batch.encoder_cached,
|
||||||
|
encoder_lens=batch.encoder_lens,
|
||||||
|
encoder_lens_cpu=batch.encoder_lens_cpu,
|
||||||
|
encoder_out_cache_loc=batch.encoder_out_cache_loc,
|
||||||
seq_lens_sum=batch.seq_lens_sum,
|
seq_lens_sum=batch.seq_lens_sum,
|
||||||
return_logprob=batch.return_logprob,
|
return_logprob=batch.return_logprob,
|
||||||
top_logprobs_nums=batch.top_logprobs_nums,
|
top_logprobs_nums=batch.top_logprobs_nums,
|
||||||
@@ -212,11 +223,11 @@ class ForwardBatch:
|
|||||||
],
|
],
|
||||||
axis=0,
|
axis=0,
|
||||||
)
|
)
|
||||||
ret.image_inputs = batch.image_inputs
|
|
||||||
ret.extend_num_tokens = batch.extend_num_tokens
|
ret.extend_num_tokens = batch.extend_num_tokens
|
||||||
ret.extend_seq_lens = torch.tensor(
|
ret.extend_seq_lens = torch.tensor(
|
||||||
batch.extend_seq_lens, dtype=torch.int32
|
batch.extend_seq_lens, dtype=torch.int32
|
||||||
).to(device, non_blocking=True)
|
).to(device, non_blocking=True)
|
||||||
|
|
||||||
ret.extend_prefix_lens = torch.tensor(
|
ret.extend_prefix_lens = torch.tensor(
|
||||||
batch.extend_prefix_lens, dtype=torch.int32
|
batch.extend_prefix_lens, dtype=torch.int32
|
||||||
).to(device, non_blocking=True)
|
).to(device, non_blocking=True)
|
||||||
|
|||||||
@@ -270,7 +270,6 @@ class ModelRunner:
|
|||||||
if hasattr(self.model, "get_attention_sliding_window_size")
|
if hasattr(self.model, "get_attention_sliding_window_size")
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
self.has_cross_attention = getattr(self.model, "has_cross_attention", False)
|
|
||||||
self.is_generation = is_generation_model(
|
self.is_generation = is_generation_model(
|
||||||
self.model_config.hf_config.architectures, self.server_args.is_embedding
|
self.model_config.hf_config.architectures, self.server_args.is_embedding
|
||||||
)
|
)
|
||||||
@@ -510,7 +509,7 @@ class ModelRunner:
|
|||||||
"Window attention is not supported in the triton attention backend. "
|
"Window attention is not supported in the triton attention backend. "
|
||||||
"Please use `--attention-backend flashinfer`."
|
"Please use `--attention-backend flashinfer`."
|
||||||
)
|
)
|
||||||
assert not self.has_cross_attention, (
|
assert not self.model_config.is_encoder_decoder, (
|
||||||
"Cross attention is not supported in the triton attention backend. "
|
"Cross attention is not supported in the triton attention backend. "
|
||||||
"Please use `--attention-backend flashinfer`."
|
"Please use `--attention-backend flashinfer`."
|
||||||
)
|
)
|
||||||
@@ -558,9 +557,7 @@ class ModelRunner:
|
|||||||
self.cuda_graph_runner = CudaGraphRunner(self)
|
self.cuda_graph_runner = CudaGraphRunner(self)
|
||||||
|
|
||||||
def forward_decode(self, forward_batch: ForwardBatch):
|
def forward_decode(self, forward_batch: ForwardBatch):
|
||||||
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(
|
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
|
||||||
forward_batch.batch_size
|
|
||||||
):
|
|
||||||
return self.cuda_graph_runner.replay(forward_batch)
|
return self.cuda_graph_runner.replay(forward_batch)
|
||||||
|
|
||||||
forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)
|
forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)
|
||||||
|
|||||||
1004
python/sglang/srt/models/mllama.py
Normal file
1004
python/sglang/srt/models/mllama.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -605,7 +605,11 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
]
|
]
|
||||||
|
|
||||||
positions = forward_batch.mrope_positions
|
positions = forward_batch.mrope_positions
|
||||||
if image_inputs is None or len(image_inputs) == 0:
|
if (
|
||||||
|
forward_batch.forward_mode.is_decode()
|
||||||
|
or image_inputs is None
|
||||||
|
or len(image_inputs) == 0
|
||||||
|
):
|
||||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||||
else:
|
else:
|
||||||
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
||||||
|
|||||||
@@ -209,6 +209,7 @@ def is_multimodal_model(model_architectures):
|
|||||||
or "LlavaQwenForCausalLM" in model_architectures
|
or "LlavaQwenForCausalLM" in model_architectures
|
||||||
or "LlavaMistralForCausalLM" in model_architectures
|
or "LlavaMistralForCausalLM" in model_architectures
|
||||||
or "LlavaVidForCausalLM" in model_architectures
|
or "LlavaVidForCausalLM" in model_architectures
|
||||||
|
or "MllamaForConditionalGeneration" in model_architectures
|
||||||
or "Qwen2VLForConditionalGeneration" in model_architectures
|
or "Qwen2VLForConditionalGeneration" in model_architectures
|
||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
|
|||||||
@@ -171,7 +171,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
|
|||||||
assert isinstance(text, str)
|
assert isinstance(text, str)
|
||||||
print(text)
|
print(text)
|
||||||
assert "man" in text or "cab" in text, text
|
assert "man" in text or "cab" in text, text
|
||||||
assert "logo" in text, text
|
assert "logo" in text or '"S"' in text or "SG" in text, text
|
||||||
assert response.id
|
assert response.id
|
||||||
assert response.created
|
assert response.created
|
||||||
assert response.usage.prompt_tokens > 0
|
assert response.usage.prompt_tokens > 0
|
||||||
@@ -363,5 +363,27 @@ class TestQWen2VLServer(TestOpenAIVisionServer):
|
|||||||
cls.base_url += "/v1"
|
cls.base_url += "/v1"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMllamaServer(TestOpenAIVisionServer):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.api_key = "sk-123456"
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
api_key=cls.api_key,
|
||||||
|
other_args=[
|
||||||
|
"--chat-template",
|
||||||
|
"llama_3_vision",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
cls.base_url += "/v1"
|
||||||
|
|
||||||
|
def test_video_chat_completion(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user