[main][refactor] refactor SequenceRowParallelOp forward (#3616)
### What this PR does / why we need it? This PR refactors SequenceRowParallelOp forward. In order to further expand the operator inclusion scope in dynamic judgment scenarios, this PR customizes the entire matmul computation and communication as a custom operator masking. With this refactor, it will support directly writing code such as common operation fusion into the `SequenceRowParallelOp` class's member function `matmul_and_reduce`, without the need to register more redundant custom masking operators. ### How was this patch tested? CI passed with existing test. Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
@@ -235,6 +235,31 @@ def _maybe_all_reduce_tensor_model_parallel_impl(
|
||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
|
||||
def _matmul_and_reduce_impl(input_parallel: torch.Tensor,
|
||||
layer_name: str) -> torch.Tensor:
|
||||
forward_context = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
assert self.custom_op is not None
|
||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||
output = self.custom_op.matmul_and_reduce(input_parallel, bias_)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor,
|
||||
layer_name: str) -> torch.Tensor:
|
||||
forward_context = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
num_tokens = input_parallel.size(0)
|
||||
if forward_context.sp_enabled:
|
||||
num_tokens = num_tokens // self.tp_size
|
||||
output = torch.empty(size=(num_tokens, self.output_size_per_partition),
|
||||
device=input_parallel.device,
|
||||
dtype=input_parallel.dtype)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad",
|
||||
op_func=_maybe_all_gather_and_maybe_unpad_impl,
|
||||
fake_impl=_maybe_all_gather_and_maybe_unpad_fake,
|
||||
@@ -282,3 +307,9 @@ direct_register_custom_op(op_name="maybe_all_reduce_tensor_model_parallel",
|
||||
fake_impl=lambda x: x,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(op_name="matmul_and_reduce",
|
||||
op_func=_matmul_and_reduce_impl,
|
||||
fake_impl=_matmul_and_reduce_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
Reference in New Issue
Block a user