[v0.11.0][refactor] refactor SequenceRowParallelOp forward (#3654)
### What this PR does / why we need it? This PR refactors SequenceRowParallelOp forward. In order to further expand the operator inclusion scope in dynamic judgment scenarios, this PR customizes the entire matmul computation and communication as a custom operator masking. With this refactor, it will support directly writing code such as common operation fusion into the SequenceRowParallelOp class's member function matmul_and_reduce, without the need to register more redundant custom masking operators. ### How was this patch tested? CI passed with new added/existing test. Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
@@ -26,6 +26,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import divide
|
||||
from vllm.model_executor.layers.linear import ( # noqa
|
||||
WEIGHT_LOADER_V2_SUPPORTED, ColumnParallelLinear, LinearBase,
|
||||
@@ -234,6 +235,13 @@ class AscendRowParallelLinear(RowParallelLinear):
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
# TODO(shaopeng-666): Remove the visual check after the mm model reconstruction is complete.
|
||||
if prefix in compilation_config.static_forward_context and \
|
||||
"visual" not in prefix:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
|
||||
self.custom_op, self.tp_rank, self.tp_size = get_parallel_op(
|
||||
disable_tp, prefix, self, "row")
|
||||
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group
|
||||
|
||||
Reference in New Issue
Block a user