[Main][Bugfix]Avoid using the fusion operator in the MOE model (#3834)
### What this PR does / why we need it? The current MatmulReduceScatter operator experiences performance degradation in small-shape scenarios, so it determines whether to use this operator by judging the size of the shape. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/releases/v0.11.1 --------- Signed-off-by: ZYang6263 <zy626375@gmail.com>
This commit is contained in:
@@ -113,13 +113,16 @@ def set_ascend_forward_context(
|
|||||||
# Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold,
|
# Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold,
|
||||||
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
|
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
|
||||||
# the performance may degrade due to the switching of communication methods.
|
# the performance may degrade due to the switching of communication methods.
|
||||||
|
mmrs_fusion = True
|
||||||
if is_moe_model(vllm_config):
|
if is_moe_model(vllm_config):
|
||||||
sp_enabled = enable_sp(vllm_config) and \
|
sp_enabled = enable_sp(vllm_config) and \
|
||||||
tp_world_size > 1 and num_tokens is not None
|
tp_world_size > 1 and num_tokens is not None
|
||||||
|
mmrs_fusion = False
|
||||||
else:
|
else:
|
||||||
sp_enabled = enable_sp(vllm_config) and \
|
sp_enabled = enable_sp(vllm_config) and \
|
||||||
tp_world_size > 1 and \
|
tp_world_size > 1 and \
|
||||||
num_tokens is not None and num_tokens > 1000
|
num_tokens is not None and num_tokens > 1000
|
||||||
|
forward_context.mmrs_fusion = mmrs_fusion
|
||||||
|
|
||||||
if sp_enabled:
|
if sp_enabled:
|
||||||
pad_size = (tp_world_size -
|
pad_size = (tp_world_size -
|
||||||
|
|||||||
@@ -382,8 +382,10 @@ class SequenceRowParallelOp(CustomRowParallelOp):
|
|||||||
try:
|
try:
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
sp_enabled = forward_context.sp_enabled
|
sp_enabled = forward_context.sp_enabled
|
||||||
|
mmrs_fusion = forward_context.mmrs_fusion
|
||||||
except AssertionError:
|
except AssertionError:
|
||||||
sp_enabled = False
|
sp_enabled = False
|
||||||
|
mmrs_fusion = False
|
||||||
|
|
||||||
x = input_parallel
|
x = input_parallel
|
||||||
|
|
||||||
@@ -409,7 +411,8 @@ class SequenceRowParallelOp(CustomRowParallelOp):
|
|||||||
quant_per_tensor)
|
quant_per_tensor)
|
||||||
|
|
||||||
# For unquant
|
# For unquant
|
||||||
if isinstance(self.layer.quant_method, UnquantizedLinearMethod
|
if mmrs_fusion and isinstance(
|
||||||
|
self.layer.quant_method, UnquantizedLinearMethod
|
||||||
) and torch.version.cann.startswith("8.3"):
|
) and torch.version.cann.startswith("8.3"):
|
||||||
output = torch_npu.npu_mm_reduce_scatter_base(
|
output = torch_npu.npu_mm_reduce_scatter_base(
|
||||||
x,
|
x,
|
||||||
@@ -423,7 +426,8 @@ class SequenceRowParallelOp(CustomRowParallelOp):
|
|||||||
if bias_ is not None:
|
if bias_ is not None:
|
||||||
output.add_(bias_)
|
output.add_(bias_)
|
||||||
# For w8a8 quant
|
# For w8a8 quant
|
||||||
elif (isinstance(self.layer.quant_method, AscendLinearMethod)
|
elif mmrs_fusion and (
|
||||||
|
isinstance(self.layer.quant_method, AscendLinearMethod)
|
||||||
and isinstance(self.layer.quant_method.quant_method,
|
and isinstance(self.layer.quant_method.quant_method,
|
||||||
AscendW8A8LinearMethod)
|
AscendW8A8LinearMethod)
|
||||||
) and torch.version.cann.startswith("8.3"):
|
) and torch.version.cann.startswith("8.3"):
|
||||||
|
|||||||
Reference in New Issue
Block a user