[v0.11.0][Bugfix]Avoid using the fusion operator in the MOE model (#3837)
### What this PR does / why we need it? The current MatmulReduceScatter operator experiences performance degradation in small-shape scenarios, so it determines whether to use this operator by judging the size of the shape. --------- Signed-off-by: ZYang6263 <zy626375@gmail.com>
This commit is contained in:
@@ -382,8 +382,10 @@ class SequenceRowParallelOp(CustomRowParallelOp):
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
sp_enabled = forward_context.sp_enabled
|
||||
mmrs_fusion = forward_context.mmrs_fusion
|
||||
except AssertionError:
|
||||
sp_enabled = False
|
||||
mmrs_fusion = False
|
||||
|
||||
x = input_parallel
|
||||
|
||||
@@ -409,8 +411,9 @@ class SequenceRowParallelOp(CustomRowParallelOp):
|
||||
quant_per_tensor)
|
||||
|
||||
# For unquant
|
||||
if isinstance(self.layer.quant_method, UnquantizedLinearMethod
|
||||
) and torch.version.cann.startswith("8.3"):
|
||||
if mmrs_fusion and isinstance(
|
||||
self.layer.quant_method, UnquantizedLinearMethod
|
||||
) and torch.version.cann.startswith("8.3"):
|
||||
output = torch_npu.npu_mm_reduce_scatter_base(
|
||||
x,
|
||||
self.layer.weight.t(),
|
||||
@@ -423,10 +426,11 @@ class SequenceRowParallelOp(CustomRowParallelOp):
|
||||
if bias_ is not None:
|
||||
output.add_(bias_)
|
||||
# For w8a8 quant
|
||||
elif (isinstance(self.layer.quant_method, AscendLinearMethod)
|
||||
and isinstance(self.layer.quant_method.quant_method,
|
||||
AscendW8A8LinearMethod)
|
||||
) and torch.version.cann.startswith("8.3"):
|
||||
elif mmrs_fusion and (
|
||||
isinstance(self.layer.quant_method, AscendLinearMethod)
|
||||
and isinstance(self.layer.quant_method.quant_method,
|
||||
AscendW8A8LinearMethod)
|
||||
) and torch.version.cann.startswith("8.3"):
|
||||
if x.dtype != torch.int8:
|
||||
x_quant = quant_per_tensor(
|
||||
x, self.layer.aclnn_input_scale_reciprocal,
|
||||
|
||||
Reference in New Issue
Block a user