vlm: adapt internvl to VisionAttention (#6870)
This commit is contained in:
@@ -1,15 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
import math
|
||||
from functools import lru_cache, wraps
|
||||
from typing import Optional, Tuple
|
||||
from functools import lru_cache
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from sglang.srt.utils import is_cuda
|
||||
from sglang.srt.utils import is_cuda, print_info_once
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
@@ -29,29 +31,42 @@ from sglang.srt.layers.linear import (
|
||||
from sglang.srt.layers.quantization import QuantizationConfig
|
||||
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import add_prefix, logger
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
ROTARY_EMBED_CLASSES = {
|
||||
"normal": apply_rotary_pos_emb,
|
||||
}
|
||||
|
||||
|
||||
def execute_once(func):
|
||||
has_run = None
|
||||
@dataclasses.dataclass
|
||||
class SingletonCache:
|
||||
data: Any = None
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
nonlocal has_run
|
||||
if not has_run:
|
||||
func(*args, **kwargs)
|
||||
has_run = True
|
||||
def set_data(self, value: Any) -> None:
|
||||
self.data = value
|
||||
|
||||
return wrapper
|
||||
def get_data(self) -> Optional[Any]:
|
||||
return self.data
|
||||
|
||||
def empty(self) -> bool:
|
||||
return self.get_data() is None
|
||||
|
||||
|
||||
@execute_once
|
||||
def info_once(message: str):
|
||||
logger.info(message)
|
||||
# TODO: requires real seqlens from images
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def _get_cu_seqlens_for_shape(batch_size: int, seqlen: int, device) -> torch.Tensor:
|
||||
"""
|
||||
Generates cumulative sequence lengths (cu_seqlens) for a given batch_size, seqlen, and device.
|
||||
Caches the result based on these parameters.
|
||||
"""
|
||||
cu_seqlens = torch.arange(
|
||||
0,
|
||||
(batch_size + 1) * seqlen,
|
||||
step=seqlen,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
return cu_seqlens
|
||||
|
||||
|
||||
class VisionSdpaAttention(nn.Module):
|
||||
@@ -265,8 +280,9 @@ class VisionFlash3Attention(nn.Module):
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: Optional[Union[SingletonCache, torch.Tensor]],
|
||||
bsz: int,
|
||||
seq_len: int,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
@@ -275,7 +291,16 @@ class VisionFlash3Attention(nn.Module):
|
||||
Returns:
|
||||
[b * s, h, head_size]
|
||||
"""
|
||||
cu_seqlens = cu_seqlens.to(dtype=torch.int32).cuda()
|
||||
if cu_seqlens is None:
|
||||
cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
|
||||
elif isinstance(cu_seqlens, SingletonCache):
|
||||
if cu_seqlens.empty():
|
||||
cu_seqlens.set_data(
|
||||
_get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
|
||||
)
|
||||
cu_seqlens = cu_seqlens.get_data()
|
||||
|
||||
cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device)
|
||||
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
max_seqlen = seq_lens.max().item()
|
||||
output = flash_attn_varlen_func(
|
||||
@@ -346,11 +371,11 @@ class VisionAttention(nn.Module):
|
||||
if global_server_args_dict["mm_attention_backend"] is None:
|
||||
if qkv_backend is None:
|
||||
qkv_backend = "sdpa"
|
||||
info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
|
||||
print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
|
||||
else:
|
||||
qkv_backend = global_server_args_dict["mm_attention_backend"]
|
||||
|
||||
info_once(f"Using {qkv_backend} as multimodal attention backend.")
|
||||
print_info_once(f"Using {qkv_backend} as multimodal attention backend.")
|
||||
|
||||
self.qkv_backend = QKV_BACKEND_IMPL[qkv_backend](
|
||||
head_dim=self.head_size,
|
||||
@@ -423,15 +448,16 @@ class VisionAttention(nn.Module):
|
||||
# [s, b, embed_dim] --> [s, b, head * 3 * head_size]
|
||||
qkv, _ = self.qkv_proj(x)
|
||||
|
||||
# [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
|
||||
# [s, b, head, head_dim_sum]
|
||||
new_x_shape = qkv.size()[:-1] + (
|
||||
head,
|
||||
3 * self.hidden_size_per_attention_head,
|
||||
self.q_size + 2 * self.kv_size,
|
||||
)
|
||||
qkv = qkv.view(*new_x_shape)
|
||||
|
||||
# [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size]
|
||||
q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
# [s, b, head, head_size] --> [b, s, head, head_size]
|
||||
q, k, v = [
|
||||
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
|
||||
@@ -468,6 +494,7 @@ class VisionAttention(nn.Module):
|
||||
k=k,
|
||||
v=v,
|
||||
bsz=bsz,
|
||||
seq_len=s,
|
||||
cu_seqlens=cu_seqlens,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
@@ -11,21 +11,19 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==========================582====================================================
|
||||
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/7f62077af5159c625fe3ad1c812e6c1a2b93ba3b/vllm/model_executor/models/internlm2.py
|
||||
# Adapted from https://raw.githubusercontent.com/hehesangsj/sglang/refs/heads/internvl/python/sglang/srt/models/internvl.py
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from sgl_kernel.flash_attn import flash_attn_varlen_func
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
|
||||
from sglang.srt.layers.attention.vision import SingletonCache, VisionAttention
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternTokenPairs,
|
||||
@@ -40,75 +38,12 @@ from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||
from sglang.utils import logger
|
||||
|
||||
|
||||
class FlashAttention(nn.Module):
|
||||
"""Implement the scaled dot product attention with softmax.
|
||||
Arguments
|
||||
---------
|
||||
softmax_scale: The temperature to use for the softmax attention.
|
||||
(default: 1/sqrt(d_keys) where d_keys is computed at
|
||||
runtime)
|
||||
attention_dropout: The dropout rate to apply to the attention
|
||||
(default: 0.0)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None
|
||||
):
|
||||
super().__init__()
|
||||
self.softmax_scale = softmax_scale
|
||||
self.dropout_p = attention_dropout
|
||||
|
||||
def forward(
|
||||
self,
|
||||
qkv,
|
||||
causal=False,
|
||||
max_s=None,
|
||||
):
|
||||
"""Implements the multihead softmax attention.
|
||||
Arguments
|
||||
---------
|
||||
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
|
||||
if unpadded: (nnz, 3, h, d)
|
||||
"""
|
||||
assert qkv.dtype in [torch.float16, torch.bfloat16]
|
||||
assert qkv.is_cuda
|
||||
|
||||
batch_size, seqlen, _, nheads, d = qkv.shape
|
||||
if batch_size == 0 or seqlen == 0:
|
||||
output_shape = (batch_size, seqlen, nheads, d)
|
||||
return (
|
||||
torch.zeros(output_shape, dtype=qkv.dtype, device=qkv.device),
|
||||
None,
|
||||
)
|
||||
|
||||
qkv_reshaped = rearrange(qkv, "b s three h d -> (b s) three h d", three=3)
|
||||
q, k, v = qkv_reshaped.unbind(1)
|
||||
|
||||
max_s = seqlen
|
||||
cu_seqlens = torch.arange(
|
||||
0,
|
||||
(batch_size + 1) * seqlen,
|
||||
step=seqlen,
|
||||
dtype=torch.int32,
|
||||
device=qkv.device,
|
||||
)
|
||||
output_reshaped = flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
max_s,
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=causal,
|
||||
)
|
||||
output = rearrange(output_reshaped, "(b s) h d -> b s h d", b=batch_size)
|
||||
return output, None
|
||||
|
||||
|
||||
class InternAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
quant_config: QuantizationConfig = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
@@ -116,7 +51,19 @@ class InternAttention(nn.Module):
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
|
||||
|
||||
self.attn = VisionAttention(
|
||||
qkv_backend="fa3",
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=self.num_heads,
|
||||
projection_size=self.embed_dim,
|
||||
use_qkv_parallel=True,
|
||||
quant_config=quant_config,
|
||||
dropout=getattr(config, "dropout", 0.0),
|
||||
proj_bias=getattr(config, "qkv_bias", True),
|
||||
flatten_batch=False,
|
||||
)
|
||||
|
||||
self.proj_drop = nn.Dropout(config.dropout)
|
||||
|
||||
self.qk_normalization = config.qk_normalization
|
||||
@@ -125,36 +72,15 @@ class InternAttention(nn.Module):
|
||||
self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
self.inner_attn = FlashAttention(softmax_scale=self.scale)
|
||||
|
||||
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
def _flash_attn(
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
):
|
||||
qkv = self.qkv(x)
|
||||
qkv = rearrange(
|
||||
qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads
|
||||
)
|
||||
|
||||
if self.qk_normalization:
|
||||
q, k, v = qkv.unbind(2)
|
||||
q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
|
||||
k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
|
||||
qkv = torch.stack([q, k, v], dim=2)
|
||||
|
||||
context, _ = self.inner_attn(
|
||||
qkv,
|
||||
)
|
||||
outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
|
||||
outs = self.proj_drop(outs)
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
out = self.attn(hidden_states, cu_seqlens=cu_seqlens)
|
||||
outs = self.proj_drop(out)
|
||||
return outs
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
x = self._flash_attn(hidden_states)
|
||||
return x
|
||||
|
||||
|
||||
class InternVisionEmbeddings(nn.Module):
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
@@ -286,6 +212,7 @@ class InternVisionEncoderLayer(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
) -> Tuple[
|
||||
torch.FloatTensor,
|
||||
Optional[torch.FloatTensor],
|
||||
@@ -295,8 +222,12 @@ class InternVisionEncoderLayer(nn.Module):
|
||||
Args:
|
||||
hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
"""
|
||||
|
||||
hidden_states = hidden_states + self.drop_path1(
|
||||
self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1
|
||||
self.attn(
|
||||
self.norm1(hidden_states).to(hidden_states.dtype), cu_seqlens=cu_seqlens
|
||||
)
|
||||
* self.ls1
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + self.drop_path2(
|
||||
@@ -363,12 +294,12 @@ class InternVisionEncoder(nn.Module):
|
||||
encoder_states = () if output_hidden_states else None
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
cu_seqlens = SingletonCache()
|
||||
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
)
|
||||
layer_outputs = encoder_layer(hidden_states, cu_seqlens=cu_seqlens)
|
||||
hidden_states = layer_outputs
|
||||
|
||||
if output_hidden_states:
|
||||
@@ -625,6 +556,7 @@ class InternVLChatModel(nn.Module):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
@@ -641,6 +573,11 @@ class InternVLChatModel(nn.Module):
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
if "vision_model" in name:
|
||||
# adapt to VisionAttention
|
||||
name = name.replace(r"attn.", r"attn.attn.")
|
||||
name = name.replace(r"qkv.", r"qkv_proj.")
|
||||
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
@@ -665,6 +602,13 @@ class InternVLChatModel(nn.Module):
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
unloaded_params = params_dict.keys() - loaded_params
|
||||
if unloaded_params:
|
||||
raise RuntimeError(
|
||||
f"Some weights are not initialized from checkpoints: {unloaded_params}"
|
||||
)
|
||||
return loaded_params
|
||||
|
||||
|
||||
EntryClass = InternVLChatModel
|
||||
|
||||
@@ -17,6 +17,7 @@ import base64
|
||||
import builtins
|
||||
import ctypes
|
||||
import dataclasses
|
||||
import functools
|
||||
import importlib
|
||||
import io
|
||||
import ipaddress
|
||||
@@ -1386,6 +1387,11 @@ def print_warning_once(msg: str) -> None:
|
||||
logger.warning(msg, stacklevel=2)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def print_info_once(msg: str) -> None:
|
||||
logger.info(msg)
|
||||
|
||||
|
||||
def get_device_name(device_id: int = 0) -> str:
|
||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||
return torch.cuda.get_device_name(device_id)
|
||||
|
||||
Reference in New Issue
Block a user