# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate, ) from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): @property def activation_format(self) -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard def max_num_tokens_per_rank(self) -> int | None: return None def topk_indices_dtype(self) -> torch.dtype | None: return None def num_dispatchers(self) -> int: return 1 def output_is_reduced(self) -> bool: return False def prepare( self, a1: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: if apply_router_weight_on_input: topk = topk_ids.size(1) # TODO: this only works for topK=1, will need to update for topK>1 assert topk == 1, ( "apply_router_weight_on_input is only implemented for topk=1" ) # Note: do not use inplace for shared experts overlap a1 = a1 * topk_weights.to(a1.dtype) a1q, a1q_scale = moe_kernel_quantize_input( a1, quant_config.a1_scale, quant_config.quant_dtype, quant_config.per_act_token_quant, quant_config.block_shape, ) return a1q, a1q_scale, None, None, None def finalize( self, output: torch.Tensor, fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, ) -> None: if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): weight_and_reduce_impl = TopKWeightAndReduceContiguous() weight_and_reduce_impl.apply( output=output, fused_expert_output=fused_expert_output, topk_weights=topk_weights, topk_ids=topk_ids, apply_router_weight_on_input=apply_router_weight_on_input, )