[EPLB][bugfix] Bugfix for fused mc2 (#6794)
### What this PR does / why we need it?
This pull request addresses a bug related to the fused mc2 functionality
within the EPLB (Expert Parallelism Load Balancing) system, specifically
impacting quantization and MoE communication.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.15.0
- vLLM main:
83b47f67b1
Signed-off-by: Spicy-Stick <873805887@qq.com>
Signed-off-by: root <root@localhost.localdomain>
This commit is contained in:
@@ -235,28 +235,28 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
|
||||
topk_weights = topk_weights.to(self.in_dtype)
|
||||
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
if self.dynamic_eplb:
|
||||
w1 = layer.w13_weight_list
|
||||
w1_scale = layer.w13_weight_scale_fp32_list
|
||||
w2 = layer.w2_weight_list
|
||||
w2_scale = layer.w2_weight_scale_list
|
||||
else:
|
||||
w1 = [layer.w13_weight]
|
||||
w1_scale = [layer.w13_weight_scale_fp32]
|
||||
w2 = [layer.w2_weight]
|
||||
w2_scale = [layer.w2_weight_scale]
|
||||
|
||||
fused_scale_flag = (
|
||||
get_forward_context().moe_comm_type == MoECommType.FUSED_MC2
|
||||
and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1
|
||||
)
|
||||
if self.dynamic_eplb:
|
||||
w1 = layer.w13_weight_list
|
||||
w1_scale = layer.fused_w1_scale_list if fused_scale_flag else layer.w13_weight_scale_fp32_list
|
||||
w2 = layer.w2_weight_list
|
||||
w2_scale = layer.fused_w2_scale_list if fused_scale_flag else layer.w2_weight_scale_list
|
||||
else:
|
||||
w1 = [layer.w13_weight]
|
||||
w1_scale = [layer.fused_w1_scale] if fused_scale_flag else [layer.w13_weight_scale_fp32]
|
||||
w2 = [layer.w2_weight]
|
||||
w2_scale = [layer.fused_w2_scale] if fused_scale_flag else [layer.w2_weight_scale]
|
||||
|
||||
final_hidden_states = moe_comm_method.fused_experts(
|
||||
hidden_states=x,
|
||||
pertoken_scale=pertoken_scale,
|
||||
w1=w1,
|
||||
w1_scale=[layer.fused_w1_scale] if fused_scale_flag else w1_scale,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=[layer.fused_w2_scale] if fused_scale_flag else w2_scale,
|
||||
w2_scale=w2_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
use_int8_w8a8=True,
|
||||
@@ -282,8 +282,9 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
|
||||
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(layer.w2_weight_scale.data.shape[0], -1)
|
||||
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(layer.w2_weight_offset.data.shape[0], -1)
|
||||
|
||||
layer.fused_w1_scale = scale_from_float_to_int64(layer.w13_weight_scale.data)
|
||||
layer.fused_w2_scale = scale_from_float_to_int64(layer.w2_weight_scale.data)
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||
layer.fused_w1_scale = scale_from_float_to_int64(layer.w13_weight_scale.data)
|
||||
layer.fused_w2_scale = scale_from_float_to_int64(layer.w2_weight_scale.data)
|
||||
|
||||
if self.dynamic_eplb:
|
||||
layer.w13_weight_list = [weight.clone() for weight in layer.w13_weight.data.unbind(dim=0)]
|
||||
@@ -292,9 +293,21 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
|
||||
weight.clone() for weight in layer.w13_weight_scale_fp32.data.unbind(dim=0)
|
||||
]
|
||||
layer.w2_weight_scale_list = [weight.clone() for weight in layer.w2_weight_scale.data.unbind(dim=0)]
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||
layer.fused_w1_scale_list = [
|
||||
weight.clone()
|
||||
for weight in layer.fused_w1_scale.view(len(layer.w13_weight_list), -1).data.unbind(dim=0)
|
||||
]
|
||||
layer.fused_w2_scale_list = [
|
||||
weight.clone()
|
||||
for weight in layer.fused_w2_scale.view(len(layer.w2_weight_list), -1).data.unbind(dim=0)
|
||||
]
|
||||
del layer.w13_weight
|
||||
del layer.w2_weight
|
||||
del layer.w13_weight_scale
|
||||
del layer.w13_weight_scale_fp32
|
||||
del layer.w2_weight_scale
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||
del layer.fused_w1_scale
|
||||
del layer.fused_w2_scale
|
||||
torch.npu.empty_cache()
|
||||
|
||||
Reference in New Issue
Block a user