Fix W8A8 fused moe bug (#1529)

### What this PR does / why we need it?
1. drop some useless code for w8a8 fusedmoe
2. Add in8 kv cache check
3. Add more ut.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI passed with new added test.

---------

Signed-off-by: zhuyilin <809721801@qq.com>
Signed-off-by: tianyitang <tangtianyi4@huawei.com>
Co-authored-by: tianyitang <tangtianyi4@huawei.com>
This commit is contained in:
Zhu Yi Lin
2025-07-02 16:40:51 +08:00
committed by GitHub
parent 7fc1a98489
commit 6b80c5acba
8 changed files with 1623 additions and 53 deletions

View File

@@ -251,7 +251,8 @@ class VLLMAscendQuantizer:
# Attention
if '.attn' in prefix and 'fa_quant_type' in quant_description.keys():
quant_type = quant_description['fa_quant_type']
if '.attn' in prefix and 'kv_quant_type' in quant_description.keys():
# Use KVCache int8
elif '.attn' in prefix and 'kv_quant_type' in quant_description.keys():
quant_type = quant_description['kv_quant_type']
# Linear
else:

View File

@@ -15,7 +15,6 @@
# limitations under the License.
#
import os
from typing import Any, Callable, Dict, Optional
import torch
@@ -219,53 +218,34 @@ class AscendW8A8FusedMoEMethod:
assert router_logits.shape[
1] == global_num_experts, "Number of global experts mismatch"
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if global_num_experts == 256:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k,
bias=e_score_correction_bias,
k_group=topk_group,
group_count=num_expert_group,
group_select_mode=1,
renorm=0,
norm_type=1,
routed_scaling_factor=1,
eps=float(1e-20))
else:
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts,
)
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts,
)
if os.environ.get("VLLM_ENABLE_MC2", '0') == "1" and not is_prefill:
raise NotImplementedError("W8A8FusedMoe are not "
"mplemented for VLLM_ENABLE_MC2")
else:
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale,
w1_input_scale=layer.w13_input_scale,
w1_input_offset=layer.w13_input_offset,
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
w2_input_scale=layer.w2_input_scale,
w2_input_offset=layer.w2_input_offset,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
global_num_experts=global_num_experts,
expert_map=expert_map)
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale,
w1_input_scale=layer.w13_input_scale,
w1_input_offset=layer.w13_input_offset,
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
w2_input_scale=layer.w2_input_scale,
w2_input_offset=layer.w2_input_offset,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
global_num_experts=global_num_experts,
expert_map=expert_map)
def process_weights_after_loading(self, layer):
# torch.npu.config.allow_internal_format = True
@@ -299,8 +279,10 @@ class AscendW8A8FusedMoEMethod:
torch.int8)
# NZ
# layer.w13_weight.data = torch_npu.npu_format_cast(layer.w13_weight.data, 29).contiguous()
# layer.w2_weight.data = torch_npu.npu_format_cast(layer.w2_weight.data, 29).contiguous()
layer.w13_weight.data = torch_npu.npu_format_cast(
layer.w13_weight.data, 29).contiguous()
layer.w2_weight.data = torch_npu.npu_format_cast(
layer.w2_weight.data, 29).contiguous()
class AscendC8KVCacheMethod: