diff --git a/python/pyproject.toml b/python/pyproject.toml index df6236162..d51fc2331 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -8,16 +8,12 @@ version = "0.3.4" description = "SGLang is yet another fast serving framework for large language models and vision language models." readme = "README.md" requires-python = ">=3.8" -license = {file = "LICENSE"} +license = { file = "LICENSE" } classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", ] -dependencies = [ - "requests", - "tqdm", - "numpy", -] +dependencies = ["requests", "tqdm", "numpy"] [project.optional-dependencies] runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular", @@ -32,7 +28,14 @@ srt_xpu = ["sglang[runtime_common]"] openai = ["openai>=1.0", "tiktoken"] anthropic = ["anthropic>=0.20.0"] litellm = ["litellm>=1.0.0"] -test = ["jsonlines", "matplotlib", "pandas", "sentence_transformers", "accelerate", "peft"] +test = [ + "jsonlines", + "matplotlib", + "pandas", + "sentence_transformers", + "accelerate", + "peft", +] all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] all_xpu = ["sglang[srt_xpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] dev = ["sglang[all]", "sglang[test]"] @@ -43,7 +46,23 @@ dev_xpu = ["sglang[all_xpu]", "sglang[test]"] "Bug Tracker" = "https://github.com/sgl-project/sglang/issues" [tool.setuptools.packages.find] -exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"] +exclude = [ + "assets*", + "benchmark*", + "docs*", + "dist*", + "playground*", + "scripts*", + "tests*", +] [tool.wheel] -exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"] +exclude = [ + "assets*", + "benchmark*", + "docs*", + "dist*", + "playground*", + "scripts*", + "tests*", +] diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index a05398812..43cb7bc3f 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -227,8 +227,9 @@ def extend(reqs, model_runner): req_to_token_pool=model_runner.req_to_token_pool, token_to_kv_pool=model_runner.token_to_kv_pool, tree_cache=None, + model_config=model_runner.model_config, ) - batch.prepare_for_extend(model_runner.model_config.vocab_size) + batch.prepare_for_extend() model_worker_batch = batch.get_model_worker_batch() forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner) logits_output = model_runner.forward(forward_batch) diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index d7602964d..b8f9a533d 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -229,6 +229,7 @@ register_chat_template( ), }, stop_str=("<|eot_id|>",), + image_token="<|image|>", ) ) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index c1493faad..a3c59e8d8 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -89,6 +89,8 @@ class ModelConfig: self.num_hidden_layers = self.hf_text_config.num_hidden_layers self.vocab_size = self.hf_text_config.vocab_size + self.is_encoder_decoder = self.hf_config.model_type in ["mllama"] + # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289 def get_total_num_kv_heads(self) -> int: """Returns the total number of KV heads.""" diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index 73bbc1e2e..42b2d70d5 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -509,6 +509,19 @@ register_conv_template( ) ) +register_conv_template( + Conversation( + name="llama_3_vision", + system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.", + system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>", + roles=("user", "assistant"), + sep_style=SeparatorStyle.LLAMA3, + sep="", + stop_str=["<|end_of_text|>", "<|eot_id|>"], + image_token="<|image|>", + ) +) + register_conv_template( Conversation( name="llava_llama_3", diff --git a/python/sglang/srt/layers/attention/__init__.py b/python/sglang/srt/layers/attention/__init__.py index ae0ef6b7d..f5d573f5f 100644 --- a/python/sglang/srt/layers/attention/__init__.py +++ b/python/sglang/srt/layers/attention/__init__.py @@ -1,8 +1,10 @@ from abc import ABC, abstractmethod +from typing import Optional import torch from torch import nn +from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -19,7 +21,11 @@ class AttentionBackend(ABC): raise NotImplementedError() def init_forward_metadata_capture_cuda_graph( - self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor] = None, ): """Init the metadata for a forward pass for capturing a cuda graph.""" raise NotImplementedError() @@ -30,6 +36,7 @@ class AttentionBackend(ABC): req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor] = None, ): """Init the metadata for a forward pass for replying a cuda graph.""" raise NotImplementedError() @@ -43,7 +50,7 @@ class AttentionBackend(ABC): q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - layer: nn.Module, + layer: RadixAttention, forward_batch: ForwardBatch, ): """Run forward on an attention layer.""" @@ -57,7 +64,7 @@ class AttentionBackend(ABC): q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - layer: nn.Module, + layer: RadixAttention, forward_batch: ForwardBatch, ): """Run a forward for decode.""" @@ -68,7 +75,7 @@ class AttentionBackend(ABC): q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - layer: nn.Module, + layer: RadixAttention, forward_batch: ForwardBatch, ): """Run a forward for extend.""" diff --git a/python/sglang/srt/layers/attention/double_sparsity_backend.py b/python/sglang/srt/layers/attention/double_sparsity_backend.py index c83fba814..73c32df8f 100644 --- a/python/sglang/srt/layers/attention/double_sparsity_backend.py +++ b/python/sglang/srt/layers/attention/double_sparsity_backend.py @@ -10,6 +10,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner @@ -134,8 +135,13 @@ class DoubleSparseAttnBackend(AttentionBackend): ) def init_forward_metadata_capture_cuda_graph( - self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens=None, ): + # NOTE: encoder_lens expected to be zeros or None self.forward_metadata = ( self.cuda_graph_start_loc, self.cuda_graph_attn_logits, @@ -149,14 +155,18 @@ class DoubleSparseAttnBackend(AttentionBackend): req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, + encoder_lens=None, ): + # NOTE: encoder_lens expected to be zeros or None self.cuda_graph_start_loc.zero_() self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) def get_cuda_graph_seq_len_fill_value(self): return 1 - def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): + def forward_extend( + self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + ): # TODO: reuse the buffer across layers if layer.qk_head_dim != layer.v_head_dim: o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) @@ -172,7 +182,7 @@ class DoubleSparseAttnBackend(AttentionBackend): ) forward_batch.token_to_kv_pool.set_kv_buffer( - layer.layer_id, forward_batch.out_cache_loc, k, v, k_label + layer, forward_batch.out_cache_loc, k, v, k_label ) ( @@ -201,7 +211,9 @@ class DoubleSparseAttnBackend(AttentionBackend): ) return o - def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): + def forward_decode( + self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + ): # During torch.compile, there is a bug in rotary_emb that causes the # output value to have a 3D tensor shape. This reshapes the output correctly. q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) @@ -231,7 +243,7 @@ class DoubleSparseAttnBackend(AttentionBackend): ) forward_batch.token_to_kv_pool.set_kv_buffer( - layer.layer_id, forward_batch.out_cache_loc, k, v, k_label + layer, forward_batch.out_cache_loc, k, v, k_label ) # NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 231300ce0..e5e7ca29c 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -11,7 +11,6 @@ from enum import Enum, auto from typing import TYPE_CHECKING import torch -import torch.nn as nn import triton import triton.language as tl @@ -21,6 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.utils import is_flashinfer_available if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner if is_flashinfer_available(): @@ -56,13 +56,13 @@ class FlashInferAttnBackend(AttentionBackend): assert not ( model_runner.sliding_window_size is not None - and model_runner.has_cross_attention + and model_runner.model_config.is_encoder_decoder ), "Sliding window and cross attention are not supported together" if model_runner.sliding_window_size is not None: self.num_wrappers = 2 self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW - elif model_runner.has_cross_attention: + elif model_runner.model_config.is_encoder_decoder: self.num_wrappers = 2 self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION else: @@ -128,6 +128,8 @@ class FlashInferAttnBackend(AttentionBackend): forward_batch.req_pool_indices, forward_batch.seq_lens, forward_batch.seq_lens_sum, + decode_wrappers=None, + encoder_lens=forward_batch.encoder_lens, ) self.forward_metadata = (self.decode_wrappers,) else: @@ -144,13 +146,11 @@ class FlashInferAttnBackend(AttentionBackend): forward_batch.req_pool_indices, forward_batch.seq_lens, prefix_lens, - use_ragged, + use_ragged=use_ragged, + encoder_lens=forward_batch.encoder_lens, ) - self.forward_metadata = ( - use_ragged, - extend_no_prefix, - ) + self.forward_metadata = (use_ragged, extend_no_prefix) def init_cuda_graph_state(self, max_bs: int): cuda_graph_kv_indices = torch.zeros( @@ -163,7 +163,11 @@ class FlashInferAttnBackend(AttentionBackend): ] def init_forward_metadata_capture_cuda_graph( - self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: torch.Tensor = None, ): decode_wrappers = [] for i in range(self.num_wrappers): @@ -181,7 +185,11 @@ class FlashInferAttnBackend(AttentionBackend): seq_lens_sum = seq_lens.sum().item() self.indices_updater_decode.update( - req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers + req_pool_indices, + seq_lens, + seq_lens_sum, + decode_wrappers=decode_wrappers, + encoder_lens=encoder_lens, ) self.cuda_graph_metadata[bs] = decode_wrappers self.forward_metadata = (decode_wrappers,) @@ -192,34 +200,42 @@ class FlashInferAttnBackend(AttentionBackend): req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, + encoder_lens: torch.Tensor = None, ): self.indices_updater_decode.update( req_pool_indices[:bs], seq_lens[:bs], seq_lens_sum, - self.cuda_graph_metadata[bs], + decode_wrappers=self.cuda_graph_metadata[bs], + encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None, ) def get_cuda_graph_seq_len_fill_value(self): return 0 - def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): + def forward_extend( + self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + ): prefill_wrapper_paged = self.prefill_wrappers_paged[ self._get_wrapper_idx(layer) ] use_ragged, extend_no_prefix = self.forward_metadata + cache_loc = ( + forward_batch.out_cache_loc + if not layer.is_cross_attention + else forward_batch.encoder_out_cache_loc + ) if not use_ragged: if k is not None: assert v is not None - forward_batch.token_to_kv_pool.set_kv_buffer( - layer.layer_id, forward_batch.out_cache_loc, k, v - ) + forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) + o = prefill_wrapper_paged.forward( q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), - causal=True, + causal=not layer.is_cross_attention, sm_scale=layer.scaling, window_left=layer.sliding_window_size, logits_soft_cap=layer.logit_cap, @@ -247,20 +263,23 @@ class FlashInferAttnBackend(AttentionBackend): o, _ = merge_state(o1, s1, o2, s2) - forward_batch.token_to_kv_pool.set_kv_buffer( - layer.layer_id, forward_batch.out_cache_loc, k, v - ) + forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) return o.view(-1, layer.tp_q_head_num * layer.head_dim) - def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): + def forward_decode( + self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + ): decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)] + cache_loc = ( + forward_batch.out_cache_loc + if not layer.is_cross_attention + else forward_batch.encoder_out_cache_loc + ) if k is not None: assert v is not None - forward_batch.token_to_kv_pool.set_kv_buffer( - layer.layer_id, forward_batch.out_cache_loc, k, v - ) + forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) o = decode_wrapper.forward( q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), @@ -271,7 +290,7 @@ class FlashInferAttnBackend(AttentionBackend): return o.view(-1, layer.tp_q_head_num * layer.head_dim) - def _get_wrapper_idx(self, layer: nn.Module): + def _get_wrapper_idx(self, layer: RadixAttention): if self.num_wrappers == 1: return 0 @@ -298,6 +317,8 @@ class FlashInferIndicesUpdaterDecode: self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1) self.sliding_window_size = model_runner.sliding_window_size + self.attn_backend = attn_backend + # Buffers and wrappers self.kv_indptr = attn_backend.kv_indptr self.kv_last_page_len = attn_backend.kv_last_page_len @@ -305,20 +326,27 @@ class FlashInferIndicesUpdaterDecode: self.decode_wrappers = attn_backend.decode_wrappers # Dispatch - if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW: + if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW: self.update = self.update_sliding_window - elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION: + elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION: self.update = self.update_cross_attention else: - assert attn_backend.num_wrappers == 1 + assert self.attn_backend.num_wrappers == 1 self.update = self.update_single_wrapper + def update( + self, req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers, encoder_lens + ): + # Keep the signature for type checking, will be initialized during runtime + raise NotImplementedError() + def update_single_wrapper( self, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, decode_wrappers=None, + encoder_lens=None, ): decode_wrappers = decode_wrappers or self.decode_wrappers self.call_begin_forward( @@ -336,6 +364,7 @@ class FlashInferIndicesUpdaterDecode: seq_lens: torch.Tensor, seq_lens_sum: int, decode_wrappers=None, + encoder_lens=None, ): decode_wrappers = decode_wrappers or self.decode_wrappers @@ -363,8 +392,35 @@ class FlashInferIndicesUpdaterDecode: kv_start_idx_tmp, ) - def update_cross_attention(self): - raise NotImplementedError() + def update_cross_attention( + self, + req_pool_indices, + seq_lens, + seq_lens_sum, + decode_wrappers=None, + encoder_lens=None, + ): + decode_wrappers = decode_wrappers or self.decode_wrappers + + for wrapper_id in range(2): + if wrapper_id == 0: + # Normal attention + paged_kernel_lens = seq_lens + kv_start_idx = encoder_lens + else: + # Cross attention + paged_kernel_lens = encoder_lens + kv_start_idx = torch.zeros_like(encoder_lens) + seq_lens_sum = encoder_lens.sum().item() + + self.call_begin_forward( + decode_wrappers[wrapper_id], + req_pool_indices, + paged_kernel_lens, + seq_lens_sum, + self.kv_indptr[wrapper_id], + kv_start_idx, + ) def call_begin_forward( self, @@ -421,6 +477,8 @@ class FlashInferIndicesUpdaterPrefill: self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1) self.sliding_window_size = model_runner.sliding_window_size + self.attn_backend = attn_backend + # Buffers and wrappers self.kv_indptr = attn_backend.kv_indptr self.kv_last_page_len = attn_backend.kv_last_page_len @@ -430,16 +488,20 @@ class FlashInferIndicesUpdaterPrefill: self.wrappers_paged = attn_backend.prefill_wrappers_paged # Dispatch - if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW: + if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW: self.update = self.update_sliding_window - elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION: + elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION: self.update = self.update_cross_attention else: - assert attn_backend.num_wrappers == 1 + assert self.attn_backend.num_wrappers == 1 self.update = self.update_single_wrapper + def update(self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens): + # Keep the signature for type checking, will be initialized during runtime + raise NotImplementedError() + def update_single_wrapper( - self, req_pool_indices, seq_lens, prefix_lens, use_ragged + self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens ): if use_ragged: paged_kernel_lens = prefix_lens @@ -460,7 +522,7 @@ class FlashInferIndicesUpdaterPrefill: ) def update_sliding_window( - self, req_pool_indices, seq_lens, prefix_lens, use_ragged + self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens ): for wrapper_id in range(2): if wrapper_id == 0: @@ -487,8 +549,31 @@ class FlashInferIndicesUpdaterPrefill: use_ragged, ) - def update_cross_attention(self): - raise NotImplementedError() + def update_cross_attention( + self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens + ): + for wrapper_id in range(2): + if wrapper_id == 0: + # normal attention + paged_kernel_lens = seq_lens + kv_start_idx = encoder_lens + else: + # cross attention + paged_kernel_lens = encoder_lens + kv_start_idx = torch.zeros_like(encoder_lens) + + self.call_begin_forward( + self.wrapper_ragged, + self.wrappers_paged[wrapper_id], + req_pool_indices, + paged_kernel_lens, + seq_lens, + prefix_lens, + kv_start_idx, + self.kv_indptr[wrapper_id], + self.qo_indptr[wrapper_id], + use_ragged, + ) def call_begin_forward( self, diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index fb3805cfe..47b8d3cd5 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -10,6 +10,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner @@ -81,8 +82,13 @@ class TritonAttnBackend(AttentionBackend): ) def init_forward_metadata_capture_cuda_graph( - self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens=None, ): + # NOTE: encoder_lens expected to be zeros or None self.forward_metadata = ( self.cuda_graph_start_loc, self.cuda_graph_attn_logits, @@ -96,14 +102,18 @@ class TritonAttnBackend(AttentionBackend): req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, + encoder_lens=None, ): + # NOTE: encoder_lens expected to be zeros or None self.cuda_graph_start_loc.zero_() self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) def get_cuda_graph_seq_len_fill_value(self): return 1 - def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): + def forward_extend( + self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + ): # TODO: reuse the buffer across layers if layer.qk_head_dim != layer.v_head_dim: o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) @@ -111,7 +121,7 @@ class TritonAttnBackend(AttentionBackend): o = torch.empty_like(q) forward_batch.token_to_kv_pool.set_kv_buffer( - layer.layer_id, forward_batch.out_cache_loc, k, v + layer, forward_batch.out_cache_loc, k, v ) start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata @@ -133,7 +143,9 @@ class TritonAttnBackend(AttentionBackend): ) return o - def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): + def forward_decode( + self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + ): # During torch.compile, there is a bug in rotary_emb that causes the # output value to have a 3D tensor shape. This reshapes the output correctly. q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) @@ -147,7 +159,7 @@ class TritonAttnBackend(AttentionBackend): start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata forward_batch.token_to_kv_pool.set_kv_buffer( - layer.layer_id, forward_batch.out_cache_loc, k, v + layer, forward_batch.out_cache_loc, k, v ) self.decode_attention_fwd( diff --git a/python/sglang/srt/managers/image_processor.py b/python/sglang/srt/managers/image_processor.py index b958ab89b..08ad15023 100644 --- a/python/sglang/srt/managers/image_processor.py +++ b/python/sglang/srt/managers/image_processor.py @@ -33,20 +33,9 @@ def init_global_processor(server_args: ServerArgs): class BaseImageProcessor(ABC): - @abstractmethod - async def process_images_async(self, image_data, **kwargs): - pass - - -class DummyImageProcessor(BaseImageProcessor): - async def process_images_async(self, *args, **kwargs): - return None - - -class LlavaImageProcessor(BaseImageProcessor): - def __init__(self, hf_config, server_args, _image_processor): + def __init__(self, hf_config, server_args, _processor): self.hf_config = hf_config - self._image_processor = _image_processor + self._processor = _processor self.executor = concurrent.futures.ProcessPoolExecutor( initializer=init_global_processor, mp_context=mp.get_context("fork"), @@ -54,6 +43,23 @@ class LlavaImageProcessor(BaseImageProcessor): max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()), ) + @abstractmethod + async def process_images_async(self, image_data, input_text, **kwargs): + pass + + +class DummyImageProcessor(BaseImageProcessor): + def __init__(self): + pass + + async def process_images_async(self, *args, **kwargs): + return None + + +class LlavaImageProcessor(BaseImageProcessor): + def __init__(self, hf_config, server_args, _processor): + super().__init__(hf_config, server_args, _processor) + @staticmethod def _process_single_image_task( image_data: Union[str, bytes], @@ -119,7 +125,7 @@ class LlavaImageProcessor(BaseImageProcessor): ) async def process_images_async( - self, image_data: List[Union[str, bytes]], request_obj + self, image_data: List[Union[str, bytes]], input_text, request_obj ): if not image_data: return None @@ -177,6 +183,54 @@ class LlavaImageProcessor(BaseImageProcessor): } +class MllamaImageProcessor(BaseImageProcessor): + def __init__(self, hf_config, server_args, _processor): + super().__init__(hf_config, server_args, _processor) + + @staticmethod + def _process_single_image_task(images, input_text): + # input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask' + return global_processor(images, input_text, return_tensors="pt") + + async def _process_single_image(self, images, input_text): + if self.executor is not None: + loop = asyncio.get_event_loop() + image_inputs = await loop.run_in_executor( + self.executor, + MllamaImageProcessor._process_single_image_task, + images, + input_text, + ) + else: + image_inputs = self._processor(images, input_text, return_tensors="pt") + + return image_inputs + + async def process_images_async( + self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs + ): + if not image_data: + return None + + if isinstance(input_text, list): + assert len(input_text) and isinstance(input_text[0], int) + input_text = self._processor.tokenizer.decode(input_text) + + if not isinstance(image_data, list): + image_data = [image_data] + + if len(image_data) > 0: + images = [load_image(image)[0] for image in image_data] + else: + images = load_image(image_data[0])[0] + + image_inputs = await self._process_single_image(images, input_text) + image_inputs["image_hashes"] = [hash(str(image_data))] + image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0] + + return image_inputs + + class Qwen2VLImageProcessor(BaseImageProcessor): def __init__(self, hf_config, server_args, _image_processor): self.hf_config = hf_config @@ -237,7 +291,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor): return self._process_single_image_task(image_data) async def process_images_async( - self, image_data: List[Union[str, bytes]], request_obj + self, image_data: List[Union[str, bytes]], input_text, request_obj ): if not image_data: return None @@ -292,12 +346,14 @@ class Qwen2VLImageProcessor(BaseImageProcessor): def get_image_processor( - hf_config, server_args: ServerArgs, _image_processor + hf_config, server_args: ServerArgs, processor ) -> BaseImageProcessor: - if "Qwen2VLForConditionalGeneration" in hf_config.architectures: - return Qwen2VLImageProcessor(hf_config, server_args, _image_processor) + if "MllamaForConditionalGeneration" in hf_config.architectures: + return MllamaImageProcessor(hf_config, server_args, processor) + elif "Qwen2VLForConditionalGeneration" in hf_config.architectures: + return Qwen2VLImageProcessor(hf_config, server_args, processor.image_processor) else: - return LlavaImageProcessor(hf_config, server_args, _image_processor) + return LlavaImageProcessor(hf_config, server_args, processor.image_processor) def get_dummy_image_processor(): diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index b0ab2dfe5..bcf3103ad 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -36,6 +36,7 @@ from typing import List, Optional, Tuple, Union import torch from sglang.global_config import global_config +from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained import RegexGuide from sglang.srt.constrained.jump_forward import JumpForwardMap from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache @@ -121,11 +122,12 @@ class ImageInputs: """The image related inputs.""" pixel_values: torch.Tensor - image_hash: int + image_hashes: Optional[list] = None image_sizes: Optional[list] = None image_offsets: Optional[list] = None pad_values: Optional[list] = None modalities: Optional[list] = None + num_image_tokens: Optional[int] = None image_embeds: Optional[List[torch.Tensor]] = None aspect_ratio_ids: Optional[List[torch.Tensor]] = None @@ -138,19 +140,27 @@ class ImageInputs: # Use image hash as fake token_ids, which is then used for prefix matching ret = ImageInputs( pixel_values=obj["pixel_values"], - image_hash=hash(tuple(obj["image_hashes"])), - image_grid_thws=obj.get("image_grid_thws"), + image_hashes=hash(tuple(obj["image_hashes"])), ) - image_hash = ret.image_hash + image_hash = ret.image_hashes ret.pad_values = [ (image_hash) % vocab_size, (image_hash >> 16) % vocab_size, (image_hash >> 32) % vocab_size, (image_hash >> 64) % vocab_size, ] - ret.image_sizes = obj["image_sizes"] - # Only when pixel values is not None we have modalities - ret.modalities = obj["modalities"] or ["image"] + + optional_args = [ + "image_sizes", + "modalities", + "aspect_ratio_ids", + "aspect_ratio_mask", + "image_grid_thws", + ] + for arg in optional_args: + if arg in obj: + setattr(ret, arg, obj[arg]) + return ret @@ -416,6 +426,10 @@ class ScheduleBatch: req_to_token_pool: ReqToTokenPool = None token_to_kv_pool: BaseTokenToKVPool = None tree_cache: BasePrefixCache = None + + # For utility + model_config: ModelConfig = None + forward_mode: ForwardMode = None sampling_info: SamplingBatchInfo = None @@ -440,6 +454,12 @@ class ScheduleBatch: extend_num_tokens: int = None decoding_reqs: List[Req] = None + # For encoder-decoder + encoder_cached: Optional[List[bool]] = None + encoder_lens: Optional[torch.Tensor] = None + encoder_lens_cpu: Optional[List[int]] = None + encoder_out_cache_loc: Optional[torch.Tensor] = None + # Stream has_stream: bool = False @@ -450,12 +470,20 @@ class ScheduleBatch: device: str = "cuda" @classmethod - def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache): + def init_new( + cls, + reqs, + req_to_token_pool, + token_to_kv_pool, + tree_cache, + model_config, + ): return cls( reqs=reqs, req_to_token_pool=req_to_token_pool, token_to_kv_pool=token_to_kv_pool, tree_cache=tree_cache, + model_config=model_config, return_logprob=any(req.return_logprob for req in reqs), has_stream=any(req.stream for req in reqs), has_regex=any(req.regex_fsm for req in reqs), @@ -493,7 +521,78 @@ class ScheduleBatch: return out_cache_loc - def prepare_for_extend(self, vocab_size: int): + def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]): + self.encoder_lens_cpu = [] + self.encoder_cached = [] + + for req in self.reqs: + im = req.image_inputs + if im is None or im.num_image_tokens is None: + # No image input + self.encoder_lens_cpu.append(0) + self.encoder_cached.append(True) + else: + self.encoder_lens_cpu.append(im.num_image_tokens) + self.encoder_cached.append( + self.forward_mode.is_decode() + or len(req.prefix_indices) >= im.num_image_tokens + ) + + self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to( + self.device, non_blocking=True + ) + + # Strip encoder infos + pt = 0 + decoder_out_cache_loc = [] + encoder_out_cache_loc = [] + for i, req in enumerate(self.reqs): + encoder_len = self.encoder_lens_cpu[i] + seq_lens[i] -= encoder_len + + if len(req.prefix_indices) < encoder_len: + # NOTE: the encoder part should considered as a whole + assert len(req.prefix_indices) == 0 + input_ids[i] = input_ids[i][encoder_len:] + encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len]) + decoder_out_cache_loc.append( + self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len] + ) + self.extend_lens[i] -= encoder_len + self.extend_num_tokens -= encoder_len + else: + decoder_out_cache_loc.append( + self.out_cache_loc[pt : pt + req.extend_input_len] + ) + self.prefix_lens[i] -= encoder_len + + pt += req.extend_input_len + + # Reassign + self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to( + self.device, non_blocking=True + ) + self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to( + self.device, non_blocking=True + ) + + if not decoder_out_cache_loc: + self.out_cache_loc = torch.empty(0, dtype=torch.int32).to( + self.device, non_blocking=True + ) + else: + self.out_cache_loc = torch.cat(decoder_out_cache_loc) + + if not encoder_out_cache_loc: + self.encoder_out_cache_loc = torch.empty(0, dtype=torch.int32).to( + self.device, non_blocking=True + ) + else: + self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc) + + assert len(self.out_cache_loc) == self.extend_num_tokens + + def prepare_for_extend(self): self.forward_mode = ForwardMode.EXTEND bs = len(self.reqs) @@ -561,8 +660,13 @@ class ScheduleBatch: self.extend_lens = [r.extend_input_len for r in reqs] self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs] + if self.model_config.is_encoder_decoder: + self.prepare_encoder_info_extend(input_ids, seq_lens) + self.sampling_info = SamplingBatchInfo.from_schedule_batch( - self, vocab_size, global_server_args_dict["disable_penalizer"] + self, + self.model_config.vocab_size, + global_server_args_dict["disable_penalizer"], ) def mix_with_running(self, running_batch: "ScheduleBatch"): @@ -752,6 +856,10 @@ class ScheduleBatch: return jump_forward_reqs + def prepare_encoder_info_decode(self): + # Reset the encoder cached status + self.encoder_cached = [True] * len(self.reqs) + def prepare_for_decode(self, enable_overlap: bool = False): self.forward_mode = ForwardMode.DECODE @@ -766,16 +874,22 @@ class ScheduleBatch: bs = len(self.reqs) self.out_cache_loc = self.alloc_token_slots(bs) + if self.model_config.is_encoder_decoder: + locs = self.encoder_lens + self.seq_lens + self.prepare_encoder_info_decode() + else: + locs = self.seq_lens + if enable_overlap: # Do not use in-place operations in the overlap mode self.req_to_token_pool.write( - (self.req_pool_indices, self.seq_lens), self.out_cache_loc + (self.req_pool_indices, locs), self.out_cache_loc ) self.seq_lens = self.seq_lens + 1 else: # A faster in-place version self.req_to_token_pool.write( - (self.req_pool_indices, self.seq_lens), self.out_cache_loc + (self.req_pool_indices, locs), self.out_cache_loc ) self.seq_lens.add_(1) self.seq_lens_sum += bs @@ -802,6 +916,10 @@ class ScheduleBatch: # No need to filter return + if self.model_config.is_encoder_decoder: + self.encoder_lens = self.encoder_lens[keep_indices] + self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices] + self.reqs = [self.reqs[i] for i in keep_indices] new_indices = torch.tensor(keep_indices, dtype=torch.int32).to( self.device, non_blocking=True @@ -828,6 +946,11 @@ class ScheduleBatch: # needs to be called with pre-merged Batch.reqs. self.sampling_info.merge_batch(other.sampling_info) + # Encoder-decoder infos + if self.model_config.is_encoder_decoder: + self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens]) + self.encoder_lens_cpu.extend(other.encoder_lens_cpu) + self.req_pool_indices = torch.concat( [self.req_pool_indices, other.req_pool_indices] ) @@ -850,14 +973,11 @@ class ScheduleBatch: def get_model_worker_batch(self): if self.forward_mode.is_decode(): - extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = ( - image_inputs - ) = None + extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None else: extend_seq_lens = self.extend_lens extend_prefix_lens = self.prefix_lens extend_logprob_start_lens = self.extend_logprob_start_lens - image_inputs = [r.image_inputs for r in self.reqs] if self.has_regex: self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs] @@ -887,7 +1007,11 @@ class ScheduleBatch: extend_seq_lens=extend_seq_lens, extend_prefix_lens=extend_prefix_lens, extend_logprob_start_lens=extend_logprob_start_lens, - image_inputs=image_inputs, + image_inputs=[r.image_inputs for r in self.reqs], + encoder_cached=self.encoder_cached, + encoder_lens=self.encoder_lens, + encoder_lens_cpu=self.encoder_lens_cpu, + encoder_out_cache_loc=self.encoder_out_cache_loc, lora_paths=[req.lora_path for req in self.reqs], sampling_info=self.sampling_info, mrope_positions_delta=mrope_positions_delta, @@ -897,6 +1021,7 @@ class ScheduleBatch: # Only contain fields that will be used by process_batch_result return ScheduleBatch( reqs=self.reqs, + model_config=self.model_config, forward_mode=self.forward_mode, out_cache_loc=self.out_cache_loc, return_logprob=self.return_logprob, @@ -944,6 +1069,12 @@ class ModelWorkerBatch: # For multimodal image_inputs: Optional[List[ImageInputs]] + # For encoder-decoder + encoder_cached: Optional[List[bool]] + encoder_lens: Optional[torch.Tensor] + encoder_lens_cpu: Optional[List[int]] + encoder_out_cache_loc: Optional[torch.Tensor] + # For LoRA lora_paths: Optional[List[str]] diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 1b68bacd9..b2f217c85 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -662,8 +662,9 @@ class Scheduler: self.req_to_token_pool, self.token_to_kv_pool, self.tree_cache, + self.model_config, ) - new_batch.prepare_for_extend(self.model_config.vocab_size) + new_batch.prepare_for_extend() # Mixed-style chunked prefill if self.is_mixed_chunk and self.running_batch is not None: diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 2bc7ff04b..fc9e23519 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -122,7 +122,7 @@ class TokenizerManager: # We want to parallelize the image pre-processing so we create an executor for it self.image_processor = get_image_processor( - self.hf_config, server_args, self.processor.image_processor + self.hf_config, server_args, self.processor ) else: self.tokenizer = get_tokenizer( @@ -191,8 +191,10 @@ class TokenizerManager: sampling_params = self._get_sampling_params(obj.sampling_params) if self.is_generation: image_inputs = await self.image_processor.process_images_async( - obj.image_data, obj + obj.image_data, input_text or input_ids, obj ) + if image_inputs and "input_ids" in image_inputs: + input_ids = image_inputs["input_ids"] return_logprob = obj.return_logprob logprob_start_len = obj.logprob_start_len top_logprobs_num = obj.top_logprobs_num @@ -217,8 +219,10 @@ class TokenizerManager: sampling_params = self._get_sampling_params(obj.sampling_params[index]) if self.is_generation: image_inputs = await self.image_processor.process_images_async( - obj.image_data[index], obj + obj.image_data[index], input_text or input_ids, obj ) + if image_inputs and "input_ids" in image_inputs: + input_ids = image_inputs["input_ids"] return_logprob = obj.return_logprob[index] logprob_start_len = obj.logprob_start_len[index] top_logprobs_num = obj.top_logprobs_num[index] @@ -263,8 +267,10 @@ class TokenizerManager: sampling_params = SamplingParams(**obj.sampling_params[0]) sampling_params.max_new_tokens = 0 image_inputs = await self.image_processor.process_images_async( - obj.image_data[0], obj + obj.image_data[0], input_text or input_ids, obj ) + if image_inputs and "input_ids" in image_inputs: + input_ids = image_inputs["input_ids"] return_logprob = obj.return_logprob[0] logprob_start_len = obj.logprob_start_len[0] top_logprobs_num = obj.top_logprobs_num[0] diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index bd42dfc72..4277862a7 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -26,6 +26,8 @@ from typing import List, Tuple, Union import torch +from sglang.srt.layers.radix_attention import RadixAttention + logger = logging.getLogger(__name__) @@ -41,13 +43,17 @@ class ReqToTokenPool: ) self.free_slots = list(range(size)) self.write_records = [] + self.use_records = use_records - if use_records: - # records all write operations + if self.use_records: self.write = self.write_with_records else: self.write = self.write_without_records + def write(self, indices, values): + # Keep the signature for type checking, will be initialized during runtime + raise NotImplementedError() + def available_size(self): return len(self.free_slots) @@ -154,7 +160,7 @@ class BaseTokenToKVPool: def set_kv_buffer( self, - layer_id: int, + layer: RadixAttention, loc: torch.Tensor, cache_k: torch.Tensor, cache_v: torch.Tensor, @@ -209,11 +215,12 @@ class MHATokenToKVPool(BaseTokenToKVPool): def set_kv_buffer( self, - layer_id: int, + layer: RadixAttention, loc: torch.Tensor, cache_k: torch.Tensor, cache_v: torch.Tensor, ): + layer_id = layer.layer_id if cache_k.dtype != self.dtype: cache_k = cache_k.to(self.dtype) if cache_v.dtype != self.dtype: @@ -265,11 +272,12 @@ class MLATokenToKVPool(BaseTokenToKVPool): def set_kv_buffer( self, - layer_id: int, + layer: RadixAttention, loc: torch.Tensor, cache_k: torch.Tensor, cache_v: torch.Tensor, ): + layer_id = layer.layer_id if cache_k.dtype != self.dtype: cache_k = cache_k.to(self.dtype) if self.store_dtype != self.dtype: @@ -324,13 +332,14 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool): def set_kv_buffer( self, - layer_id: int, + layer: RadixAttention, loc: torch.Tensor, cache_k: torch.Tensor, cache_v: torch.Tensor, cache_label: torch.Tensor, ): # NOTE(Andy): ignore the dtype check + layer_id = layer.layer_id self.k_buffer[layer_id][loc] = cache_k self.v_buffer[layer_id][loc] = cache_v self.label_buffer[layer_id][loc] = cache_label diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 37e3c8429..ffa77ec4c 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -105,6 +105,7 @@ class CudaGraphRunner: self.graph_memory_pool = None self.use_torch_compile = model_runner.server_args.enable_torch_compile self.disable_padding = model_runner.server_args.disable_cuda_graph_padding + self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder # Batch sizes to capture if self.model_runner.server_args.disable_cuda_graph_padding: @@ -132,6 +133,9 @@ class CudaGraphRunner: self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() ) + # FIXME(lsyin): leave it here for now, I don't know whether it is necessary + self.encoder_len_fill_value = 0 + if self.use_torch_compile: set_torch_compile_config() @@ -144,9 +148,18 @@ class CudaGraphRunner: ) self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32) + if self.is_encoder_decoder: + # NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch + self.encoder_lens = torch.full( + (self.max_bs,), self.encoder_len_fill_value, dtype=torch.int32 + ) + else: + self.encoder_lens = None + # Capture try: - self.capture() + with self.model_capture_mode(): + self.capture() except RuntimeError as e: raise Exception( f"Capture cuda graph failed: {e}\n" @@ -157,11 +170,32 @@ class CudaGraphRunner: "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" ) - def can_run(self, batch_size: int): - if self.disable_padding: - return batch_size in self.graphs - else: - return batch_size <= self.max_bs + @contextmanager + def model_capture_mode(self): + if hasattr(self.model_runner.model, "capture_mode"): + self.model_runner.model.capture_mode = True + + yield + + if hasattr(self.model_runner.model, "capture_mode"): + self.model_runner.model.capture_mode = False + + def can_run(self, forward_batch: ForwardBatch): + is_bs_supported = ( + forward_batch.batch_size in self.graphs + if self.disable_padding + else forward_batch.batch_size <= self.max_bs + ) + + # NOTE: cuda graph cannot handle mixed batch (encoder_len = 0) + # If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph + # because the full_text_row_masked_out_mask tensor will always be ones + is_encoder_lens_supported = ( + torch.all(forward_batch.encoder_lens > 0) + if self.is_encoder_decoder + else True + ) + return is_bs_supported and is_encoder_lens_supported def capture(self): with graph_capture() as graph_capture_context: @@ -188,11 +222,19 @@ class CudaGraphRunner: req_pool_indices = self.req_pool_indices[:bs] seq_lens = self.seq_lens[:bs] out_cache_loc = self.out_cache_loc[:bs] + if self.is_encoder_decoder: + encoder_lens = self.encoder_lens[:bs] + else: + encoder_lens = None + seq_lens_sum = seq_lens.sum().item() # Attention backend self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph( - bs, req_pool_indices, seq_lens + bs, + req_pool_indices, + seq_lens, + encoder_lens, ) # Run and capture @@ -208,6 +250,7 @@ class CudaGraphRunner: attn_backend=self.model_runner.attn_backend, out_cache_loc=out_cache_loc, seq_lens_sum=seq_lens_sum, + encoder_lens=encoder_lens, return_logprob=False, top_logprobs_nums=[0] * bs, positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64), @@ -251,6 +294,8 @@ class CudaGraphRunner: self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc) + if self.is_encoder_decoder: + self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) # Attention backend self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( @@ -258,6 +303,7 @@ class CudaGraphRunner: self.req_pool_indices, self.seq_lens, forward_batch.seq_lens_sum, + self.encoder_lens, ) # Replay diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index f4e117b76..f3065d7a2 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -108,6 +108,12 @@ class ForwardBatch: # For multimodal image_inputs: Optional[List[ImageInputs]] = None + # Encoder-decoder + encoder_cached: Optional[List[bool]] = None + encoder_lens: Optional[torch.Tensor] = None + encoder_lens_cpu: Optional[List[int]] = None + encoder_out_cache_loc: Optional[torch.Tensor] = None + # For LoRA lora_paths: Optional[List[str]] = None @@ -194,6 +200,11 @@ class ForwardBatch: req_pool_indices=batch.req_pool_indices, seq_lens=batch.seq_lens, out_cache_loc=batch.out_cache_loc, + image_inputs=batch.image_inputs, + encoder_cached=batch.encoder_cached, + encoder_lens=batch.encoder_lens, + encoder_lens_cpu=batch.encoder_lens_cpu, + encoder_out_cache_loc=batch.encoder_out_cache_loc, seq_lens_sum=batch.seq_lens_sum, return_logprob=batch.return_logprob, top_logprobs_nums=batch.top_logprobs_nums, @@ -212,11 +223,11 @@ class ForwardBatch: ], axis=0, ) - ret.image_inputs = batch.image_inputs ret.extend_num_tokens = batch.extend_num_tokens ret.extend_seq_lens = torch.tensor( batch.extend_seq_lens, dtype=torch.int32 ).to(device, non_blocking=True) + ret.extend_prefix_lens = torch.tensor( batch.extend_prefix_lens, dtype=torch.int32 ).to(device, non_blocking=True) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 291528e07..e2a2504cb 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -270,7 +270,6 @@ class ModelRunner: if hasattr(self.model, "get_attention_sliding_window_size") else None ) - self.has_cross_attention = getattr(self.model, "has_cross_attention", False) self.is_generation = is_generation_model( self.model_config.hf_config.architectures, self.server_args.is_embedding ) @@ -510,7 +509,7 @@ class ModelRunner: "Window attention is not supported in the triton attention backend. " "Please use `--attention-backend flashinfer`." ) - assert not self.has_cross_attention, ( + assert not self.model_config.is_encoder_decoder, ( "Cross attention is not supported in the triton attention backend. " "Please use `--attention-backend flashinfer`." ) @@ -558,9 +557,7 @@ class ModelRunner: self.cuda_graph_runner = CudaGraphRunner(self) def forward_decode(self, forward_batch: ForwardBatch): - if self.cuda_graph_runner and self.cuda_graph_runner.can_run( - forward_batch.batch_size - ): + if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch): return self.cuda_graph_runner.replay(forward_batch) forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64) diff --git a/python/sglang/srt/models/mllama.py b/python/sglang/srt/models/mllama.py new file mode 100644 index 000000000..7db6f0e1f --- /dev/null +++ b/python/sglang/srt/models/mllama.py @@ -0,0 +1,1004 @@ +# Adapted from: +# https://github.com/vllm-project/vllm/blob/7193774b1ff8603ad5bf4598e5efba0d9a39b436/vllm/model_executor/models/mllama.py +"""PyTorch Mllama model.""" +import math +from typing import Iterable, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers.models.mllama.configuration_mllama as config_mllama +import vllm.distributed.parallel_state as ps +from torch import nn +from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast +from transformers.models.mllama.modeling_mllama import ( + _prepare_aspect_ratio_attention_mask, +) +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from sglang.srt.layers.activation import get_act_fn +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.managers.schedule_batch import ImageInputs +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.models.llama import LlamaDecoderLayer, LlamaMLP + + +class ColumnParallelConv2dPatch(torch.nn.Module): + """Conv2D Patching layer with model parallelism. + Column parallel over unfolded input. + Arguments: + in_channels: Input channels. + out_channels: Output channels. + kernel_size: Size of convolution kernel. + stride (default 1): Stride for convolution. + bias (default False): Use bias in Conv2d. + Input: (bsz, in_channels, width, height) + Output: (bsz, num_tokens, out_channels) + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]], + bias: bool = False, + ) -> None: + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride) + self._linear = ColumnParallelLinear( + in_channels * kernel_size[0] * kernel_size[1], + out_channels, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self._unfold(x) + x = x.permute(0, 2, 1) + x, _ = self._linear(x) + return x + + +class MllamaPrecomputedAspectRatioEmbedding(nn.Module): + + def __init__(self, config: config_mllama.MllamaVisionConfig, is_gated: bool = True): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.hidden_size = config.hidden_size + self.max_aspect_ratio_id = config.max_aspect_ratio_id + self.is_gated = is_gated + + self.embedding = nn.Embedding( + self.max_aspect_ratio_id + 1, self.max_num_tiles * self.hidden_size + ) + if is_gated: + self.gate = nn.Parameter(torch.zeros(1)) + + def forward( + self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor + ) -> torch.Tensor: + embeddings = self.embedding(aspect_ratio_ids) + embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size) + + if self.is_gated: + embeddings = embeddings * self.gate.tanh() + + hidden_state = hidden_state + embeddings + return hidden_state + + +class MllamaPrecomputedPositionEmbedding(nn.Module): + def __init__(self, config: config_mllama.MllamaVisionConfig): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.max_aspect_ratio_id = config.max_aspect_ratio_id + self.num_patches = (config.image_size // config.patch_size) ** 2 + 1 + self.hidden_size = config.hidden_size + self.scale = config.hidden_size**-0.5 + + self.gate = nn.Parameter(torch.zeros(1)) + + # position embedding + position_embedding = torch.randn(self.num_patches, self.hidden_size) + self.embedding = nn.Parameter(self.scale * position_embedding) + + # tile position embedding + self.tile_embedding = nn.Embedding( + self.max_aspect_ratio_id + 1, + self.max_num_tiles * self.num_patches * self.hidden_size, + ) + + def forward( + self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor + ) -> torch.Tensor: + # position embeddings + gated_position_embedding = (1 - self.gate.tanh()) * self.embedding + hidden_state = hidden_state + gated_position_embedding.view( + 1, 1, self.num_patches, self.hidden_size + ) + + # precomputed tile position embeddings + tile_position_embedding = self.tile_embedding(aspect_ratio_ids) + batch_size = hidden_state.shape[0] + tile_position_embedding = tile_position_embedding.reshape( + batch_size, self.max_num_tiles, self.num_patches, self.hidden_size + ) + gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding + hidden_state = hidden_state + gated_tile_position_embedding + + return hidden_state + + +class MllamaVisionSdpaAttention(nn.Module): + def __init__(self, config: config_mllama.MllamaVisionConfig): + super().__init__() + + model_parallel_size = get_tensor_model_parallel_world_size() + self.embed_dim = config.hidden_size + self.num_heads = config.attention_heads + self.head_dim = config.hidden_size // config.attention_heads + self.num_local_heads = self.num_heads // model_parallel_size + self.q_size = self.num_local_heads * self.head_dim + self.kv_size = self.num_local_heads * self.head_dim + + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + bias=False, + ) + self.o_proj = RowParallelLinear( + self.num_heads * self.head_dim, + self.embed_dim, + bias=False, + input_is_parallel=True, + ) + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_state) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q.view( + q.shape[0], q.shape[1], self.num_local_heads, self.head_dim + ).transpose(1, 2) + k = k.view( + k.shape[0], k.shape[1], self.num_local_heads, self.head_dim + ).transpose(1, 2) + v = v.view( + v.shape[0], v.shape[1], self.num_local_heads, self.head_dim + ).transpose(1, 2) + + # TODO: remove padding in image encoder + attn_output = F.scaled_dot_product_attention( + q, k, v, attn_mask=attention_mask, dropout_p=0.0 + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape( + attn_output.shape[0], attn_output.shape[1], -1 + ) + output, _ = self.o_proj(attn_output) + return output + + +class MllamaVisionMLP(nn.Module): + def __init__(self, config, quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + + return hidden_states + + +class MllamaVisionEncoderLayer(nn.Module): + def __init__( + self, config: config_mllama.MllamaVisionConfig, is_gated: bool = False + ): + super().__init__() + + self.hidden_size = config.hidden_size + self.num_attention_heads = config.attention_heads + self.is_gated = is_gated + self.intermediate_size = config.intermediate_size + + self.self_attn = MllamaVisionSdpaAttention(config) + self.mlp = MllamaVisionMLP(config) + + self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps) + self.post_attention_layernorm = nn.LayerNorm( + self.hidden_size, eps=config.norm_eps + ) + + # there used to be an if else here, no code path + if is_gated: + self.gate_attn = nn.Parameter(torch.ones(1) * math.pi / 4) + self.gate_ffn = nn.Parameter(torch.ones(1) * math.pi / 4) + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ): + # Self Attention + residual = hidden_state + hidden_state = self.input_layernorm(hidden_state) + hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask) + gate_attn = 1 if not self.is_gated else self.gate_attn.tanh() + hidden_state = residual + gate_attn * hidden_state + + # Feed forward + residual = hidden_state + hidden_state = self.post_attention_layernorm(hidden_state) + hidden_state = self.mlp(hidden_state) + gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh() + hidden_state = residual + gate_ffn * hidden_state + + return hidden_state + + +class MllamaVisionEncoder(nn.Module): + def __init__( + self, + config: config_mllama.MllamaVisionConfig, + num_layers=32, + is_gated=False, + output_hidden_states=None, + ): + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [MllamaVisionEncoderLayer(config, is_gated) for _ in range(num_layers)] + ) + self.output_hidden_states = output_hidden_states or [] + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> Union[Tuple, BaseModelOutput]: + encoder_states = () + + for i, encoder_layer in enumerate(self.layers): + if i in self.output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + hidden_states = encoder_layer( + hidden_states, + attention_mask, + ) + + if len(self.layers) - 1 in self.output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + return hidden_states, encoder_states + + +class MllamaVisionModel(nn.Module): + def __init__(self, config: config_mllama.MllamaVisionConfig): + super().__init__() + self.image_size = config.image_size + self.patch_size = config.patch_size + self.max_num_tiles = config.max_num_tiles + self.hidden_size = config.hidden_size + self.in_channels = config.num_channels + self.intermediate_layers_indices = config.intermediate_layers_indices + + self.num_patches = (self.image_size // self.patch_size) ** 2 + 1 + self.scale = config.hidden_size**-0.5 + + self.patch_embedding = ColumnParallelConv2dPatch( + in_channels=config.num_channels, + out_channels=self.hidden_size, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size)) + self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(config) + + self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( + config, is_gated=True + ) + self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( + config, is_gated=True + ) + + # layer norms + self.layernorm_pre = nn.LayerNorm(self.hidden_size) + self.layernorm_post = nn.LayerNorm(self.hidden_size) + + # encoders + self.transformer = MllamaVisionEncoder( + config, + config.num_hidden_layers, + is_gated=False, + output_hidden_states=config.intermediate_layers_indices, + ) + self.global_transformer = MllamaVisionEncoder( + config, config.num_global_layers, is_gated=True + ) + + def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: + batch_size, _, hidden_size = hidden_state.shape + class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size) + hidden_state = torch.cat([class_embedding, hidden_state], dim=1) + return hidden_state + + def forward( + self, + pixel_values: torch.Tensor, + aspect_ratio_ids: torch.Tensor, + aspect_ratio_mask: torch.Tensor, + ) -> torch.Tensor: + batch_size, num_concurrent_media, num_tiles, num_channels, height, width = ( + pixel_values.shape + ) + + pixel_values = pixel_values.reshape( + batch_size * num_concurrent_media * num_tiles, num_channels, height, width + ) + aspect_ratio_ids = aspect_ratio_ids.reshape( + batch_size * num_concurrent_media, -1 + ) + + # patch embedding + patch_embeds = self.patch_embedding( + pixel_values.to(self.layernorm_pre.weight.dtype) + ) + hidden_state = patch_embeds + hidden_state = ps.get_tp_group().all_gather(hidden_state) + + # tile embeddings + _, num_patches, dim = hidden_state.shape + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles, -1, dim + ) + hidden_state = self.pre_tile_positional_embedding( + hidden_state, aspect_ratio_ids + ) + + # apply cls token + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media * num_tiles, num_patches, dim + ) + hidden_state = self.apply_class_embedding(hidden_state) + num_patches += 1 + + # apply position embeddings + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles, num_patches, dim + ) + hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids) + + # apply encoder + hidden_state = self.layernorm_pre(hidden_state) + + # Compute the number of tokens to pad + num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 + # Compute padding tuple for pad function + padding = ( + 0, + 0, + 0, + num_padding_patches, + ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2) + # Pad the tensor + hidden_state = F.pad(hidden_state, padding, mode="constant", value=0) + slice_index = -num_padding_patches if num_padding_patches > 0 else None + + attention_mask = aspect_ratio_mask.reshape( + batch_size * num_concurrent_media, -1 + ) + attention_mask = _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask=attention_mask, + num_patches=self.num_patches, + target_length=hidden_state.shape[2], + dtype=self.layernorm_pre.weight.dtype, + ) + + hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim) + output = self.transformer( + hidden_state, + attention_mask=attention_mask, + ) + hidden_state, intermediate_hidden_states = output[0], output[1] + intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1) + + # apply global encoder + hidden_state = self.layernorm_post(hidden_state) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim, + ) + hidden_state = self.post_tile_positional_embedding( + hidden_state, aspect_ratio_ids + ) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles * (num_patches + num_padding_patches), + dim, + ) + hidden_state = self.global_transformer( + hidden_state, attention_mask=attention_mask + )[0] + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim, + ) + hidden_state = hidden_state[:, :, :slice_index] + + # adding intermediate layer outputs + hidden_state = hidden_state.reshape( + batch_size, num_concurrent_media, num_tiles, num_patches, dim + ) + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + -1, + ) + intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index] + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size, num_concurrent_media, num_tiles, num_patches, -1 + ) + hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1) + return hidden_state + + +class MllamaTextRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class MllamaTextCrossAttention(nn.Module): + def __init__( + self, + config: Optional[config_mllama.MllamaTextConfig] = None, + layer_id: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.model_parallel_size = get_tensor_model_parallel_world_size() + self.num_heads = self.config.num_attention_heads + self.num_local_heads = self.num_heads // self.model_parallel_size + self.num_key_value_heads = self.config.num_key_value_heads + self.num_local_key_value_heads = ( + self.num_key_value_heads // self.model_parallel_size + ) + self.dropout = config.dropout + self.hidden_size = config.hidden_size + self.head_dim = config.hidden_size // self.num_heads + self.layer_id = layer_id + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.q_local_size = self.num_local_heads * self.head_dim + self.kv_local_size = self.num_local_key_value_heads * self.head_dim + + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.num_heads, + self.num_key_value_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.num_heads * self.head_dim, + self.hidden_size, + bias=False, + input_is_parallel=True, + quant_config=quant_config, + ) + # vllm.model_executor.layers.layernorm.RMSNorm has precision issue, + # use huggingface's instead + self.q_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.scaling = self.head_dim**-0.5 + + self.attn = RadixAttention( + self.num_local_heads, + self.head_dim, + self.scaling, + self.num_local_key_value_heads, + layer_id=layer_id, + is_cross_attention=True, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + cross_attention_states: Optional[torch.Tensor], + forward_batch: ForwardBatch, + ) -> torch.Tensor: + qkv_dec, _ = self.qkv_proj(hidden_states) + q, _, _ = qkv_dec.split( + [self.q_local_size, self.kv_local_size, self.kv_local_size], dim=-1 + ) + if cross_attention_states is None: + k = None + v = None + else: + qkv_enc, _ = self.qkv_proj(cross_attention_states) + _, k, v = qkv_enc.split( + [self.q_local_size, self.kv_local_size, self.kv_local_size], dim=-1 + ) + k = k.view(-1, self.num_local_key_value_heads, self.head_dim) + v = v.view(-1, self.num_local_key_value_heads, self.head_dim) + k = self.k_norm(k) + q = q.view(-1, self.num_local_heads, self.head_dim) + q = self.q_norm(q) + + output = self.attn(q, k, v, forward_batch) + out, _ = self.o_proj(output) + return out + + +class MllamaCrossAttentionDecoderLayer(torch.nn.Module): + """Cross-attention transformer block with tanh-gated attention + and feedforward.""" + + def __init__( + self, + config: config_mllama.MllamaTextConfig, + layer_id: int, + quant_config: Optional[QuantizationConfig], + ) -> None: + super().__init__() + self.layer_id = layer_id + self.cross_attn = MllamaTextCrossAttention( + config=config, + layer_id=layer_id, + quant_config=quant_config, + ) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1)) + + self.mlp = LlamaMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1)) + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: torch.Tensor, + cross_attention_mask: torch.Tensor, + full_text_row_masked_out_mask: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.cross_attn( + hidden_states=hidden_states, + attention_mask=cross_attention_mask, + cross_attention_states=cross_attention_states, + forward_batch=forward_batch, + ) + hidden_states = full_text_row_masked_out_mask * hidden_states + hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = full_text_row_masked_out_mask * hidden_states + hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states + return hidden_states + + +class MllamaTextModel(nn.Module): + config_class = config_mllama.MllamaTextConfig + base_model_prefix = "model" + + def __init__( + self, + config: config_mllama.MllamaTextConfig, + quant_config: Optional[QuantizationConfig], + cache_config=None, + ): + super().__init__() + self.padding_id = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size + 8, config.hidden_size + ) + self.cross_attention_layers = config.cross_attention_layers + + layers = [] + for layer_id in range(config.num_hidden_layers): + if layer_id in self.cross_attention_layers: + layers.append( + MllamaCrossAttentionDecoderLayer( + config, layer_id, quant_config=quant_config + ) + ) + else: + # TODO: force LlamaDecoderLayer to config.attention_bias=False + layers.append( + LlamaDecoderLayer( + config, quant_config=quant_config, layer_id=layer_id + ) + ) + + self.layers = nn.ModuleList(layers) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.LongTensor, + positions: Optional[torch.LongTensor], + cross_attention_states: Optional[torch.LongTensor], + cross_attention_mask: Optional[torch.LongTensor], + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]], + forward_batch: ForwardBatch, + skip_cross_attention: bool, + ) -> torch.Tensor: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + + for _, decoder_layer in enumerate(self.layers): + if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer): + if not skip_cross_attention: + hidden_states = decoder_layer( + hidden_states=hidden_states, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + forward_batch=forward_batch, + ) + elif isinstance(decoder_layer, LlamaDecoderLayer): + hidden_states, residual = decoder_layer( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + residual=None, + ) + hidden_states = hidden_states + residual + else: + raise ValueError(f"Unknown decoder layer type {type(decoder_layer)}") + hidden_states = self.norm(hidden_states) + return hidden_states + + +class MllamaForCausalLM(nn.Module): + config_class = config_mllama.MllamaTextConfig + base_model_prefix = "language_model" + _no_split_modules = [ + "MllamaCrossAttentionDecoderLayer", + "MllamaSelfAttentionDecoderLayer", + ] + + def __init__( + self, + config: config_mllama.MllamaTextConfig, + quant_config: Optional[QuantizationConfig], + cache_config=None, + ): + super().__init__() + self.vocab_size = config.vocab_size + self.model = MllamaTextModel(config, cache_config, quant_config) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + quant_config=quant_config, + ) + + def forward( + self, + input_ids: torch.LongTensor, + positions: Optional[torch.LongTensor], + cross_attention_states: Optional[torch.LongTensor], + cross_attention_mask: Optional[torch.LongTensor], + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]], + forward_batch: ForwardBatch, + skip_cross_attention: bool, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + forward_batch=forward_batch, + skip_cross_attention=skip_cross_attention, + ) + return hidden_states + + +class MllamaForConditionalGeneration(nn.Module): + def __init__( + self, + config: config_mllama.MllamaConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config=None, + ): + super().__init__() + self.vocab_size = config.text_config.vocab_size + self.hidden_size = config.text_config.hidden_size + self.max_num_tiles = config.vision_config.max_num_tiles + self.vision_output_dim = config.vision_config.vision_output_dim + self.pad_token_id = ( + config.pad_token_id if config.pad_token_id is not None else -1 + ) + self.image_size = config.vision_config.image_size + + self.vision_model = MllamaVisionModel(config.vision_config) + self.language_model = MllamaForCausalLM( + config.text_config, + cache_config=cache_config, + quant_config=quant_config, + ) + self.multi_modal_projector = nn.Linear( + config.vision_config.vision_output_dim, + config.text_config.hidden_size, + bias=True, + ) + self.logits_processor = LogitsProcessor(config.text_config) + self.capture_mode = False + + def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): + pixel_values = image_inputs.pixel_values + pad_values = image_inputs.pad_values + + num_concurrent_media, num_tiles = pixel_values.shape[1:3] + num_patches = self.vision_model.num_patches + image_len = num_concurrent_media * num_tiles * num_patches + image_inputs.num_image_tokens = image_len + + pad_ids = pad_values * ((image_len + len(pad_values)) // len(pad_values)) + + return pad_ids[:image_len] + input_ids + + def _batch_image_inputs(self, forward_batch: ForwardBatch): + if forward_batch.forward_mode.is_decode() or all(forward_batch.encoder_cached): + return None, None, None, None + + # pixel_values: shape (bs, num_image, num_tiles, 3, image_res, image_res) + max_num_images = max_num_tiles = bs = 0 + for i, im in enumerate(forward_batch.image_inputs): + if not forward_batch.encoder_cached[i] and im is not None: + max_num_images = max(max_num_images, im.pixel_values.shape[1]) + max_num_tiles = max(max_num_tiles, im.pixel_values.shape[2]) + bs += 1 + + if max_num_images * max_num_tiles * bs == 0: + return None, None, None, None + + with forward_batch.out_cache_loc.device: + batched_images = torch.zeros( + bs, + max_num_images, + max_num_tiles, + 3, + self.image_size, + self.image_size, + dtype=torch.float32, + ) + batched_ar_ids = torch.ones( + bs, max_num_images, dtype=torch.int64, device="cuda" + ) + batched_ar_mask = torch.zeros( + bs, max_num_images, max_num_tiles, dtype=torch.int64 + ) + i = 0 + encoder_lens_need = [] + for k, im in enumerate(forward_batch.image_inputs): + if forward_batch.encoder_cached[k] or im is None: + continue + + encoder_lens_need.append(forward_batch.encoder_lens[k]) + for j in range(im.pixel_values.shape[1]): + img = im.pixel_values[0, j] + num_tiles = img.shape[0] + batched_images[i, j, :num_tiles] = img + batched_ar_ids[i, j] = im.aspect_ratio_ids[0, j] + batched_ar_mask[i, j, :num_tiles] = im.aspect_ratio_mask[0, j] + i += 1 + + return batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need + + def flat_encoder_result( + self, cross_attention_states: torch.Tensor, encoder_lens_need: List[int] + ): + # NOTE: not all encoders need computation, some are cached + head_dim = cross_attention_states.shape[-1] + total_encoder_len = sum(encoder_lens_need) + cross_attention_states_flat = torch.zeros( + total_encoder_len, + head_dim, + device=cross_attention_states.device, + dtype=cross_attention_states.dtype, + ) + + i = start_pos = 0 + for encoder_len in encoder_lens_need: + if encoder_len == 0: + continue + end_pos = start_pos + encoder_len + cross_attention_states_flat[start_pos:end_pos] = cross_attention_states[i][ + :encoder_len + ] + i += 1 + start_pos += encoder_len + + return cross_attention_states_flat + + def get_full_text_row_masked_out_mask(self, forward_batch: ForwardBatch): + if forward_batch.forward_mode.is_decode(): + full_text_row_masked_out_mask = forward_batch.encoder_lens != 0 + else: + full_text_row_masked_out_mask = torch.ones( + forward_batch.extend_seq_lens.sum(), dtype=torch.bool + ) + start_pos = 0 + + for seq_len, encoder_len in zip( + forward_batch.seq_lens.tolist(), forward_batch.encoder_lens_cpu + ): + if encoder_len == 0: + full_text_row_masked_out_mask[start_pos : start_pos + seq_len] = ( + False + ) + start_pos += encoder_len + + full_text_row_masked_out_mask = full_text_row_masked_out_mask.to( + forward_batch.seq_lens.device + ) + + return full_text_row_masked_out_mask.reshape(-1, 1) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + ) -> Union[Tuple, CausalLMOutputWithPast]: + batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need = ( + self._batch_image_inputs(forward_batch) + ) + + # TODO: support multi-image by this mask + cross_attention_mask = None + cross_attention_states = None + + if self.capture_mode: + # NOTE: when doing cuda graph capture, we do not want to skip cross attention + # Make is a constant value to avoid cuda graph capture issue + skip_cross_attention = False + else: + # NOTE: we do not need image_inputs when prefill + assert len(forward_batch.encoder_lens) == len(forward_batch.seq_lens) + assert len(forward_batch.encoder_lens_cpu) == len(forward_batch.seq_lens) + skip_cross_attention = forward_batch.encoder_lens.max() == 0 + + if not skip_cross_attention: + full_text_row_masked_out_mask = self.get_full_text_row_masked_out_mask( + forward_batch + ) + else: + full_text_row_masked_out_mask = None + + if batched_images is not None: + # NOTE: llama's reference implementation runs vision model on CPU + cross_attention_states = self.vision_model( + batched_images, batched_ar_ids, batched_ar_mask + ) + cross_attention_states = self.multi_modal_projector(cross_attention_states) + + bs, _, _, _, image_token_dim = cross_attention_states.shape + cross_attention_states = cross_attention_states.view( + bs, -1, image_token_dim + ) + + cross_attention_states = self.flat_encoder_result( + cross_attention_states, encoder_lens_need + ) + + hidden_states = self.language_model( + input_ids=input_ids, + positions=positions, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + forward_batch=forward_batch, + skip_cross_attention=skip_cross_attention, + ) + return self.logits_processor( + input_ids, hidden_states, self.language_model.lm_head.weight, forward_batch + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + updated_params = set() + for name, loaded_weight in weights: + if "patch_embedding.weight" in name: + name = name.replace( + "patch_embedding.weight", "patch_embedding._linear.weight" + ) + loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1) + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + updated_params.add(name) + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict.pop(name) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +EntryClass = MllamaForConditionalGeneration diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index ae2d4f58c..b1cc78771 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -605,7 +605,11 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal): ] positions = forward_batch.mrope_positions - if image_inputs is None or len(image_inputs) == 0: + if ( + forward_batch.forward_mode.is_decode() + or image_inputs is None + or len(image_inputs) == 0 + ): inputs_embeds = self.model.embed_tokens(input_ids) else: if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 762f3933d..69aea52ac 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -209,6 +209,7 @@ def is_multimodal_model(model_architectures): or "LlavaQwenForCausalLM" in model_architectures or "LlavaMistralForCausalLM" in model_architectures or "LlavaVidForCausalLM" in model_architectures + or "MllamaForConditionalGeneration" in model_architectures or "Qwen2VLForConditionalGeneration" in model_architectures ): return True diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py index 296572ea9..bf8f9d277 100644 --- a/test/srt/test_vision_openai_server.py +++ b/test/srt/test_vision_openai_server.py @@ -171,7 +171,7 @@ class TestOpenAIVisionServer(unittest.TestCase): assert isinstance(text, str) print(text) assert "man" in text or "cab" in text, text - assert "logo" in text, text + assert "logo" in text or '"S"' in text or "SG" in text, text assert response.id assert response.created assert response.usage.prompt_tokens > 0 @@ -363,5 +363,27 @@ class TestQWen2VLServer(TestOpenAIVisionServer): cls.base_url += "/v1" +class TestMllamaServer(TestOpenAIVisionServer): + @classmethod + def setUpClass(cls): + cls.model = "meta-llama/Llama-3.2-11B-Vision-Instruct" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, + other_args=[ + "--chat-template", + "llama_3_vision", + ], + ) + cls.base_url += "/v1" + + def test_video_chat_completion(self): + pass + + if __name__ == "__main__": unittest.main()