[v0.11.0][refactor] refactor SequenceRowParallelOp forward (#3654)
### What this PR does / why we need it? This PR refactors SequenceRowParallelOp forward. In order to further expand the operator inclusion scope in dynamic judgment scenarios, this PR customizes the entire matmul computation and communication as a custom operator masking. With this refactor, it will support directly writing code such as common operation fusion into the SequenceRowParallelOp class's member function matmul_and_reduce, without the need to register more redundant custom masking operators. ### How was this patch tested? CI passed with new added/existing test. Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
@@ -366,14 +366,23 @@ class SequenceRowParallelOp(CustomRowParallelOp):
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
else:
|
||||
output_parallel = self.quant_method.apply(self.layer,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel)
|
||||
output = torch.ops.vllm.matmul_and_reduce(input_parallel,
|
||||
self.prefix)
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
def matmul_and_reduce(self, input_parallel: torch.Tensor,
|
||||
bias_: Optional[Parameter]) -> torch.Tensor:
|
||||
assert self.quant_method is not None
|
||||
output_parallel = self.quant_method.apply(self.layer,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
from vllm_ascend.ops.register_custom_ops import \
|
||||
_maybe_pad_and_reduce_impl
|
||||
output = _maybe_pad_and_reduce_impl(output_parallel)
|
||||
return output
|
||||
|
||||
def update_attrs(self):
|
||||
super().update_attrs()
|
||||
self.input_is_parallel = self.layer.input_is_parallel
|
||||
|
||||
Reference in New Issue
Block a user