Feature: support qwen and llama4 reducescatter for dp attention padding (#9101)
This commit is contained in:
@@ -253,7 +253,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||||||
)
|
)
|
||||||
return lora_output
|
return lora_output
|
||||||
|
|
||||||
def forward(self, input_: torch.Tensor):
|
def forward(self, input_: torch.Tensor, skip_all_reduce=False):
|
||||||
# duplicate the logic in RowParallelLinear
|
# duplicate the logic in RowParallelLinear
|
||||||
if self.base_layer.input_is_parallel:
|
if self.base_layer.input_is_parallel:
|
||||||
input_parallel = input_
|
input_parallel = input_
|
||||||
@@ -270,7 +270,11 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||||||
if self.set_lora:
|
if self.set_lora:
|
||||||
output_parallel = self.apply_lora(output_parallel, input_parallel)
|
output_parallel = self.apply_lora(output_parallel, input_parallel)
|
||||||
|
|
||||||
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
|
if (
|
||||||
|
self.base_layer.reduce_results
|
||||||
|
and self.base_layer.tp_size > 1
|
||||||
|
and not skip_all_reduce
|
||||||
|
):
|
||||||
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
||||||
else:
|
else:
|
||||||
output_ = output_parallel
|
output_ = output_parallel
|
||||||
|
|||||||
@@ -91,10 +91,18 @@ class LlamaMLP(nn.Module):
|
|||||||
)
|
)
|
||||||
self.act_fn = SiluAndMul()
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
def forward(self, x, forward_batch=None):
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
forward_batch=None,
|
||||||
|
use_reduce_scatter: bool = False,
|
||||||
|
):
|
||||||
gate_up, _ = self.gate_up_proj(x)
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
x = self.act_fn(gate_up)
|
x = self.act_fn(gate_up)
|
||||||
x, _ = self.down_proj(x)
|
x, _ = self.down_proj(
|
||||||
|
x,
|
||||||
|
skip_all_reduce=use_reduce_scatter,
|
||||||
|
)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -131,14 +131,19 @@ class Llama4MoE(nn.Module):
|
|||||||
reduce_results=False, # We need to do scatter before reduce
|
reduce_results=False, # We need to do scatter before reduce
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states, forward_batch: ForwardBatch):
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
use_reduce_scatter: bool = False,
|
||||||
|
):
|
||||||
shared_out, routed_out = self._forward_core(
|
shared_out, routed_out = self._forward_core(
|
||||||
hidden_states, forward_batch.forward_mode
|
hidden_states, forward_batch.forward_mode
|
||||||
)
|
)
|
||||||
|
|
||||||
out_aD = routed_out + shared_out
|
out_aD = routed_out + shared_out
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1 and not use_reduce_scatter:
|
||||||
out_aD = tensor_model_parallel_all_reduce(out_aD)
|
out_aD = tensor_model_parallel_all_reduce(out_aD)
|
||||||
|
|
||||||
return out_aD
|
return out_aD
|
||||||
@@ -412,6 +417,7 @@ class Llama4DecoderLayer(nn.Module):
|
|||||||
layer_scatter_modes=self.layer_scatter_modes,
|
layer_scatter_modes=self.layer_scatter_modes,
|
||||||
input_layernorm=self.input_layernorm,
|
input_layernorm=self.input_layernorm,
|
||||||
post_attention_layernorm=self.post_attention_layernorm,
|
post_attention_layernorm=self.post_attention_layernorm,
|
||||||
|
allow_reduce_scatter=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _is_moe_layer(self, layer_id: int) -> bool:
|
def _is_moe_layer(self, layer_id: int) -> bool:
|
||||||
@@ -441,8 +447,15 @@ class Llama4DecoderLayer(nn.Module):
|
|||||||
hidden_states, residual, forward_batch
|
hidden_states, residual, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# For DP with padding, reduce scatter can be used instead of all-reduce.
|
||||||
|
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
|
||||||
|
forward_batch
|
||||||
|
)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
hidden_states = self.feed_forward(hidden_states, forward_batch)
|
hidden_states = self.feed_forward(
|
||||||
|
hidden_states, forward_batch, use_reduce_scatter
|
||||||
|
)
|
||||||
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
||||||
hidden_states, residual, forward_batch
|
hidden_states, residual, forward_batch
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -108,10 +108,14 @@ class Qwen2MoeMLP(nn.Module):
|
|||||||
)
|
)
|
||||||
self.act_fn = SiluAndMul()
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
use_reduce_scatter: bool = False,
|
||||||
|
):
|
||||||
gate_up, _ = self.gate_up_proj(x)
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
x = self.act_fn(gate_up)
|
x = self.act_fn(gate_up)
|
||||||
x, _ = self.down_proj(x)
|
x, _ = self.down_proj(x, skip_all_reduce=use_reduce_scatter)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@@ -176,7 +180,10 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|||||||
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
|
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
forward_batch: Optional[ForwardBatch] = None,
|
||||||
|
use_reduce_scatter: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
num_tokens, hidden_dim = hidden_states.shape
|
num_tokens, hidden_dim = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
@@ -194,6 +201,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|||||||
final_hidden_states = self.experts(hidden_states, topk_output)
|
final_hidden_states = self.experts(hidden_states, topk_output)
|
||||||
if shared_output is not None:
|
if shared_output is not None:
|
||||||
final_hidden_states = final_hidden_states + shared_output
|
final_hidden_states = final_hidden_states + shared_output
|
||||||
|
if self.tp_size > 1 and not use_reduce_scatter:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
|
|
||||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||||
@@ -368,6 +376,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|||||||
layer_scatter_modes=self.layer_scatter_modes,
|
layer_scatter_modes=self.layer_scatter_modes,
|
||||||
input_layernorm=self.input_layernorm,
|
input_layernorm=self.input_layernorm,
|
||||||
post_attention_layernorm=self.post_attention_layernorm,
|
post_attention_layernorm=self.post_attention_layernorm,
|
||||||
|
allow_reduce_scatter=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -393,7 +402,12 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|||||||
hidden_states, residual, forward_batch
|
hidden_states, residual, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = self.mlp(hidden_states, forward_batch)
|
# For DP with padding, reduce scatter can be used instead of all-reduce.
|
||||||
|
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
|
||||||
|
forward_batch
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
|
||||||
|
|
||||||
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
||||||
hidden_states, residual, forward_batch
|
hidden_states, residual, forward_batch
|
||||||
|
|||||||
@@ -144,11 +144,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
self.top_k = config.num_experts_per_tok
|
self.top_k = config.num_experts_per_tok
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
forward_batch: Optional[ForwardBatch] = None,
|
||||||
|
use_reduce_scatter: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
if not global_server_args_dict["moe_a2a_backend"].is_deepep():
|
if not global_server_args_dict["moe_a2a_backend"].is_deepep():
|
||||||
return self.forward_normal(hidden_states)
|
return self.forward_normal(hidden_states, use_reduce_scatter)
|
||||||
else:
|
else:
|
||||||
return self.forward_deepep(hidden_states, forward_batch)
|
return self.forward_deepep(hidden_states, forward_batch)
|
||||||
|
|
||||||
@@ -159,7 +162,11 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
if name not in ["correction_bias"]
|
if name not in ["correction_bias"]
|
||||||
]
|
]
|
||||||
|
|
||||||
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward_normal(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
use_reduce_scatter: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
num_tokens, hidden_dim = hidden_states.shape
|
num_tokens, hidden_dim = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
|
|
||||||
@@ -167,7 +174,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
topk_output = self.topk(hidden_states, router_logits)
|
topk_output = self.topk(hidden_states, router_logits)
|
||||||
final_hidden_states = self.experts(hidden_states, topk_output)
|
final_hidden_states = self.experts(hidden_states, topk_output)
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1 and not use_reduce_scatter:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
|
|
||||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||||
@@ -521,6 +528,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|||||||
layer_scatter_modes=self.layer_scatter_modes,
|
layer_scatter_modes=self.layer_scatter_modes,
|
||||||
input_layernorm=self.input_layernorm,
|
input_layernorm=self.input_layernorm,
|
||||||
post_attention_layernorm=self.post_attention_layernorm,
|
post_attention_layernorm=self.post_attention_layernorm,
|
||||||
|
allow_reduce_scatter=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -546,7 +554,12 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|||||||
hidden_states, residual, forward_batch
|
hidden_states, residual, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = self.mlp(hidden_states, forward_batch)
|
# For DP with padding, reduce scatter can be used instead of all-reduce.
|
||||||
|
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
|
||||||
|
forward_batch
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
|
||||||
|
|
||||||
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
||||||
hidden_states, residual, forward_batch
|
hidden_states, residual, forward_batch
|
||||||
|
|||||||
Reference in New Issue
Block a user