vlm: adapt internvl to VisionAttention (#6870)
This commit is contained in:
@@ -11,21 +11,19 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==========================582====================================================
|
||||
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/7f62077af5159c625fe3ad1c812e6c1a2b93ba3b/vllm/model_executor/models/internlm2.py
|
||||
# Adapted from https://raw.githubusercontent.com/hehesangsj/sglang/refs/heads/internvl/python/sglang/srt/models/internvl.py
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from sgl_kernel.flash_attn import flash_attn_varlen_func
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
|
||||
from sglang.srt.layers.attention.vision import SingletonCache, VisionAttention
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternTokenPairs,
|
||||
@@ -40,75 +38,12 @@ from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||
from sglang.utils import logger
|
||||
|
||||
|
||||
class FlashAttention(nn.Module):
|
||||
"""Implement the scaled dot product attention with softmax.
|
||||
Arguments
|
||||
---------
|
||||
softmax_scale: The temperature to use for the softmax attention.
|
||||
(default: 1/sqrt(d_keys) where d_keys is computed at
|
||||
runtime)
|
||||
attention_dropout: The dropout rate to apply to the attention
|
||||
(default: 0.0)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None
|
||||
):
|
||||
super().__init__()
|
||||
self.softmax_scale = softmax_scale
|
||||
self.dropout_p = attention_dropout
|
||||
|
||||
def forward(
|
||||
self,
|
||||
qkv,
|
||||
causal=False,
|
||||
max_s=None,
|
||||
):
|
||||
"""Implements the multihead softmax attention.
|
||||
Arguments
|
||||
---------
|
||||
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
|
||||
if unpadded: (nnz, 3, h, d)
|
||||
"""
|
||||
assert qkv.dtype in [torch.float16, torch.bfloat16]
|
||||
assert qkv.is_cuda
|
||||
|
||||
batch_size, seqlen, _, nheads, d = qkv.shape
|
||||
if batch_size == 0 or seqlen == 0:
|
||||
output_shape = (batch_size, seqlen, nheads, d)
|
||||
return (
|
||||
torch.zeros(output_shape, dtype=qkv.dtype, device=qkv.device),
|
||||
None,
|
||||
)
|
||||
|
||||
qkv_reshaped = rearrange(qkv, "b s three h d -> (b s) three h d", three=3)
|
||||
q, k, v = qkv_reshaped.unbind(1)
|
||||
|
||||
max_s = seqlen
|
||||
cu_seqlens = torch.arange(
|
||||
0,
|
||||
(batch_size + 1) * seqlen,
|
||||
step=seqlen,
|
||||
dtype=torch.int32,
|
||||
device=qkv.device,
|
||||
)
|
||||
output_reshaped = flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
max_s,
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=causal,
|
||||
)
|
||||
output = rearrange(output_reshaped, "(b s) h d -> b s h d", b=batch_size)
|
||||
return output, None
|
||||
|
||||
|
||||
class InternAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
quant_config: QuantizationConfig = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
@@ -116,7 +51,19 @@ class InternAttention(nn.Module):
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
|
||||
|
||||
self.attn = VisionAttention(
|
||||
qkv_backend="fa3",
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=self.num_heads,
|
||||
projection_size=self.embed_dim,
|
||||
use_qkv_parallel=True,
|
||||
quant_config=quant_config,
|
||||
dropout=getattr(config, "dropout", 0.0),
|
||||
proj_bias=getattr(config, "qkv_bias", True),
|
||||
flatten_batch=False,
|
||||
)
|
||||
|
||||
self.proj_drop = nn.Dropout(config.dropout)
|
||||
|
||||
self.qk_normalization = config.qk_normalization
|
||||
@@ -125,36 +72,15 @@ class InternAttention(nn.Module):
|
||||
self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
self.inner_attn = FlashAttention(softmax_scale=self.scale)
|
||||
|
||||
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
def _flash_attn(
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
):
|
||||
qkv = self.qkv(x)
|
||||
qkv = rearrange(
|
||||
qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads
|
||||
)
|
||||
|
||||
if self.qk_normalization:
|
||||
q, k, v = qkv.unbind(2)
|
||||
q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
|
||||
k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
|
||||
qkv = torch.stack([q, k, v], dim=2)
|
||||
|
||||
context, _ = self.inner_attn(
|
||||
qkv,
|
||||
)
|
||||
outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
|
||||
outs = self.proj_drop(outs)
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
out = self.attn(hidden_states, cu_seqlens=cu_seqlens)
|
||||
outs = self.proj_drop(out)
|
||||
return outs
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
x = self._flash_attn(hidden_states)
|
||||
return x
|
||||
|
||||
|
||||
class InternVisionEmbeddings(nn.Module):
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
@@ -286,6 +212,7 @@ class InternVisionEncoderLayer(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
) -> Tuple[
|
||||
torch.FloatTensor,
|
||||
Optional[torch.FloatTensor],
|
||||
@@ -295,8 +222,12 @@ class InternVisionEncoderLayer(nn.Module):
|
||||
Args:
|
||||
hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
"""
|
||||
|
||||
hidden_states = hidden_states + self.drop_path1(
|
||||
self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1
|
||||
self.attn(
|
||||
self.norm1(hidden_states).to(hidden_states.dtype), cu_seqlens=cu_seqlens
|
||||
)
|
||||
* self.ls1
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + self.drop_path2(
|
||||
@@ -363,12 +294,12 @@ class InternVisionEncoder(nn.Module):
|
||||
encoder_states = () if output_hidden_states else None
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
cu_seqlens = SingletonCache()
|
||||
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
)
|
||||
layer_outputs = encoder_layer(hidden_states, cu_seqlens=cu_seqlens)
|
||||
hidden_states = layer_outputs
|
||||
|
||||
if output_hidden_states:
|
||||
@@ -625,6 +556,7 @@ class InternVLChatModel(nn.Module):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
@@ -641,6 +573,11 @@ class InternVLChatModel(nn.Module):
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
if "vision_model" in name:
|
||||
# adapt to VisionAttention
|
||||
name = name.replace(r"attn.", r"attn.attn.")
|
||||
name = name.replace(r"qkv.", r"qkv_proj.")
|
||||
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
@@ -665,6 +602,13 @@ class InternVLChatModel(nn.Module):
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
unloaded_params = params_dict.keys() - loaded_params
|
||||
if unloaded_params:
|
||||
raise RuntimeError(
|
||||
f"Some weights are not initialized from checkpoints: {unloaded_params}"
|
||||
)
|
||||
return loaded_params
|
||||
|
||||
|
||||
EntryClass = InternVLChatModel
|
||||
|
||||
Reference in New Issue
Block a user