vlm: adapt internvl to VisionAttention (#6870)

This commit is contained in:
Mick
2025-06-11 16:16:04 +08:00
committed by GitHub
parent 2a5f0100e0
commit 83d87685c5
3 changed files with 105 additions and 128 deletions

View File

@@ -1,15 +1,17 @@
from __future__ import annotations
import dataclasses
import functools
import math
from functools import lru_cache, wraps
from typing import Optional, Tuple
from functools import lru_cache
from typing import Any, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from sglang.srt.utils import is_cuda
from sglang.srt.utils import is_cuda, print_info_once
_is_cuda = is_cuda()
@@ -29,29 +31,42 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import add_prefix, logger
from sglang.srt.utils import add_prefix
ROTARY_EMBED_CLASSES = {
"normal": apply_rotary_pos_emb,
}
def execute_once(func):
has_run = None
@dataclasses.dataclass
class SingletonCache:
data: Any = None
@wraps(func)
def wrapper(*args, **kwargs):
nonlocal has_run
if not has_run:
func(*args, **kwargs)
has_run = True
def set_data(self, value: Any) -> None:
self.data = value
return wrapper
def get_data(self) -> Optional[Any]:
return self.data
def empty(self) -> bool:
return self.get_data() is None
@execute_once
def info_once(message: str):
logger.info(message)
# TODO: requires real seqlens from images
@functools.lru_cache(maxsize=128)
def _get_cu_seqlens_for_shape(batch_size: int, seqlen: int, device) -> torch.Tensor:
"""
Generates cumulative sequence lengths (cu_seqlens) for a given batch_size, seqlen, and device.
Caches the result based on these parameters.
"""
cu_seqlens = torch.arange(
0,
(batch_size + 1) * seqlen,
step=seqlen,
dtype=torch.int32,
device=device,
)
return cu_seqlens
class VisionSdpaAttention(nn.Module):
@@ -265,8 +280,9 @@ class VisionFlash3Attention(nn.Module):
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
cu_seqlens: Optional[Union[SingletonCache, torch.Tensor]],
bsz: int,
seq_len: int,
**kwargs,
) -> torch.Tensor:
r"""
@@ -275,7 +291,16 @@ class VisionFlash3Attention(nn.Module):
Returns:
[b * s, h, head_size]
"""
cu_seqlens = cu_seqlens.to(dtype=torch.int32).cuda()
if cu_seqlens is None:
cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
elif isinstance(cu_seqlens, SingletonCache):
if cu_seqlens.empty():
cu_seqlens.set_data(
_get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
)
cu_seqlens = cu_seqlens.get_data()
cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device)
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
max_seqlen = seq_lens.max().item()
output = flash_attn_varlen_func(
@@ -346,11 +371,11 @@ class VisionAttention(nn.Module):
if global_server_args_dict["mm_attention_backend"] is None:
if qkv_backend is None:
qkv_backend = "sdpa"
info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
else:
qkv_backend = global_server_args_dict["mm_attention_backend"]
info_once(f"Using {qkv_backend} as multimodal attention backend.")
print_info_once(f"Using {qkv_backend} as multimodal attention backend.")
self.qkv_backend = QKV_BACKEND_IMPL[qkv_backend](
head_dim=self.head_size,
@@ -423,15 +448,16 @@ class VisionAttention(nn.Module):
# [s, b, embed_dim] --> [s, b, head * 3 * head_size]
qkv, _ = self.qkv_proj(x)
# [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
# [s, b, head, head_dim_sum]
new_x_shape = qkv.size()[:-1] + (
head,
3 * self.hidden_size_per_attention_head,
self.q_size + 2 * self.kv_size,
)
qkv = qkv.view(*new_x_shape)
# [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size]
q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# [s, b, head, head_size] --> [b, s, head, head_size]
q, k, v = [
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
@@ -468,6 +494,7 @@ class VisionAttention(nn.Module):
k=k,
v=v,
bsz=bsz,
seq_len=s,
cu_seqlens=cu_seqlens,
attention_mask=attention_mask,
)