# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Optional import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( _moe_unpermute_and_reduce) from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input) class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): def __init__( self, quant_dtype: Optional[torch.dtype] = None, per_channel_quant: bool = False, block_shape: Optional[list[int]] = None, ): super().__init__() self.per_channel_quant = per_channel_quant self.block_shape = block_shape self.quant_dtype = quant_dtype def max_num_tokens_per_rank(self) -> Optional[int]: return None def topk_indices_dtype(self) -> Optional[torch.dtype]: return None def prepare( self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool = False, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: if apply_router_weight_on_input: topk = topk_ids.size(1) # TODO: this only works for topK=1, will need to update for topK>1 assert topk == 1, \ "apply_router_weight_on_input is only implemented for topk=1" a1.mul_(topk_weights.to(a1.dtype)) a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale, self.quant_dtype, self.per_channel_quant, self.block_shape) return a1q, a1q_scale, None, None, None def finalize( self, output: torch.Tensor, fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, apply_router_weight_on_input: bool, ) -> None: _moe_unpermute_and_reduce(output, fused_expert_output, None, topk_weights, apply_router_weight_on_input)