vlm: adapt internvl to VisionAttention (#6870)

This commit is contained in:
Mick
2025-06-11 16:16:04 +08:00
committed by GitHub
parent 2a5f0100e0
commit 83d87685c5
3 changed files with 105 additions and 128 deletions

View File

@@ -11,21 +11,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==========================582====================================================
from typing import Iterable, List, Optional, Tuple, Union
from typing import Iterable, List, Optional, Set, Tuple, Union
import torch
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/7f62077af5159c625fe3ad1c812e6c1a2b93ba3b/vllm/model_executor/models/internlm2.py
# Adapted from https://raw.githubusercontent.com/hehesangsj/sglang/refs/heads/internvl/python/sglang/srt/models/internvl.py
import torch.nn.functional as F
from einops import rearrange, repeat
from sgl_kernel.flash_attn import flash_attn_varlen_func
from torch import nn
from transformers import PretrainedConfig, PreTrainedModel
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from sglang.srt.layers.attention.vision import SingletonCache, VisionAttention
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
@@ -40,75 +38,12 @@ from sglang.srt.models.qwen2 import Qwen2ForCausalLM
from sglang.utils import logger
class FlashAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(
self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None
):
super().__init__()
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
def forward(
self,
qkv,
causal=False,
max_s=None,
):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
if unpadded: (nnz, 3, h, d)
"""
assert qkv.dtype in [torch.float16, torch.bfloat16]
assert qkv.is_cuda
batch_size, seqlen, _, nheads, d = qkv.shape
if batch_size == 0 or seqlen == 0:
output_shape = (batch_size, seqlen, nheads, d)
return (
torch.zeros(output_shape, dtype=qkv.dtype, device=qkv.device),
None,
)
qkv_reshaped = rearrange(qkv, "b s three h d -> (b s) three h d", three=3)
q, k, v = qkv_reshaped.unbind(1)
max_s = seqlen
cu_seqlens = torch.arange(
0,
(batch_size + 1) * seqlen,
step=seqlen,
dtype=torch.int32,
device=qkv.device,
)
output_reshaped = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
softmax_scale=self.softmax_scale,
causal=causal,
)
output = rearrange(output_reshaped, "(b s) h d -> b s h d", b=batch_size)
return output, None
class InternAttention(nn.Module):
def __init__(self, config):
def __init__(
self,
config,
quant_config: QuantizationConfig = None,
):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
@@ -116,7 +51,19 @@ class InternAttention(nn.Module):
self.head_dim = self.embed_dim // self.num_heads
self.scale = self.head_dim**-0.5
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
self.attn = VisionAttention(
qkv_backend="fa3",
embed_dim=self.embed_dim,
num_heads=self.num_heads,
projection_size=self.embed_dim,
use_qkv_parallel=True,
quant_config=quant_config,
dropout=getattr(config, "dropout", 0.0),
proj_bias=getattr(config, "qkv_bias", True),
flatten_batch=False,
)
self.proj_drop = nn.Dropout(config.dropout)
self.qk_normalization = config.qk_normalization
@@ -125,36 +72,15 @@ class InternAttention(nn.Module):
self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.inner_attn = FlashAttention(softmax_scale=self.scale)
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
def _flash_attn(
def forward(
self,
x,
):
qkv = self.qkv(x)
qkv = rearrange(
qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads
)
if self.qk_normalization:
q, k, v = qkv.unbind(2)
q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
qkv = torch.stack([q, k, v], dim=2)
context, _ = self.inner_attn(
qkv,
)
outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
outs = self.proj_drop(outs)
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
) -> torch.Tensor:
out = self.attn(hidden_states, cu_seqlens=cu_seqlens)
outs = self.proj_drop(out)
return outs
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
x = self._flash_attn(hidden_states)
return x
class InternVisionEmbeddings(nn.Module):
def __init__(self, config: PretrainedConfig):
@@ -286,6 +212,7 @@ class InternVisionEncoderLayer(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
) -> Tuple[
torch.FloatTensor,
Optional[torch.FloatTensor],
@@ -295,8 +222,12 @@ class InternVisionEncoderLayer(nn.Module):
Args:
hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
"""
hidden_states = hidden_states + self.drop_path1(
self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1
self.attn(
self.norm1(hidden_states).to(hidden_states.dtype), cu_seqlens=cu_seqlens
)
* self.ls1
)
hidden_states = hidden_states + self.drop_path2(
@@ -363,12 +294,12 @@ class InternVisionEncoder(nn.Module):
encoder_states = () if output_hidden_states else None
hidden_states = inputs_embeds
cu_seqlens = SingletonCache()
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
layer_outputs = encoder_layer(
hidden_states,
)
layer_outputs = encoder_layer(hidden_states, cu_seqlens=cu_seqlens)
hidden_states = layer_outputs
if output_hidden_states:
@@ -625,6 +556,7 @@ class InternVLChatModel(nn.Module):
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
@@ -641,6 +573,11 @@ class InternVLChatModel(nn.Module):
weight_loader(param, loaded_weight, shard_id)
break
else:
if "vision_model" in name:
# adapt to VisionAttention
name = name.replace(r"attn.", r"attn.attn.")
name = name.replace(r"qkv.", r"qkv_proj.")
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
@@ -665,6 +602,13 @@ class InternVLChatModel(nn.Module):
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
raise RuntimeError(
f"Some weights are not initialized from checkpoints: {unloaded_params}"
)
return loaded_params
EntryClass = InternVLChatModel