[EPLB][bugfix] Bugfix for fused mc2 (#6794)
### What this PR does / why we need it?
This pull request addresses a bug related to the fused mc2 functionality
within the EPLB (Expert Parallelism Load Balancing) system, specifically
impacting quantization and MoE communication.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.15.0
- vLLM main:
83b47f67b1
Signed-off-by: Spicy-Stick <873805887@qq.com>
Signed-off-by: root <root@localhost.localdomain>
This commit is contained in:
@@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
|
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
|
||||||
|
from vllm_ascend.quantization.methods.base import QuantType
|
||||||
from transformers import DeepseekV2Config
|
from transformers import DeepseekV2Config
|
||||||
|
|
||||||
|
|
||||||
@@ -17,6 +18,8 @@ class TestVllmAdaptor(unittest.TestCase):
|
|||||||
mock_model.get_expert_map.return_value = [i for i in range(n_routed_experts)]
|
mock_model.get_expert_map.return_value = [i for i in range(n_routed_experts)]
|
||||||
mock_model.get_log2phy_map.return_value = [i for i in range(n_routed_experts)]
|
mock_model.get_log2phy_map.return_value = [i for i in range(n_routed_experts)]
|
||||||
self.model = mock_model
|
self.model = mock_model
|
||||||
|
num_dense_layers = getattr(config, "first_k_dense_replace", 0)
|
||||||
|
self.model.model.layers[num_dense_layers].mlp.experts.quant_type = QuantType.W8A8
|
||||||
|
|
||||||
self.mock_rank = patch("vllm_ascend.eplb.adaptor.vllm_adaptor.dist.get_rank", return_value=0).start()
|
self.mock_rank = patch("vllm_ascend.eplb.adaptor.vllm_adaptor.dist.get_rank", return_value=0).start()
|
||||||
self.mock_size = patch("vllm_ascend.eplb.adaptor.vllm_adaptor.dist.get_world_size", return_value=4).start()
|
self.mock_size = patch("vllm_ascend.eplb.adaptor.vllm_adaptor.dist.get_world_size", return_value=4).start()
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from vllm.distributed import get_dp_group, get_ep_group, get_tensor_model_parall
|
|||||||
from vllm.forward_context import BatchDescriptor, get_forward_context, set_forward_context
|
from vllm.forward_context import BatchDescriptor, get_forward_context, set_forward_context
|
||||||
|
|
||||||
import vllm_ascend.envs as envs_ascend
|
import vllm_ascend.envs as envs_ascend
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
|
||||||
from vllm_ascend.utils import (
|
from vllm_ascend.utils import (
|
||||||
AscendDeviceType,
|
AscendDeviceType,
|
||||||
enable_sp,
|
enable_sp,
|
||||||
@@ -243,11 +242,10 @@ def select_moe_comm_method(num_tokens: int, vllm_config: VllmConfig, is_draft_mo
|
|||||||
moe_comm_type = MoECommType.ALLGATHER
|
moe_comm_type = MoECommType.ALLGATHER
|
||||||
|
|
||||||
elif soc_version in {AscendDeviceType.A3}:
|
elif soc_version in {AscendDeviceType.A3}:
|
||||||
dynamic_eplb = get_ascend_config().eplb_config.dynamic_eplb
|
|
||||||
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
|
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
|
||||||
# TODO: drop speculative method guard when dispatch_gmm_combine_decode supports w16a16
|
# TODO: drop speculative method guard when dispatch_gmm_combine_decode supports w16a16
|
||||||
fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic"
|
fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic"
|
||||||
dispatch_ffn_combine_enable = get_ep_group().world_size <= 32 and (not is_draft_model) and (not dynamic_eplb)
|
dispatch_ffn_combine_enable = get_ep_group().world_size <= 32 and (not is_draft_model)
|
||||||
if num_tokens <= mc2_tokens_capacity:
|
if num_tokens <= mc2_tokens_capacity:
|
||||||
fused_decode_enable = fused_mc2_enable
|
fused_decode_enable = fused_mc2_enable
|
||||||
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||||
|
|||||||
@@ -22,6 +22,9 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
|
|
||||||
|
import vllm_ascend.envs as envs_ascend
|
||||||
|
from vllm_ascend.quantization.methods.base import QuantType
|
||||||
|
|
||||||
|
|
||||||
class VllmEplbAdaptor:
|
class VllmEplbAdaptor:
|
||||||
def __init__(self, model, **args):
|
def __init__(self, model, **args):
|
||||||
@@ -59,12 +62,19 @@ class VllmEplbAdaptor:
|
|||||||
def init_expert_param_per_layer(self):
|
def init_expert_param_per_layer(self):
|
||||||
self.param_dict = dict()
|
self.param_dict = dict()
|
||||||
if self.model.quant_config is not None:
|
if self.model.quant_config is not None:
|
||||||
|
quant_type = self.model.model.layers[self.num_dense_layers].mlp.experts.quant_type
|
||||||
|
if quant_type == QuantType.W8A8:
|
||||||
self.expert_weight_names = [
|
self.expert_weight_names = [
|
||||||
"w13_weight_list",
|
"w13_weight_list",
|
||||||
"w2_weight_list",
|
"w2_weight_list",
|
||||||
"w13_weight_scale_fp32_list",
|
"w13_weight_scale_fp32_list",
|
||||||
"w2_weight_scale_list",
|
"w2_weight_scale_list",
|
||||||
]
|
]
|
||||||
|
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||||
|
self.expert_weight_names.append("fused_w1_scale_list")
|
||||||
|
self.expert_weight_names.append("fused_w2_scale_list")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"EPLB not support {quant_type}")
|
||||||
else:
|
else:
|
||||||
self.expert_weight_names = ["w13_weight", "w2_weight"]
|
self.expert_weight_names = ["w13_weight", "w2_weight"]
|
||||||
|
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ class FusedExpertsResult:
|
|||||||
before_dispatch_evt: torch.npu.Event | None = None
|
before_dispatch_evt: torch.npu.Event | None = None
|
||||||
before_combine_evt: torch.npu.Event | None = None
|
before_combine_evt: torch.npu.Event | None = None
|
||||||
# For dynamic_eplb
|
# For dynamic_eplb
|
||||||
group_list_type: int | None = None
|
group_list_type: int = 1
|
||||||
expert_tokens: torch.Tensor | None = None
|
expert_tokens: torch.Tensor | None = None
|
||||||
|
|
||||||
|
|
||||||
@@ -355,7 +355,6 @@ class FusedMC2CommImpl(MoECommMethod):
|
|||||||
if log2phy is not None:
|
if log2phy is not None:
|
||||||
topk_ids = log2phy[topk_ids]
|
topk_ids = log2phy[topk_ids]
|
||||||
|
|
||||||
group_list_type = None
|
|
||||||
expert_tokens = None
|
expert_tokens = None
|
||||||
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||||
out = torch.empty_like(hidden_states)
|
out = torch.empty_like(hidden_states)
|
||||||
@@ -375,7 +374,6 @@ class FusedMC2CommImpl(MoECommMethod):
|
|||||||
expert_tokens = self.expert_token_nums
|
expert_tokens = self.expert_token_nums
|
||||||
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
|
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
|
||||||
assert expert_map is not None, "expert_map cannot be None."
|
assert expert_map is not None, "expert_map cannot be None."
|
||||||
group_list_type = 1
|
|
||||||
out, expert_tokens = torch.ops._C_ascend.dispatch_gmm_combine_decode( # type: ignore
|
out, expert_tokens = torch.ops._C_ascend.dispatch_gmm_combine_decode( # type: ignore
|
||||||
x=hidden_states,
|
x=hidden_states,
|
||||||
expert_ids=topk_ids,
|
expert_ids=topk_ids,
|
||||||
@@ -393,4 +391,4 @@ class FusedMC2CommImpl(MoECommMethod):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}")
|
raise ValueError(f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}")
|
||||||
return FusedExpertsResult(routed_out=out, group_list_type=group_list_type, expert_tokens=expert_tokens)
|
return FusedExpertsResult(routed_out=out, expert_tokens=expert_tokens)
|
||||||
|
|||||||
@@ -235,28 +235,28 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
|
|||||||
topk_weights = topk_weights.to(self.in_dtype)
|
topk_weights = topk_weights.to(self.in_dtype)
|
||||||
|
|
||||||
moe_comm_method = get_forward_context().moe_comm_method
|
moe_comm_method = get_forward_context().moe_comm_method
|
||||||
if self.dynamic_eplb:
|
|
||||||
w1 = layer.w13_weight_list
|
|
||||||
w1_scale = layer.w13_weight_scale_fp32_list
|
|
||||||
w2 = layer.w2_weight_list
|
|
||||||
w2_scale = layer.w2_weight_scale_list
|
|
||||||
else:
|
|
||||||
w1 = [layer.w13_weight]
|
|
||||||
w1_scale = [layer.w13_weight_scale_fp32]
|
|
||||||
w2 = [layer.w2_weight]
|
|
||||||
w2_scale = [layer.w2_weight_scale]
|
|
||||||
|
|
||||||
fused_scale_flag = (
|
fused_scale_flag = (
|
||||||
get_forward_context().moe_comm_type == MoECommType.FUSED_MC2
|
get_forward_context().moe_comm_type == MoECommType.FUSED_MC2
|
||||||
and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1
|
and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1
|
||||||
)
|
)
|
||||||
|
if self.dynamic_eplb:
|
||||||
|
w1 = layer.w13_weight_list
|
||||||
|
w1_scale = layer.fused_w1_scale_list if fused_scale_flag else layer.w13_weight_scale_fp32_list
|
||||||
|
w2 = layer.w2_weight_list
|
||||||
|
w2_scale = layer.fused_w2_scale_list if fused_scale_flag else layer.w2_weight_scale_list
|
||||||
|
else:
|
||||||
|
w1 = [layer.w13_weight]
|
||||||
|
w1_scale = [layer.fused_w1_scale] if fused_scale_flag else [layer.w13_weight_scale_fp32]
|
||||||
|
w2 = [layer.w2_weight]
|
||||||
|
w2_scale = [layer.fused_w2_scale] if fused_scale_flag else [layer.w2_weight_scale]
|
||||||
|
|
||||||
final_hidden_states = moe_comm_method.fused_experts(
|
final_hidden_states = moe_comm_method.fused_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
pertoken_scale=pertoken_scale,
|
pertoken_scale=pertoken_scale,
|
||||||
w1=w1,
|
w1=w1,
|
||||||
w1_scale=[layer.fused_w1_scale] if fused_scale_flag else w1_scale,
|
w1_scale=w1_scale,
|
||||||
w2=w2,
|
w2=w2,
|
||||||
w2_scale=[layer.fused_w2_scale] if fused_scale_flag else w2_scale,
|
w2_scale=w2_scale,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
use_int8_w8a8=True,
|
use_int8_w8a8=True,
|
||||||
@@ -282,6 +282,7 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
|
|||||||
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(layer.w2_weight_scale.data.shape[0], -1)
|
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(layer.w2_weight_scale.data.shape[0], -1)
|
||||||
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(layer.w2_weight_offset.data.shape[0], -1)
|
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(layer.w2_weight_offset.data.shape[0], -1)
|
||||||
|
|
||||||
|
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||||
layer.fused_w1_scale = scale_from_float_to_int64(layer.w13_weight_scale.data)
|
layer.fused_w1_scale = scale_from_float_to_int64(layer.w13_weight_scale.data)
|
||||||
layer.fused_w2_scale = scale_from_float_to_int64(layer.w2_weight_scale.data)
|
layer.fused_w2_scale = scale_from_float_to_int64(layer.w2_weight_scale.data)
|
||||||
|
|
||||||
@@ -292,9 +293,21 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
|
|||||||
weight.clone() for weight in layer.w13_weight_scale_fp32.data.unbind(dim=0)
|
weight.clone() for weight in layer.w13_weight_scale_fp32.data.unbind(dim=0)
|
||||||
]
|
]
|
||||||
layer.w2_weight_scale_list = [weight.clone() for weight in layer.w2_weight_scale.data.unbind(dim=0)]
|
layer.w2_weight_scale_list = [weight.clone() for weight in layer.w2_weight_scale.data.unbind(dim=0)]
|
||||||
|
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||||
|
layer.fused_w1_scale_list = [
|
||||||
|
weight.clone()
|
||||||
|
for weight in layer.fused_w1_scale.view(len(layer.w13_weight_list), -1).data.unbind(dim=0)
|
||||||
|
]
|
||||||
|
layer.fused_w2_scale_list = [
|
||||||
|
weight.clone()
|
||||||
|
for weight in layer.fused_w2_scale.view(len(layer.w2_weight_list), -1).data.unbind(dim=0)
|
||||||
|
]
|
||||||
del layer.w13_weight
|
del layer.w13_weight
|
||||||
del layer.w2_weight
|
del layer.w2_weight
|
||||||
del layer.w13_weight_scale
|
del layer.w13_weight_scale
|
||||||
del layer.w13_weight_scale_fp32
|
del layer.w13_weight_scale_fp32
|
||||||
del layer.w2_weight_scale
|
del layer.w2_weight_scale
|
||||||
|
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||||
|
del layer.fused_w1_scale
|
||||||
|
del layer.fused_w2_scale
|
||||||
torch.npu.empty_cache()
|
torch.npu.empty_cache()
|
||||||
|
|||||||
Reference in New Issue
Block a user