[BugFix]Disable dispatch_gmm_combine_decode operator when mtp drafter model uses non-w8a8 while main model uses w8a8, or drafter model is eagle series (#5293)
…w8a8 while main model uses w8a8
### What this PR does / why we need it?
Disable dispatch_gmm_combine_decode operator when mtp drafter model uses
non-w8a8 while main model uses w8a8, or drafter model is eagle series.
More info about this operator, please refer to RFC: issue
https://github.com/vllm-project/vllm-ascend/issues/5476
- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c
Signed-off-by: wangqiankun <wangqiankun13@huawei.com>
This commit is contained in:
@@ -14,7 +14,8 @@ import vllm_ascend.envs as envs_ascend
|
|||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.utils import (AscendDeviceType, enable_sp, flashcomm2_enable,
|
from vllm_ascend.utils import (AscendDeviceType, enable_sp, flashcomm2_enable,
|
||||||
get_ascend_device_type, has_layer_idx,
|
get_ascend_device_type, has_layer_idx,
|
||||||
is_moe_model)
|
is_moe_model,
|
||||||
|
speculative_enable_dispatch_gmm_combine_decode)
|
||||||
|
|
||||||
|
|
||||||
class MoECommType(Enum):
|
class MoECommType(Enum):
|
||||||
@@ -242,7 +243,7 @@ def select_moe_comm_method(num_tokens: int,
|
|||||||
dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
|
dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
|
||||||
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
|
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
|
||||||
# TODO: drop dynamic_eplb guard when dispatch_gmm_combine_decode supports tensor list inputs
|
# TODO: drop dynamic_eplb guard when dispatch_gmm_combine_decode supports tensor list inputs
|
||||||
# TODO: add guard for dispatch_gmm_combine_decode when mtp uses float while moe uses w8a8
|
# TODO: drop speculative method guard when dispatch_gmm_combine_decode supports w16a16
|
||||||
fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic" and (
|
fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic" and (
|
||||||
not dynamic_eplb)
|
not dynamic_eplb)
|
||||||
if num_tokens <= mc2_tokens_capacity:
|
if num_tokens <= mc2_tokens_capacity:
|
||||||
@@ -250,6 +251,9 @@ def select_moe_comm_method(num_tokens: int,
|
|||||||
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||||
fused_decode_enable = fused_mc2_enable and get_ep_group(
|
fused_decode_enable = fused_mc2_enable and get_ep_group(
|
||||||
).world_size <= 16 and (not is_draft_model)
|
).world_size <= 16 and (not is_draft_model)
|
||||||
|
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
|
||||||
|
fused_decode_enable = fused_mc2_enable and \
|
||||||
|
speculative_enable_dispatch_gmm_combine_decode(vllm_config)
|
||||||
moe_comm_type = MoECommType.FUSED_MC2 if fused_decode_enable else MoECommType.MC2
|
moe_comm_type = MoECommType.FUSED_MC2 if fused_decode_enable else MoECommType.MC2
|
||||||
else:
|
else:
|
||||||
fused_prefill_enable = fused_mc2_enable
|
fused_prefill_enable = fused_mc2_enable
|
||||||
|
|||||||
@@ -831,6 +831,23 @@ def is_moe_model(vllm_config: VllmConfig):
|
|||||||
return _IS_MOE_MODEL
|
return _IS_MOE_MODEL
|
||||||
|
|
||||||
|
|
||||||
|
def speculative_enable_dispatch_gmm_combine_decode(
|
||||||
|
vllm_config: VllmConfig) -> bool:
|
||||||
|
if vllm_config.speculative_config is None:
|
||||||
|
return True
|
||||||
|
speculative_method = getattr(vllm_config.speculative_config, "method",
|
||||||
|
None)
|
||||||
|
if speculative_method in [None, "ngram", "suffix"]:
|
||||||
|
return True
|
||||||
|
if speculative_method in ["eagle", "eagle3"]:
|
||||||
|
return False
|
||||||
|
if speculative_method == "mtp":
|
||||||
|
mtp_quant_type = getattr(vllm_config.model_config.hf_config,
|
||||||
|
"mtp_quantize", None)
|
||||||
|
return mtp_quant_type == "w8a8_dynamic"
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _is_contain_expert(config: Any):
|
def _is_contain_expert(config: Any):
|
||||||
if isinstance(config, dict):
|
if isinstance(config, dict):
|
||||||
for k, v in config.items():
|
for k, v in config.items():
|
||||||
|
|||||||
Reference in New Issue
Block a user