[EPLB][bugfix] Bugfix for fused mc2 (#6794)
### What this PR does / why we need it?
This pull request addresses a bug related to the fused mc2 functionality
within the EPLB (Expert Parallelism Load Balancing) system, specifically
impacting quantization and MoE communication.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.15.0
- vLLM main:
83b47f67b1
Signed-off-by: Spicy-Stick <873805887@qq.com>
Signed-off-by: root <root@localhost.localdomain>
This commit is contained in:
@@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch
|
||||
import torch
|
||||
|
||||
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
|
||||
from vllm_ascend.quantization.methods.base import QuantType
|
||||
from transformers import DeepseekV2Config
|
||||
|
||||
|
||||
@@ -17,6 +18,8 @@ class TestVllmAdaptor(unittest.TestCase):
|
||||
mock_model.get_expert_map.return_value = [i for i in range(n_routed_experts)]
|
||||
mock_model.get_log2phy_map.return_value = [i for i in range(n_routed_experts)]
|
||||
self.model = mock_model
|
||||
num_dense_layers = getattr(config, "first_k_dense_replace", 0)
|
||||
self.model.model.layers[num_dense_layers].mlp.experts.quant_type = QuantType.W8A8
|
||||
|
||||
self.mock_rank = patch("vllm_ascend.eplb.adaptor.vllm_adaptor.dist.get_rank", return_value=0).start()
|
||||
self.mock_size = patch("vllm_ascend.eplb.adaptor.vllm_adaptor.dist.get_world_size", return_value=4).start()
|
||||
|
||||
@@ -9,7 +9,6 @@ from vllm.distributed import get_dp_group, get_ep_group, get_tensor_model_parall
|
||||
from vllm.forward_context import BatchDescriptor, get_forward_context, set_forward_context
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.utils import (
|
||||
AscendDeviceType,
|
||||
enable_sp,
|
||||
@@ -243,11 +242,10 @@ def select_moe_comm_method(num_tokens: int, vllm_config: VllmConfig, is_draft_mo
|
||||
moe_comm_type = MoECommType.ALLGATHER
|
||||
|
||||
elif soc_version in {AscendDeviceType.A3}:
|
||||
dynamic_eplb = get_ascend_config().eplb_config.dynamic_eplb
|
||||
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
|
||||
# TODO: drop speculative method guard when dispatch_gmm_combine_decode supports w16a16
|
||||
fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic"
|
||||
dispatch_ffn_combine_enable = get_ep_group().world_size <= 32 and (not is_draft_model) and (not dynamic_eplb)
|
||||
dispatch_ffn_combine_enable = get_ep_group().world_size <= 32 and (not is_draft_model)
|
||||
if num_tokens <= mc2_tokens_capacity:
|
||||
fused_decode_enable = fused_mc2_enable
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||
|
||||
@@ -22,6 +22,9 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from vllm.logger import logger
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.quantization.methods.base import QuantType
|
||||
|
||||
|
||||
class VllmEplbAdaptor:
|
||||
def __init__(self, model, **args):
|
||||
@@ -59,12 +62,19 @@ class VllmEplbAdaptor:
|
||||
def init_expert_param_per_layer(self):
|
||||
self.param_dict = dict()
|
||||
if self.model.quant_config is not None:
|
||||
quant_type = self.model.model.layers[self.num_dense_layers].mlp.experts.quant_type
|
||||
if quant_type == QuantType.W8A8:
|
||||
self.expert_weight_names = [
|
||||
"w13_weight_list",
|
||||
"w2_weight_list",
|
||||
"w13_weight_scale_fp32_list",
|
||||
"w2_weight_scale_list",
|
||||
]
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||
self.expert_weight_names.append("fused_w1_scale_list")
|
||||
self.expert_weight_names.append("fused_w2_scale_list")
|
||||
else:
|
||||
raise ValueError(f"EPLB not support {quant_type}")
|
||||
else:
|
||||
self.expert_weight_names = ["w13_weight", "w2_weight"]
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ class FusedExpertsResult:
|
||||
before_dispatch_evt: torch.npu.Event | None = None
|
||||
before_combine_evt: torch.npu.Event | None = None
|
||||
# For dynamic_eplb
|
||||
group_list_type: int | None = None
|
||||
group_list_type: int = 1
|
||||
expert_tokens: torch.Tensor | None = None
|
||||
|
||||
|
||||
@@ -355,7 +355,6 @@ class FusedMC2CommImpl(MoECommMethod):
|
||||
if log2phy is not None:
|
||||
topk_ids = log2phy[topk_ids]
|
||||
|
||||
group_list_type = None
|
||||
expert_tokens = None
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||
out = torch.empty_like(hidden_states)
|
||||
@@ -375,7 +374,6 @@ class FusedMC2CommImpl(MoECommMethod):
|
||||
expert_tokens = self.expert_token_nums
|
||||
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
|
||||
assert expert_map is not None, "expert_map cannot be None."
|
||||
group_list_type = 1
|
||||
out, expert_tokens = torch.ops._C_ascend.dispatch_gmm_combine_decode( # type: ignore
|
||||
x=hidden_states,
|
||||
expert_ids=topk_ids,
|
||||
@@ -393,4 +391,4 @@ class FusedMC2CommImpl(MoECommMethod):
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}")
|
||||
return FusedExpertsResult(routed_out=out, group_list_type=group_list_type, expert_tokens=expert_tokens)
|
||||
return FusedExpertsResult(routed_out=out, expert_tokens=expert_tokens)
|
||||
|
||||
@@ -235,28 +235,28 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
|
||||
topk_weights = topk_weights.to(self.in_dtype)
|
||||
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
if self.dynamic_eplb:
|
||||
w1 = layer.w13_weight_list
|
||||
w1_scale = layer.w13_weight_scale_fp32_list
|
||||
w2 = layer.w2_weight_list
|
||||
w2_scale = layer.w2_weight_scale_list
|
||||
else:
|
||||
w1 = [layer.w13_weight]
|
||||
w1_scale = [layer.w13_weight_scale_fp32]
|
||||
w2 = [layer.w2_weight]
|
||||
w2_scale = [layer.w2_weight_scale]
|
||||
|
||||
fused_scale_flag = (
|
||||
get_forward_context().moe_comm_type == MoECommType.FUSED_MC2
|
||||
and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1
|
||||
)
|
||||
if self.dynamic_eplb:
|
||||
w1 = layer.w13_weight_list
|
||||
w1_scale = layer.fused_w1_scale_list if fused_scale_flag else layer.w13_weight_scale_fp32_list
|
||||
w2 = layer.w2_weight_list
|
||||
w2_scale = layer.fused_w2_scale_list if fused_scale_flag else layer.w2_weight_scale_list
|
||||
else:
|
||||
w1 = [layer.w13_weight]
|
||||
w1_scale = [layer.fused_w1_scale] if fused_scale_flag else [layer.w13_weight_scale_fp32]
|
||||
w2 = [layer.w2_weight]
|
||||
w2_scale = [layer.fused_w2_scale] if fused_scale_flag else [layer.w2_weight_scale]
|
||||
|
||||
final_hidden_states = moe_comm_method.fused_experts(
|
||||
hidden_states=x,
|
||||
pertoken_scale=pertoken_scale,
|
||||
w1=w1,
|
||||
w1_scale=[layer.fused_w1_scale] if fused_scale_flag else w1_scale,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=[layer.fused_w2_scale] if fused_scale_flag else w2_scale,
|
||||
w2_scale=w2_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
use_int8_w8a8=True,
|
||||
@@ -282,6 +282,7 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
|
||||
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(layer.w2_weight_scale.data.shape[0], -1)
|
||||
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(layer.w2_weight_offset.data.shape[0], -1)
|
||||
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||
layer.fused_w1_scale = scale_from_float_to_int64(layer.w13_weight_scale.data)
|
||||
layer.fused_w2_scale = scale_from_float_to_int64(layer.w2_weight_scale.data)
|
||||
|
||||
@@ -292,9 +293,21 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
|
||||
weight.clone() for weight in layer.w13_weight_scale_fp32.data.unbind(dim=0)
|
||||
]
|
||||
layer.w2_weight_scale_list = [weight.clone() for weight in layer.w2_weight_scale.data.unbind(dim=0)]
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||
layer.fused_w1_scale_list = [
|
||||
weight.clone()
|
||||
for weight in layer.fused_w1_scale.view(len(layer.w13_weight_list), -1).data.unbind(dim=0)
|
||||
]
|
||||
layer.fused_w2_scale_list = [
|
||||
weight.clone()
|
||||
for weight in layer.fused_w2_scale.view(len(layer.w2_weight_list), -1).data.unbind(dim=0)
|
||||
]
|
||||
del layer.w13_weight
|
||||
del layer.w2_weight
|
||||
del layer.w13_weight_scale
|
||||
del layer.w13_weight_scale_fp32
|
||||
del layer.w2_weight_scale
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||
del layer.fused_w1_scale
|
||||
del layer.fused_w2_scale
|
||||
torch.npu.empty_cache()
|
||||
|
||||
Reference in New Issue
Block a user