vlm: adapt internvl to VisionAttention (#6870)
This commit is contained in:
@@ -1,15 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
import math
|
||||
from functools import lru_cache, wraps
|
||||
from typing import Optional, Tuple
|
||||
from functools import lru_cache
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from sglang.srt.utils import is_cuda
|
||||
from sglang.srt.utils import is_cuda, print_info_once
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
@@ -29,29 +31,42 @@ from sglang.srt.layers.linear import (
|
||||
from sglang.srt.layers.quantization import QuantizationConfig
|
||||
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import add_prefix, logger
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
ROTARY_EMBED_CLASSES = {
|
||||
"normal": apply_rotary_pos_emb,
|
||||
}
|
||||
|
||||
|
||||
def execute_once(func):
|
||||
has_run = None
|
||||
@dataclasses.dataclass
|
||||
class SingletonCache:
|
||||
data: Any = None
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
nonlocal has_run
|
||||
if not has_run:
|
||||
func(*args, **kwargs)
|
||||
has_run = True
|
||||
def set_data(self, value: Any) -> None:
|
||||
self.data = value
|
||||
|
||||
return wrapper
|
||||
def get_data(self) -> Optional[Any]:
|
||||
return self.data
|
||||
|
||||
def empty(self) -> bool:
|
||||
return self.get_data() is None
|
||||
|
||||
|
||||
@execute_once
|
||||
def info_once(message: str):
|
||||
logger.info(message)
|
||||
# TODO: requires real seqlens from images
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def _get_cu_seqlens_for_shape(batch_size: int, seqlen: int, device) -> torch.Tensor:
|
||||
"""
|
||||
Generates cumulative sequence lengths (cu_seqlens) for a given batch_size, seqlen, and device.
|
||||
Caches the result based on these parameters.
|
||||
"""
|
||||
cu_seqlens = torch.arange(
|
||||
0,
|
||||
(batch_size + 1) * seqlen,
|
||||
step=seqlen,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
return cu_seqlens
|
||||
|
||||
|
||||
class VisionSdpaAttention(nn.Module):
|
||||
@@ -265,8 +280,9 @@ class VisionFlash3Attention(nn.Module):
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: Optional[Union[SingletonCache, torch.Tensor]],
|
||||
bsz: int,
|
||||
seq_len: int,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
@@ -275,7 +291,16 @@ class VisionFlash3Attention(nn.Module):
|
||||
Returns:
|
||||
[b * s, h, head_size]
|
||||
"""
|
||||
cu_seqlens = cu_seqlens.to(dtype=torch.int32).cuda()
|
||||
if cu_seqlens is None:
|
||||
cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
|
||||
elif isinstance(cu_seqlens, SingletonCache):
|
||||
if cu_seqlens.empty():
|
||||
cu_seqlens.set_data(
|
||||
_get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
|
||||
)
|
||||
cu_seqlens = cu_seqlens.get_data()
|
||||
|
||||
cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device)
|
||||
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
max_seqlen = seq_lens.max().item()
|
||||
output = flash_attn_varlen_func(
|
||||
@@ -346,11 +371,11 @@ class VisionAttention(nn.Module):
|
||||
if global_server_args_dict["mm_attention_backend"] is None:
|
||||
if qkv_backend is None:
|
||||
qkv_backend = "sdpa"
|
||||
info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
|
||||
print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
|
||||
else:
|
||||
qkv_backend = global_server_args_dict["mm_attention_backend"]
|
||||
|
||||
info_once(f"Using {qkv_backend} as multimodal attention backend.")
|
||||
print_info_once(f"Using {qkv_backend} as multimodal attention backend.")
|
||||
|
||||
self.qkv_backend = QKV_BACKEND_IMPL[qkv_backend](
|
||||
head_dim=self.head_size,
|
||||
@@ -423,15 +448,16 @@ class VisionAttention(nn.Module):
|
||||
# [s, b, embed_dim] --> [s, b, head * 3 * head_size]
|
||||
qkv, _ = self.qkv_proj(x)
|
||||
|
||||
# [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
|
||||
# [s, b, head, head_dim_sum]
|
||||
new_x_shape = qkv.size()[:-1] + (
|
||||
head,
|
||||
3 * self.hidden_size_per_attention_head,
|
||||
self.q_size + 2 * self.kv_size,
|
||||
)
|
||||
qkv = qkv.view(*new_x_shape)
|
||||
|
||||
# [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size]
|
||||
q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
# [s, b, head, head_size] --> [b, s, head, head_size]
|
||||
q, k, v = [
|
||||
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
|
||||
@@ -468,6 +494,7 @@ class VisionAttention(nn.Module):
|
||||
k=k,
|
||||
v=v,
|
||||
bsz=bsz,
|
||||
seq_len=s,
|
||||
cu_seqlens=cu_seqlens,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user