[Performance] Improve Qwen RMSNorm by replacing with native RMSNorm op (#9709)
This commit is contained in:
@@ -31,7 +31,6 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
|
||||
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||
Qwen2_5_VLConfig,
|
||||
Qwen2_5_VLVisionConfig,
|
||||
@@ -43,6 +42,7 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_processor
|
||||
from sglang.srt.layers.attention.vision import VisionAttention
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
@@ -122,8 +122,8 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
super().__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||||
self.norm1 = Qwen2RMSNorm(dim, eps=1e-6)
|
||||
self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
|
||||
self.norm1 = RMSNorm(dim, eps=1e-6)
|
||||
self.norm2 = RMSNorm(dim, eps=1e-6)
|
||||
|
||||
if attn_implementation is None:
|
||||
softmax_in_single_precision = False
|
||||
@@ -174,18 +174,29 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
cu_seqlens: torch.Tensor,
|
||||
position_embeddings: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.norm1(x)
|
||||
hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
|
||||
S, B, H = x.shape
|
||||
# norm1: flatten to 2D -> [S*B, H], then reshape back
|
||||
x2d = x.reshape(-1, H)
|
||||
hidden_states = self.norm1(x2d).reshape(S, B, H)
|
||||
|
||||
# Attention expects [B, S, H]
|
||||
hidden_states = rearrange(hidden_states, "s b h -> b s h")
|
||||
attn = self.attn(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
attn = rearrange(attn, "b s ... -> s b ...")
|
||||
x = x + attn
|
||||
norm2 = self.norm2(x)
|
||||
mlp = self.mlp(norm2)
|
||||
x = x + mlp
|
||||
attn = rearrange(attn, "b s h -> s b h")
|
||||
|
||||
# norm2 with fused residual-add: also 2D
|
||||
attn2d = attn.reshape(-1, H)
|
||||
x_norm_2d, x_after_add_2d = self.norm2(x2d, residual=attn2d)
|
||||
x_norm = x_norm_2d.reshape(S, B, H)
|
||||
x_after_add = x_after_add_2d.reshape(S, B, H)
|
||||
|
||||
# MLP and final residual
|
||||
mlp_out = self.mlp(x_norm)
|
||||
x = x_after_add + mlp_out
|
||||
return x
|
||||
|
||||
|
||||
@@ -201,7 +212,7 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = context_dim * (spatial_merge_size**2)
|
||||
self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6)
|
||||
self.ln_q = RMSNorm(context_dim, eps=1e-6)
|
||||
self.mlp = nn.ModuleList(
|
||||
[
|
||||
ColumnParallelLinear(
|
||||
@@ -223,11 +234,13 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.ln_q(x)
|
||||
x = x.view(-1, self.hidden_size)
|
||||
|
||||
# x expected shape: [S, B, context_dim]
|
||||
S, B, D = x.shape
|
||||
x2d = x.reshape(-1, D)
|
||||
x2d = self.ln_q(x2d) # RMSNorm expects 2D
|
||||
x2d = x2d.view(-1, self.hidden_size) # group into spatial_merge_unit
|
||||
mlp_fc1, mlp_act, mlp_fc2 = self.mlp
|
||||
x_parallel, _ = mlp_fc1(x)
|
||||
x_parallel, _ = mlp_fc1(x2d)
|
||||
x_parallel = mlp_act(x_parallel)
|
||||
out, _ = mlp_fc2(x_parallel)
|
||||
return out
|
||||
@@ -394,6 +407,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
)
|
||||
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
||||
|
||||
# Move window_index to the same device as x before using it to index x
|
||||
window_index = window_index.to(device=x.device)
|
||||
|
||||
# Ensure rotary_pos_emb is on the same device/dtype as x
|
||||
rotary_pos_emb = rotary_pos_emb.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
seq_len, _ = x.size()
|
||||
|
||||
x = x.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
||||
@@ -406,12 +425,19 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
position_embeddings = (emb.cos(), emb.sin())
|
||||
# After building position_embeddings, make sure both cos and sin are on the same device/dtype as the attention input
|
||||
position_embeddings = (
|
||||
position_embeddings[0].to(x.device, x.dtype),
|
||||
position_embeddings[1].to(x.device, x.dtype),
|
||||
)
|
||||
|
||||
# compute cu_seqlens
|
||||
# compute cu_seqlens - move cu_seqlens to GPU and make it int32
|
||||
cu_seqlens = torch.cat(
|
||||
[
|
||||
torch.tensor([0], device=grid_thw.device),
|
||||
(grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).cumsum(dim=0),
|
||||
torch.tensor([0], device=x.device, dtype=torch.int32),
|
||||
(grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2])
|
||||
.cumsum(dim=0)
|
||||
.to(device=x.device, dtype=torch.int32),
|
||||
]
|
||||
)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
||||
|
||||
Reference in New Issue
Block a user