Refactor duplicated code into a common method to reduce redundancy (#7210)
### What this PR does / why we need it?
1. Extracting duplicated code into a method.
That is defining _get_input_parallel_ in parent class
_CustomRowParallelOp_, and call the helper method in its 5 children
classes :
- MLPRowParallelOp
- OProjRowParallelOp
- Flashcomm2OProjRowParallelOp
- MatmulAllreduceRowParallelOp
- SequenceRowParallelOp
's _apply_impl_ method
2. Variable typo fixing: split instead of splitted for the past tense
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
Signed-off-by: idouba <zhangchaomeng@huawei.com>
This commit is contained in:
@@ -157,6 +157,13 @@ class CustomRowParallelOp(CustomLinearOp):
|
|||||||
return output
|
return output
|
||||||
return output, output_bias
|
return output, output_bias
|
||||||
|
|
||||||
|
def get_input_parallel(self, input_: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.input_is_parallel:
|
||||||
|
return input_
|
||||||
|
|
||||||
|
split_input = split_tensor_along_last_dim(input_, num_partitions=self.tp_size)
|
||||||
|
return split_input[self.tp_rank].contiguous()
|
||||||
|
|
||||||
|
|
||||||
class CustomReplicatedOp(CustomLinearOp):
|
class CustomReplicatedOp(CustomLinearOp):
|
||||||
def apply_impl(self, input_):
|
def apply_impl(self, input_):
|
||||||
@@ -200,11 +207,7 @@ class MLPRowParallelOp(CustomRowParallelOp):
|
|||||||
return get_mlp_tp_group()
|
return get_mlp_tp_group()
|
||||||
|
|
||||||
def apply_impl(self, input_: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
|
def apply_impl(self, input_: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
|
||||||
if self.input_is_parallel:
|
input_parallel = self.get_input_parallel(input_)
|
||||||
input_parallel = input_
|
|
||||||
else:
|
|
||||||
splitted_input = split_tensor_along_last_dim(input_, num_partitions=self.tp_size)
|
|
||||||
input_parallel = splitted_input[self.tp_rank].contiguous()
|
|
||||||
|
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.layer.bias
|
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.layer.bias
|
||||||
@@ -227,11 +230,7 @@ class OProjRowParallelOp(CustomRowParallelOp):
|
|||||||
self,
|
self,
|
||||||
input_: torch.Tensor,
|
input_: torch.Tensor,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
|
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
|
||||||
if self.input_is_parallel:
|
input_parallel = self.get_input_parallel(input_)
|
||||||
input_parallel = input_
|
|
||||||
else:
|
|
||||||
splitted_input = split_tensor_along_last_dim(input_, num_partitions=self.tp_size)
|
|
||||||
input_parallel = splitted_input[self.tp_rank].contiguous()
|
|
||||||
|
|
||||||
# Prepare tensors for all-to-all communication
|
# Prepare tensors for all-to-all communication
|
||||||
local_batch_size = input_parallel.size(0)
|
local_batch_size = input_parallel.size(0)
|
||||||
@@ -303,12 +302,7 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
|
|||||||
Output.shape = [(batchsize*seqlength+padsize)/TP, hiddensize]
|
Output.shape = [(batchsize*seqlength+padsize)/TP, hiddensize]
|
||||||
"""
|
"""
|
||||||
# Handle input parallelism - split or use as-is
|
# Handle input parallelism - split or use as-is
|
||||||
if self.input_is_parallel:
|
input_parallel = self.get_input_parallel(input_)
|
||||||
input_parallel = input_
|
|
||||||
else:
|
|
||||||
tp_rank = self.tp_rank
|
|
||||||
splitted_input = split_tensor_along_last_dim(input_, num_partitions=self.tp_size)
|
|
||||||
input_parallel = splitted_input[tp_rank].contiguous()
|
|
||||||
|
|
||||||
# padding for all-to-all
|
# padding for all-to-all
|
||||||
num_padding_tokens = _EXTRA_CTX.pad_size
|
num_padding_tokens = _EXTRA_CTX.pad_size
|
||||||
@@ -394,11 +388,7 @@ class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
|
|||||||
self.hcomm_info = self.get_hcomm_info(self.comm_group.device_group)
|
self.hcomm_info = self.get_hcomm_info(self.comm_group.device_group)
|
||||||
|
|
||||||
def apply_impl(self, input_: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
|
def apply_impl(self, input_: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
|
||||||
if self.input_is_parallel:
|
input_parallel = self.get_input_parallel(input_)
|
||||||
input_parallel = input_
|
|
||||||
else:
|
|
||||||
splitted_input = split_tensor_along_last_dim(input_, num_partitions=self.tp_size)
|
|
||||||
input_parallel = splitted_input[self.tp_rank].contiguous()
|
|
||||||
"""Calculate the output tensor of forward by considering
|
"""Calculate the output tensor of forward by considering
|
||||||
fusing communication and computation."""
|
fusing communication and computation."""
|
||||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||||
@@ -492,12 +482,7 @@ class SequenceRowParallelOp(CustomRowParallelOp):
|
|||||||
Implemented multiple optimization projects for dense models, such as FlashComm and
|
Implemented multiple optimization projects for dense models, such as FlashComm and
|
||||||
communication-computation fusion.
|
communication-computation fusion.
|
||||||
"""
|
"""
|
||||||
|
input_parallel = self.get_input_parallel(input_)
|
||||||
if self.input_is_parallel:
|
|
||||||
input_parallel = input_
|
|
||||||
else:
|
|
||||||
splitted_input = split_tensor_along_last_dim(input_, num_partitions=self.tp_size)
|
|
||||||
input_parallel = splitted_input[self.tp_rank].contiguous()
|
|
||||||
|
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||||
|
|||||||
Reference in New Issue
Block a user