diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 2b7d57c..c6e863f 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -150,8 +150,8 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor, log2phy: torch.Tensor = None, global_redundant_expert_num: int = 0, **kwargs) -> torch.Tensor: - - topk_ids = log2phy[topk_ids] + if log2phy: + topk_ids = log2phy[topk_ids] global_bs = 0 moe_expert_num = len(expert_map) + global_redundant_expert_num # hidden_states = hidden_states.bfloat16() @@ -278,7 +278,8 @@ def fused_experts_with_all2all( log2phy: torch.Tensor = None, global_redundant_expert_num: int = 0, ): - topk_ids = log2phy[topk_ids] + if log2phy: + topk_ids = log2phy[topk_ids] original_shape = hidden_states.shape if len(original_shape) == 3: hidden_states = hidden_states.view(-1, hidden_states.shape[-1])