Feature: support qwen and llama4 reducescatter for dp attention padding (#9101)

This commit is contained in:
wxzhoucs
2025-08-14 12:10:29 +08:00
committed by GitHub
parent 1bc183c6de
commit 4c22897a66
5 changed files with 68 additions and 16 deletions

View File

@@ -253,7 +253,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
)
return lora_output
def forward(self, input_: torch.Tensor):
def forward(self, input_: torch.Tensor, skip_all_reduce=False):
# duplicate the logic in RowParallelLinear
if self.base_layer.input_is_parallel:
input_parallel = input_
@@ -270,7 +270,11 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
if self.set_lora:
output_parallel = self.apply_lora(output_parallel, input_parallel)
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
if (
self.base_layer.reduce_results
and self.base_layer.tp_size > 1
and not skip_all_reduce
):
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
output_ = output_parallel