perf: optimize qwen-vl with symm mem allreduce (#11381)

Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
Yuan Luo
2025-10-10 22:24:45 +08:00
committed by GitHub
parent a1a20b4c7c
commit 3b9d97f335
5 changed files with 82 additions and 17 deletions

View File

@@ -27,6 +27,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
@@ -88,10 +89,17 @@ class Qwen2MLP(nn.Module):
)
self.act_fn = SiluAndMul()
def forward(self, x):
def forward(
self,
x,
should_allreduce_fusion: bool = False,
):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
x, _ = self.down_proj(
x,
skip_all_reduce=should_allreduce_fusion,
)
return x
@@ -109,9 +117,11 @@ class Qwen2Attention(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
@@ -143,6 +153,8 @@ class Qwen2Attention(nn.Module):
self.total_num_kv_heads,
bias=True,
quant_config=quant_config,
tp_rank=tp_rank,
tp_size=tp_size,
prefix=add_prefix("qkv_proj", prefix),
)
self.o_proj = RowParallelLinear(
@@ -150,6 +162,8 @@ class Qwen2Attention(nn.Module):
hidden_size,
bias=False,
quant_config=quant_config,
tp_rank=tp_rank,
tp_size=tp_size,
prefix=add_prefix("o_proj", prefix),
)
@@ -195,6 +209,7 @@ class Qwen2DecoderLayer(nn.Module):
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
@@ -216,6 +231,18 @@ class Qwen2DecoderLayer(nn.Module):
dual_chunk_attention_config=dual_chunk_attention_config,
prefix=add_prefix("self_attn", prefix),
)
self.layer_id = layer_id
self.is_layer_sparse = False
is_previous_layer_sparse = False
self.layer_scatter_modes = LayerScatterModes.init_new(
layer_id=layer_id,
num_layers=config.num_hidden_layers,
is_layer_sparse=self.is_layer_sparse,
is_previous_layer_sparse=is_previous_layer_sparse,
)
self.mlp = Qwen2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
@@ -228,6 +255,14 @@ class Qwen2DecoderLayer(nn.Module):
config.hidden_size, eps=config.rms_norm_eps
)
self.layer_communicator = LayerCommunicator(
layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm,
allow_reduce_scatter=True,
is_last_layer=(self.layer_id == self.config.num_hidden_layers - 1),
)
def forward(
self,
positions: torch.Tensor,
@@ -249,7 +284,13 @@ class Qwen2DecoderLayer(nn.Module):
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
should_allreduce_fusion = (
self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer(
forward_batch
)
)
hidden_states = self.mlp(hidden_states, should_allreduce_fusion)
return hidden_states, residual