Feature: support qwen and llama4 reducescatter for dp attention padding (#9101)
This commit is contained in:
@@ -253,7 +253,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
)
|
||||
return lora_output
|
||||
|
||||
def forward(self, input_: torch.Tensor):
|
||||
def forward(self, input_: torch.Tensor, skip_all_reduce=False):
|
||||
# duplicate the logic in RowParallelLinear
|
||||
if self.base_layer.input_is_parallel:
|
||||
input_parallel = input_
|
||||
@@ -270,7 +270,11 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
if self.set_lora:
|
||||
output_parallel = self.apply_lora(output_parallel, input_parallel)
|
||||
|
||||
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
|
||||
if (
|
||||
self.base_layer.reduce_results
|
||||
and self.base_layer.tp_size > 1
|
||||
and not skip_all_reduce
|
||||
):
|
||||
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output_ = output_parallel
|
||||
|
||||
Reference in New Issue
Block a user