[Feature]Use DispatchGmmCombineDecode operator to replace MC2(Optional) (#5040)
### What this PR does / why we need it?
This PR adds model-side integration for the previously introduced
experimental AscendC fused operator DispatchGmmCombineDecode, used in
MoE decoding.
The operator implementation itself was added in a prior PR[#4139
](https://github.com/vllm-project/vllm-ascend/pull/4139).
This change only adapts the model execution path to optionally use the
fused operator.
When the environment variable VLLM_ASCEND_ENABLE_FUSED_MC2=2 is set, the
original MC2 path composed of multiple operators (A8W8 dispatch → GMM →
SwiGLU → GMM → combine) might be replaced by the single fused operator
DispatchGmmCombineDecode.
By default, the existing multi-operator MC2 implementation is preserved.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: wangqiankun <wangqiankun13@huawei.com>
This commit is contained in:
@@ -253,12 +253,24 @@ def select_moe_comm_method(num_tokens: int,
|
|||||||
ascend_config = get_ascend_config()
|
ascend_config = get_ascend_config()
|
||||||
dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
|
dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
|
||||||
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
|
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
|
||||||
fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic" and get_ep_group(
|
# TODO: drop dynamic_eplb guard when dispatch_gmm_combine_decode supports tensor list inputs
|
||||||
).world_size <= 16 and (not dynamic_eplb) and (not is_mtp_model)
|
# TODO: add guard for dispatch_gmm_combine_decode when mtp uses float while moe uses w8a8
|
||||||
|
fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic" and (
|
||||||
|
not dynamic_eplb)
|
||||||
if num_tokens <= mc2_tokens_capacity:
|
if num_tokens <= mc2_tokens_capacity:
|
||||||
moe_comm_type = MoECommType.FUSED_MC2 if fused_mc2_enable else MoECommType.MC2
|
fused_decode_enable = fused_mc2_enable
|
||||||
|
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||||
|
fused_decode_enable = fused_mc2_enable and get_ep_group(
|
||||||
|
).world_size <= 16 and (not is_mtp_model)
|
||||||
|
moe_comm_type = MoECommType.FUSED_MC2 if fused_decode_enable else MoECommType.MC2
|
||||||
else:
|
else:
|
||||||
moe_comm_type = MoECommType.FUSED_MC2 if fused_mc2_enable else MoECommType.ALLTOALL
|
fused_prefill_enable = fused_mc2_enable
|
||||||
|
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||||
|
fused_prefill_enable = fused_mc2_enable and get_ep_group(
|
||||||
|
).world_size <= 16 and (not is_mtp_model)
|
||||||
|
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
|
||||||
|
fused_prefill_enable = False
|
||||||
|
moe_comm_type = MoECommType.FUSED_MC2 if fused_prefill_enable else MoECommType.ALLTOALL
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported soc_version: {soc_version}")
|
raise ValueError(f"Unsupported soc_version: {soc_version}")
|
||||||
|
|||||||
@@ -135,7 +135,13 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
# Whether to anbale dynamic EPLB
|
# Whether to anbale dynamic EPLB
|
||||||
"DYNAMIC_EPLB":
|
"DYNAMIC_EPLB":
|
||||||
lambda: os.getenv("DYNAMIC_EPLB", "false").lower(),
|
lambda: os.getenv("DYNAMIC_EPLB", "false").lower(),
|
||||||
# Whether to anbale fused mc2(dispatch_gmm_combine_decode/dispatch_ffn_combine operator)
|
# Whether to enable fused mc2(`dispatch_gmm_combine_decode`/`dispatch_ffn_combine` operator)
|
||||||
|
# 0, or not set: default ALLTOALL and MC2 will be used.
|
||||||
|
# 1: ALLTOALL and MC2 might be replaced by `dispatch_ffn_combine` operator.
|
||||||
|
# `dispatch_ffn_combine` can be used only for moe layer with W8A8, EP<=16, non-mtp, non-dynamic-eplb.
|
||||||
|
# 2: MC2 might be replaced by `dispatch_gmm_combine_decode` operator.
|
||||||
|
# `dispatch_gmm_combine_decode` can be used only for **decode node** moe layer
|
||||||
|
# with W8A8, non-dynamic-eplb. And MTP layer must be W8A8.
|
||||||
"VLLM_ASCEND_ENABLE_FUSED_MC2":
|
"VLLM_ASCEND_ENABLE_FUSED_MC2":
|
||||||
lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FUSED_MC2", '0')),
|
lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FUSED_MC2", '0')),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -345,7 +345,7 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
shared_out = fc3_context.shared_experts(hidden_states)
|
shared_out = fc3_context.shared_experts(hidden_states)
|
||||||
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
|
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
|
||||||
moe_comm_type = forward_context.moe_comm_type
|
moe_comm_type = forward_context.moe_comm_type
|
||||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
|
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \
|
||||||
and not shared_expert_dp_enabled():
|
and not shared_expert_dp_enabled():
|
||||||
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
||||||
set_flash_common3_context(shared_out=shared_out)
|
set_flash_common3_context(shared_out=shared_out)
|
||||||
|
|||||||
@@ -291,9 +291,9 @@ class FusedMC2CommImpl(MoECommMethod):
|
|||||||
assert not (
|
assert not (
|
||||||
w1_scale is None or w2_scale is None
|
w1_scale is None or w2_scale is None
|
||||||
), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl."
|
), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl."
|
||||||
out = torch.empty_like(hidden_states)
|
|
||||||
|
|
||||||
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||||
|
out = torch.empty_like(hidden_states)
|
||||||
torch.ops._C_ascend.dispatch_ffn_combine(
|
torch.ops._C_ascend.dispatch_ffn_combine(
|
||||||
x=hidden_states,
|
x=hidden_states,
|
||||||
weight1=w1[0],
|
weight1=w1[0],
|
||||||
@@ -307,7 +307,21 @@ class FusedMC2CommImpl(MoECommMethod):
|
|||||||
out=out,
|
out=out,
|
||||||
)
|
)
|
||||||
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
|
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
|
||||||
raise NotImplementedError()
|
assert expert_map is not None, "expert_map cannot be None."
|
||||||
|
out, _ = torch.ops._C_ascend.dispatch_gmm_combine_decode(
|
||||||
|
x=hidden_states,
|
||||||
|
expert_ids=topk_ids,
|
||||||
|
gmm1_permuted_weight=w1[0],
|
||||||
|
gmm1_permuted_weight_scale=w1_scale[0],
|
||||||
|
gmm2_weight=w2[0],
|
||||||
|
gmm2_weight_scale=w2_scale[0],
|
||||||
|
expert_smooth_scales=None,
|
||||||
|
expert_scales=topk_weights.to(torch.float32),
|
||||||
|
group_ep=self.token_dispatcher.moe_all_to_all_group_name,
|
||||||
|
ep_rank_size=self.token_dispatcher.ep_world_size,
|
||||||
|
ep_rank_id=self.token_dispatcher.ep_rank_id,
|
||||||
|
moe_expert_num=len(expert_map),
|
||||||
|
global_bs=self.token_dispatcher.fused_global_bs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}")
|
f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}")
|
||||||
|
|||||||
@@ -125,6 +125,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
|||||||
max_num_tokens = min(max_num_reqs * uniform_decode_query_len, 512)
|
max_num_tokens = min(max_num_reqs * uniform_decode_query_len, 512)
|
||||||
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
|
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
|
||||||
self.global_bs = num_tokens_per_tp_rank * self.ep_world_size
|
self.global_bs = num_tokens_per_tp_rank * self.ep_world_size
|
||||||
|
self.fused_global_bs = max_num_tokens * self.ep_world_size
|
||||||
|
|
||||||
def get_dispatch_mc2_kwargs(
|
def get_dispatch_mc2_kwargs(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -231,6 +231,10 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
topk_weights = topk_weights.to(self.in_dtype)
|
topk_weights = topk_weights.to(self.in_dtype)
|
||||||
|
|
||||||
moe_comm_method = get_forward_context().moe_comm_method
|
moe_comm_method = get_forward_context().moe_comm_method
|
||||||
|
# When VLLM_ASCEND_ENABLE_FUSED_MC2 == 2, use dispatch_gmm_combine_decode, need fp32 scale
|
||||||
|
w2_weight_scale_fp32_flag = (
|
||||||
|
get_forward_context().moe_comm_type == MoECommType.FUSED_MC2
|
||||||
|
and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2)
|
||||||
if self.dynamic_eplb:
|
if self.dynamic_eplb:
|
||||||
w1 = layer.w13_weight_list
|
w1 = layer.w13_weight_list
|
||||||
w1_scale = layer.w13_weight_scale_fp32_list
|
w1_scale = layer.w13_weight_scale_fp32_list
|
||||||
@@ -240,7 +244,10 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
w1 = [layer.w13_weight]
|
w1 = [layer.w13_weight]
|
||||||
w1_scale = [layer.w13_weight_scale_fp32]
|
w1_scale = [layer.w13_weight_scale_fp32]
|
||||||
w2 = [layer.w2_weight]
|
w2 = [layer.w2_weight]
|
||||||
w2_scale = [layer.w2_weight_scale]
|
w2_scale = [
|
||||||
|
layer.w2_weight_scale_fp32
|
||||||
|
if w2_weight_scale_fp32_flag else layer.w2_weight_scale
|
||||||
|
]
|
||||||
|
|
||||||
fused_scale_flag = (get_forward_context().moe_comm_type
|
fused_scale_flag = (get_forward_context().moe_comm_type
|
||||||
== MoECommType.FUSED_MC2
|
== MoECommType.FUSED_MC2
|
||||||
@@ -279,6 +286,8 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
layer.w13_weight_offset.data.shape[0], -1)
|
layer.w13_weight_offset.data.shape[0], -1)
|
||||||
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
|
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
|
||||||
layer.w2_weight_scale.data.shape[0], -1)
|
layer.w2_weight_scale.data.shape[0], -1)
|
||||||
|
layer.w2_weight_scale_fp32 = layer.w2_weight_scale.data.to(
|
||||||
|
torch.float32)
|
||||||
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
|
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
|
||||||
layer.w2_weight_offset.data.shape[0], -1)
|
layer.w2_weight_offset.data.shape[0], -1)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user