diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index f1f45e27a..41f3110cd 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -1,15 +1,17 @@ from __future__ import annotations +import dataclasses +import functools import math -from functools import lru_cache, wraps -from typing import Optional, Tuple +from functools import lru_cache +from typing import Any, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from sglang.srt.utils import is_cuda +from sglang.srt.utils import is_cuda, print_info_once _is_cuda = is_cuda() @@ -29,29 +31,42 @@ from sglang.srt.layers.linear import ( from sglang.srt.layers.quantization import QuantizationConfig from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb from sglang.srt.managers.schedule_batch import global_server_args_dict -from sglang.srt.utils import add_prefix, logger +from sglang.srt.utils import add_prefix ROTARY_EMBED_CLASSES = { "normal": apply_rotary_pos_emb, } -def execute_once(func): - has_run = None +@dataclasses.dataclass +class SingletonCache: + data: Any = None - @wraps(func) - def wrapper(*args, **kwargs): - nonlocal has_run - if not has_run: - func(*args, **kwargs) - has_run = True + def set_data(self, value: Any) -> None: + self.data = value - return wrapper + def get_data(self) -> Optional[Any]: + return self.data + + def empty(self) -> bool: + return self.get_data() is None -@execute_once -def info_once(message: str): - logger.info(message) +# TODO: requires real seqlens from images +@functools.lru_cache(maxsize=128) +def _get_cu_seqlens_for_shape(batch_size: int, seqlen: int, device) -> torch.Tensor: + """ + Generates cumulative sequence lengths (cu_seqlens) for a given batch_size, seqlen, and device. + Caches the result based on these parameters. + """ + cu_seqlens = torch.arange( + 0, + (batch_size + 1) * seqlen, + step=seqlen, + dtype=torch.int32, + device=device, + ) + return cu_seqlens class VisionSdpaAttention(nn.Module): @@ -265,8 +280,9 @@ class VisionFlash3Attention(nn.Module): q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - cu_seqlens: Optional[torch.Tensor], - attention_mask: Optional[torch.Tensor] = None, + cu_seqlens: Optional[Union[SingletonCache, torch.Tensor]], + bsz: int, + seq_len: int, **kwargs, ) -> torch.Tensor: r""" @@ -275,7 +291,16 @@ class VisionFlash3Attention(nn.Module): Returns: [b * s, h, head_size] """ - cu_seqlens = cu_seqlens.to(dtype=torch.int32).cuda() + if cu_seqlens is None: + cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device) + elif isinstance(cu_seqlens, SingletonCache): + if cu_seqlens.empty(): + cu_seqlens.set_data( + _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device) + ) + cu_seqlens = cu_seqlens.get_data() + + cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device) seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] max_seqlen = seq_lens.max().item() output = flash_attn_varlen_func( @@ -346,11 +371,11 @@ class VisionAttention(nn.Module): if global_server_args_dict["mm_attention_backend"] is None: if qkv_backend is None: qkv_backend = "sdpa" - info_once(f"Multimodal attention backend not set. Use {qkv_backend}.") + print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.") else: qkv_backend = global_server_args_dict["mm_attention_backend"] - info_once(f"Using {qkv_backend} as multimodal attention backend.") + print_info_once(f"Using {qkv_backend} as multimodal attention backend.") self.qkv_backend = QKV_BACKEND_IMPL[qkv_backend]( head_dim=self.head_size, @@ -423,15 +448,16 @@ class VisionAttention(nn.Module): # [s, b, embed_dim] --> [s, b, head * 3 * head_size] qkv, _ = self.qkv_proj(x) - # [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size] + # [s, b, head, head_dim_sum] new_x_shape = qkv.size()[:-1] + ( head, - 3 * self.hidden_size_per_attention_head, + self.q_size + 2 * self.kv_size, ) qkv = qkv.view(*new_x_shape) # [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size] - q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + # [s, b, head, head_size] --> [b, s, head, head_size] q, k, v = [ rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v) @@ -468,6 +494,7 @@ class VisionAttention(nn.Module): k=k, v=v, bsz=bsz, + seq_len=s, cu_seqlens=cu_seqlens, attention_mask=attention_mask, ) diff --git a/python/sglang/srt/models/internvl.py b/python/sglang/srt/models/internvl.py index 4b2d4ac50..d466b73b9 100644 --- a/python/sglang/srt/models/internvl.py +++ b/python/sglang/srt/models/internvl.py @@ -11,21 +11,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # ==========================582==================================================== - -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/7f62077af5159c625fe3ad1c812e6c1a2b93ba3b/vllm/model_executor/models/internlm2.py # Adapted from https://raw.githubusercontent.com/hehesangsj/sglang/refs/heads/internvl/python/sglang/srt/models/internvl.py import torch.nn.functional as F -from einops import rearrange, repeat -from sgl_kernel.flash_attn import flash_attn_varlen_func from torch import nn from transformers import PretrainedConfig, PreTrainedModel from transformers.activations import ACT2FN from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from sglang.srt.layers.attention.vision import SingletonCache, VisionAttention from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.managers.mm_utils import ( MultiModalityDataPaddingPatternTokenPairs, @@ -40,75 +38,12 @@ from sglang.srt.models.qwen2 import Qwen2ForCausalLM from sglang.utils import logger -class FlashAttention(nn.Module): - """Implement the scaled dot product attention with softmax. - Arguments - --------- - softmax_scale: The temperature to use for the softmax attention. - (default: 1/sqrt(d_keys) where d_keys is computed at - runtime) - attention_dropout: The dropout rate to apply to the attention - (default: 0.0) - """ - - def __init__( - self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None - ): - super().__init__() - self.softmax_scale = softmax_scale - self.dropout_p = attention_dropout - - def forward( - self, - qkv, - causal=False, - max_s=None, - ): - """Implements the multihead softmax attention. - Arguments - --------- - qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None - if unpadded: (nnz, 3, h, d) - """ - assert qkv.dtype in [torch.float16, torch.bfloat16] - assert qkv.is_cuda - - batch_size, seqlen, _, nheads, d = qkv.shape - if batch_size == 0 or seqlen == 0: - output_shape = (batch_size, seqlen, nheads, d) - return ( - torch.zeros(output_shape, dtype=qkv.dtype, device=qkv.device), - None, - ) - - qkv_reshaped = rearrange(qkv, "b s three h d -> (b s) three h d", three=3) - q, k, v = qkv_reshaped.unbind(1) - - max_s = seqlen - cu_seqlens = torch.arange( - 0, - (batch_size + 1) * seqlen, - step=seqlen, - dtype=torch.int32, - device=qkv.device, - ) - output_reshaped = flash_attn_varlen_func( - q, - k, - v, - cu_seqlens, - cu_seqlens, - max_s, - max_s, - softmax_scale=self.softmax_scale, - causal=causal, - ) - output = rearrange(output_reshaped, "(b s) h d -> b s h d", b=batch_size) - return output, None - - class InternAttention(nn.Module): - def __init__(self, config): + def __init__( + self, + config, + quant_config: QuantizationConfig = None, + ): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -116,7 +51,19 @@ class InternAttention(nn.Module): self.head_dim = self.embed_dim // self.num_heads self.scale = self.head_dim**-0.5 - self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias) + + self.attn = VisionAttention( + qkv_backend="fa3", + embed_dim=self.embed_dim, + num_heads=self.num_heads, + projection_size=self.embed_dim, + use_qkv_parallel=True, + quant_config=quant_config, + dropout=getattr(config, "dropout", 0.0), + proj_bias=getattr(config, "qkv_bias", True), + flatten_batch=False, + ) + self.proj_drop = nn.Dropout(config.dropout) self.qk_normalization = config.qk_normalization @@ -125,36 +72,15 @@ class InternAttention(nn.Module): self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps) self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps) - self.inner_attn = FlashAttention(softmax_scale=self.scale) - - self.proj = nn.Linear(self.embed_dim, self.embed_dim) - - def _flash_attn( + def forward( self, - x, - ): - qkv = self.qkv(x) - qkv = rearrange( - qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads - ) - - if self.qk_normalization: - q, k, v = qkv.unbind(2) - q = self.q_norm(q.flatten(-2, -1)).view(q.shape) - k = self.k_norm(k.flatten(-2, -1)).view(k.shape) - qkv = torch.stack([q, k, v], dim=2) - - context, _ = self.inner_attn( - qkv, - ) - outs = self.proj(rearrange(context, "b s h d -> b s (h d)")) - outs = self.proj_drop(outs) + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + ) -> torch.Tensor: + out = self.attn(hidden_states, cu_seqlens=cu_seqlens) + outs = self.proj_drop(out) return outs - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - x = self._flash_attn(hidden_states) - return x - class InternVisionEmbeddings(nn.Module): def __init__(self, config: PretrainedConfig): @@ -286,6 +212,7 @@ class InternVisionEncoderLayer(nn.Module): def forward( self, hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, ) -> Tuple[ torch.FloatTensor, Optional[torch.FloatTensor], @@ -295,8 +222,12 @@ class InternVisionEncoderLayer(nn.Module): Args: hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)` """ + hidden_states = hidden_states + self.drop_path1( - self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1 + self.attn( + self.norm1(hidden_states).to(hidden_states.dtype), cu_seqlens=cu_seqlens + ) + * self.ls1 ) hidden_states = hidden_states + self.drop_path2( @@ -363,12 +294,12 @@ class InternVisionEncoder(nn.Module): encoder_states = () if output_hidden_states else None hidden_states = inputs_embeds + cu_seqlens = SingletonCache() + for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - layer_outputs = encoder_layer( - hidden_states, - ) + layer_outputs = encoder_layer(hidden_states, cu_seqlens=cu_seqlens) hidden_states = layer_outputs if output_hidden_states: @@ -625,6 +556,7 @@ class InternVLChatModel(nn.Module): ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: @@ -641,6 +573,11 @@ class InternVLChatModel(nn.Module): weight_loader(param, loaded_weight, shard_id) break else: + if "vision_model" in name: + # adapt to VisionAttention + name = name.replace(r"attn.", r"attn.attn.") + name = name.replace(r"qkv.", r"qkv_proj.") + # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue @@ -665,6 +602,13 @@ class InternVLChatModel(nn.Module): param, "weight_loader", default_weight_loader ) weight_loader(param, loaded_weight) + loaded_params.add(name) + unloaded_params = params_dict.keys() - loaded_params + if unloaded_params: + raise RuntimeError( + f"Some weights are not initialized from checkpoints: {unloaded_params}" + ) + return loaded_params EntryClass = InternVLChatModel diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 779843981..fe8dcbf8e 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -17,6 +17,7 @@ import base64 import builtins import ctypes import dataclasses +import functools import importlib import io import ipaddress @@ -1386,6 +1387,11 @@ def print_warning_once(msg: str) -> None: logger.warning(msg, stacklevel=2) +@functools.lru_cache(None) +def print_info_once(msg: str) -> None: + logger.info(msg) + + def get_device_name(device_id: int = 0) -> str: if hasattr(torch, "cuda") and torch.cuda.is_available(): return torch.cuda.get_device_name(device_id)