From 9a0d0b754deb3d222eaea9fbd2bf57a6b222ab1d Mon Sep 17 00:00:00 2001 From: Vincent Zhong <207368749+vincentzed@users.noreply.github.com> Date: Sun, 31 Aug 2025 05:20:50 -0400 Subject: [PATCH] [Performance] Improve Qwen RMSNorm by replacing with native RMSNorm op (#9709) --- python/sglang/srt/models/qwen2_5_vl.py | 62 ++++++++++++++++++-------- 1 file changed, 44 insertions(+), 18 deletions(-) diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index 59f3e6370..7ffb2e89b 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -31,7 +31,6 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange from transformers.activations import ACT2FN -from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig, @@ -43,6 +42,7 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( from sglang.srt.hf_transformers_utils import get_processor from sglang.srt.layers.attention.vision import VisionAttention +from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.pooler import Pooler, PoolingType @@ -122,8 +122,8 @@ class Qwen2_5_VisionBlock(nn.Module): super().__init__() if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) - self.norm1 = Qwen2RMSNorm(dim, eps=1e-6) - self.norm2 = Qwen2RMSNorm(dim, eps=1e-6) + self.norm1 = RMSNorm(dim, eps=1e-6) + self.norm2 = RMSNorm(dim, eps=1e-6) if attn_implementation is None: softmax_in_single_precision = False @@ -174,18 +174,29 @@ class Qwen2_5_VisionBlock(nn.Module): cu_seqlens: torch.Tensor, position_embeddings: torch.Tensor, ) -> torch.Tensor: - hidden_states = self.norm1(x) - hidden_states = rearrange(hidden_states, "s b ... -> b s ...") + S, B, H = x.shape + # norm1: flatten to 2D -> [S*B, H], then reshape back + x2d = x.reshape(-1, H) + hidden_states = self.norm1(x2d).reshape(S, B, H) + + # Attention expects [B, S, H] + hidden_states = rearrange(hidden_states, "s b h -> b s h") attn = self.attn( hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, ) - attn = rearrange(attn, "b s ... -> s b ...") - x = x + attn - norm2 = self.norm2(x) - mlp = self.mlp(norm2) - x = x + mlp + attn = rearrange(attn, "b s h -> s b h") + + # norm2 with fused residual-add: also 2D + attn2d = attn.reshape(-1, H) + x_norm_2d, x_after_add_2d = self.norm2(x2d, residual=attn2d) + x_norm = x_norm_2d.reshape(S, B, H) + x_after_add = x_after_add_2d.reshape(S, B, H) + + # MLP and final residual + mlp_out = self.mlp(x_norm) + x = x_after_add + mlp_out return x @@ -201,7 +212,7 @@ class Qwen2_5_VisionPatchMerger(nn.Module): ) -> None: super().__init__() self.hidden_size = context_dim * (spatial_merge_size**2) - self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6) + self.ln_q = RMSNorm(context_dim, eps=1e-6) self.mlp = nn.ModuleList( [ ColumnParallelLinear( @@ -223,11 +234,13 @@ class Qwen2_5_VisionPatchMerger(nn.Module): ) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.ln_q(x) - x = x.view(-1, self.hidden_size) - + # x expected shape: [S, B, context_dim] + S, B, D = x.shape + x2d = x.reshape(-1, D) + x2d = self.ln_q(x2d) # RMSNorm expects 2D + x2d = x2d.view(-1, self.hidden_size) # group into spatial_merge_unit mlp_fc1, mlp_act, mlp_fc2 = self.mlp - x_parallel, _ = mlp_fc1(x) + x_parallel, _ = mlp_fc1(x2d) x_parallel = mlp_act(x_parallel) out, _ = mlp_fc2(x_parallel) return out @@ -394,6 +407,12 @@ class Qwen2_5_VisionTransformer(nn.Module): ) cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + # Move window_index to the same device as x before using it to index x + window_index = window_index.to(device=x.device) + + # Ensure rotary_pos_emb is on the same device/dtype as x + rotary_pos_emb = rotary_pos_emb.to(device=x.device, dtype=x.dtype) + seq_len, _ = x.size() x = x.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) @@ -406,12 +425,19 @@ class Qwen2_5_VisionTransformer(nn.Module): rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) + # After building position_embeddings, make sure both cos and sin are on the same device/dtype as the attention input + position_embeddings = ( + position_embeddings[0].to(x.device, x.dtype), + position_embeddings[1].to(x.device, x.dtype), + ) - # compute cu_seqlens + # compute cu_seqlens - move cu_seqlens to GPU and make it int32 cu_seqlens = torch.cat( [ - torch.tensor([0], device=grid_thw.device), - (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).cumsum(dim=0), + torch.tensor([0], device=x.device, dtype=torch.int32), + (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]) + .cumsum(dim=0) + .to(device=x.device, dtype=torch.int32), ] ) cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)