[Performance] Improve Qwen RMSNorm by replacing with native RMSNorm op (#9709)
This commit is contained in:
@@ -31,7 +31,6 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
|
|
||||||
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||||
Qwen2_5_VLConfig,
|
Qwen2_5_VLConfig,
|
||||||
Qwen2_5_VLVisionConfig,
|
Qwen2_5_VLVisionConfig,
|
||||||
@@ -43,6 +42,7 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
|||||||
|
|
||||||
from sglang.srt.hf_transformers_utils import get_processor
|
from sglang.srt.hf_transformers_utils import get_processor
|
||||||
from sglang.srt.layers.attention.vision import VisionAttention
|
from sglang.srt.layers.attention.vision import VisionAttention
|
||||||
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.pooler import Pooler, PoolingType
|
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||||
@@ -122,8 +122,8 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
if norm_layer is None:
|
if norm_layer is None:
|
||||||
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||||||
self.norm1 = Qwen2RMSNorm(dim, eps=1e-6)
|
self.norm1 = RMSNorm(dim, eps=1e-6)
|
||||||
self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
|
self.norm2 = RMSNorm(dim, eps=1e-6)
|
||||||
|
|
||||||
if attn_implementation is None:
|
if attn_implementation is None:
|
||||||
softmax_in_single_precision = False
|
softmax_in_single_precision = False
|
||||||
@@ -174,18 +174,29 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
position_embeddings: torch.Tensor,
|
position_embeddings: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.norm1(x)
|
S, B, H = x.shape
|
||||||
hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
|
# norm1: flatten to 2D -> [S*B, H], then reshape back
|
||||||
|
x2d = x.reshape(-1, H)
|
||||||
|
hidden_states = self.norm1(x2d).reshape(S, B, H)
|
||||||
|
|
||||||
|
# Attention expects [B, S, H]
|
||||||
|
hidden_states = rearrange(hidden_states, "s b h -> b s h")
|
||||||
attn = self.attn(
|
attn = self.attn(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
position_embeddings=position_embeddings,
|
position_embeddings=position_embeddings,
|
||||||
)
|
)
|
||||||
attn = rearrange(attn, "b s ... -> s b ...")
|
attn = rearrange(attn, "b s h -> s b h")
|
||||||
x = x + attn
|
|
||||||
norm2 = self.norm2(x)
|
# norm2 with fused residual-add: also 2D
|
||||||
mlp = self.mlp(norm2)
|
attn2d = attn.reshape(-1, H)
|
||||||
x = x + mlp
|
x_norm_2d, x_after_add_2d = self.norm2(x2d, residual=attn2d)
|
||||||
|
x_norm = x_norm_2d.reshape(S, B, H)
|
||||||
|
x_after_add = x_after_add_2d.reshape(S, B, H)
|
||||||
|
|
||||||
|
# MLP and final residual
|
||||||
|
mlp_out = self.mlp(x_norm)
|
||||||
|
x = x_after_add + mlp_out
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@@ -201,7 +212,7 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = context_dim * (spatial_merge_size**2)
|
self.hidden_size = context_dim * (spatial_merge_size**2)
|
||||||
self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6)
|
self.ln_q = RMSNorm(context_dim, eps=1e-6)
|
||||||
self.mlp = nn.ModuleList(
|
self.mlp = nn.ModuleList(
|
||||||
[
|
[
|
||||||
ColumnParallelLinear(
|
ColumnParallelLinear(
|
||||||
@@ -223,11 +234,13 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
x = self.ln_q(x)
|
# x expected shape: [S, B, context_dim]
|
||||||
x = x.view(-1, self.hidden_size)
|
S, B, D = x.shape
|
||||||
|
x2d = x.reshape(-1, D)
|
||||||
|
x2d = self.ln_q(x2d) # RMSNorm expects 2D
|
||||||
|
x2d = x2d.view(-1, self.hidden_size) # group into spatial_merge_unit
|
||||||
mlp_fc1, mlp_act, mlp_fc2 = self.mlp
|
mlp_fc1, mlp_act, mlp_fc2 = self.mlp
|
||||||
x_parallel, _ = mlp_fc1(x)
|
x_parallel, _ = mlp_fc1(x2d)
|
||||||
x_parallel = mlp_act(x_parallel)
|
x_parallel = mlp_act(x_parallel)
|
||||||
out, _ = mlp_fc2(x_parallel)
|
out, _ = mlp_fc2(x_parallel)
|
||||||
return out
|
return out
|
||||||
@@ -394,6 +407,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
)
|
)
|
||||||
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
||||||
|
|
||||||
|
# Move window_index to the same device as x before using it to index x
|
||||||
|
window_index = window_index.to(device=x.device)
|
||||||
|
|
||||||
|
# Ensure rotary_pos_emb is on the same device/dtype as x
|
||||||
|
rotary_pos_emb = rotary_pos_emb.to(device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
seq_len, _ = x.size()
|
seq_len, _ = x.size()
|
||||||
|
|
||||||
x = x.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
x = x.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
||||||
@@ -406,12 +425,19 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
||||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||||
position_embeddings = (emb.cos(), emb.sin())
|
position_embeddings = (emb.cos(), emb.sin())
|
||||||
|
# After building position_embeddings, make sure both cos and sin are on the same device/dtype as the attention input
|
||||||
|
position_embeddings = (
|
||||||
|
position_embeddings[0].to(x.device, x.dtype),
|
||||||
|
position_embeddings[1].to(x.device, x.dtype),
|
||||||
|
)
|
||||||
|
|
||||||
# compute cu_seqlens
|
# compute cu_seqlens - move cu_seqlens to GPU and make it int32
|
||||||
cu_seqlens = torch.cat(
|
cu_seqlens = torch.cat(
|
||||||
[
|
[
|
||||||
torch.tensor([0], device=grid_thw.device),
|
torch.tensor([0], device=x.device, dtype=torch.int32),
|
||||||
(grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).cumsum(dim=0),
|
(grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2])
|
||||||
|
.cumsum(dim=0)
|
||||||
|
.to(device=x.device, dtype=torch.int32),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
||||||
|
|||||||
Reference in New Issue
Block a user