perf: optimize qwen-vl with symm mem allreduce (#11381)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
@@ -27,6 +27,7 @@ from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
||||
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
@@ -88,10 +89,17 @@ class Qwen2MLP(nn.Module):
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
should_allreduce_fusion: bool = False,
|
||||
):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
x, _ = self.down_proj(
|
||||
x,
|
||||
skip_all_reduce=should_allreduce_fusion,
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
@@ -109,9 +117,11 @@ class Qwen2Attention(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
|
||||
prefix: str = "",
|
||||
alt_stream: Optional[torch.cuda.Stream] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
@@ -143,6 +153,8 @@ class Qwen2Attention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
tp_rank=tp_rank,
|
||||
tp_size=tp_size,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
@@ -150,6 +162,8 @@ class Qwen2Attention(nn.Module):
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
tp_rank=tp_rank,
|
||||
tp_size=tp_size,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
|
||||
@@ -195,6 +209,7 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
alt_stream: Optional[torch.cuda.Stream] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 1000000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
@@ -216,6 +231,18 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
|
||||
self.layer_id = layer_id
|
||||
self.is_layer_sparse = False
|
||||
is_previous_layer_sparse = False
|
||||
|
||||
self.layer_scatter_modes = LayerScatterModes.init_new(
|
||||
layer_id=layer_id,
|
||||
num_layers=config.num_hidden_layers,
|
||||
is_layer_sparse=self.is_layer_sparse,
|
||||
is_previous_layer_sparse=is_previous_layer_sparse,
|
||||
)
|
||||
|
||||
self.mlp = Qwen2MLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
@@ -228,6 +255,14 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
self.layer_communicator = LayerCommunicator(
|
||||
layer_scatter_modes=self.layer_scatter_modes,
|
||||
input_layernorm=self.input_layernorm,
|
||||
post_attention_layernorm=self.post_attention_layernorm,
|
||||
allow_reduce_scatter=True,
|
||||
is_last_layer=(self.layer_id == self.config.num_hidden_layers - 1),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
@@ -249,7 +284,13 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
should_allreduce_fusion = (
|
||||
self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer(
|
||||
forward_batch
|
||||
)
|
||||
)
|
||||
|
||||
hidden_states = self.mlp(hidden_states, should_allreduce_fusion)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user