# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC): """Base class for runtime dispatching of expert implementations.""" def __init__( self, experts: mk.FusedMoEPermuteExpertsUnpermute, fallback_experts: mk.FusedMoEPermuteExpertsUnpermute, ): super().__init__( moe_config=experts.moe_config, quant_config=experts.quant_config ) self.fallback_experts = fallback_experts self.experts = experts @staticmethod def get_clses() -> tuple[ type[mk.FusedMoEPermuteExpertsUnpermute], type[mk.FusedMoEPermuteExpertsUnpermute], ]: """ Get the cls for the experts and fallback experts. Subclasses should implement this method, so that we have a consistent way to call the _supports_* class methods below. """ raise NotImplementedError( "Subclasses must return the cls for the experts and fallback experts." ) @classmethod def activation_format( cls: type["FallbackExperts"], ) -> mk.FusedMoEActivationFormat: experts_cls, fallback_cls = cls.get_clses() assert experts_cls.activation_format() == fallback_cls.activation_format() return experts_cls.activation_format() @classmethod def _supports_current_device(cls) -> bool: experts_cls, fallback_cls = cls.get_clses() return ( experts_cls._supports_current_device() and fallback_cls._supports_current_device() ) @classmethod def _supports_no_act_and_mul(cls) -> bool: experts_cls, fallback_cls = cls.get_clses() return ( experts_cls._supports_no_act_and_mul() and fallback_cls._supports_no_act_and_mul() ) @classmethod def _supports_quant_scheme( cls, weight_key: QuantKey | None, activation_key: QuantKey | None, ) -> bool: experts_cls, fallback_cls = cls.get_clses() return experts_cls._supports_quant_scheme( weight_key, activation_key ) and fallback_cls._supports_quant_scheme(weight_key, activation_key) @classmethod def _supports_activation(cls, activation: MoEActivation) -> bool: experts_cls, fallback_cls = cls.get_clses() return experts_cls._supports_activation( activation ) and fallback_cls._supports_activation(activation) @classmethod def _supports_parallel_config( cls, moe_parallel_config: FusedMoEParallelConfig ) -> bool: experts_cls, fallback_cls = cls.get_clses() return experts_cls._supports_parallel_config( moe_parallel_config ) and fallback_cls._supports_parallel_config(moe_parallel_config) def supports_chunking(self) -> bool: assert ( self.experts.supports_chunking() == self.fallback_experts.supports_chunking() ) return ( self.experts.supports_chunking() and self.fallback_experts.supports_chunking() ) def supports_expert_map(self) -> bool: assert ( self.experts.supports_expert_map() == self.fallback_experts.supports_expert_map() ) return ( self.experts.supports_expert_map() and self.fallback_experts.supports_expert_map() ) def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: e_war = self.experts.finalize_weight_and_reduce_impl() fbe_war = self.fallback_experts.finalize_weight_and_reduce_impl() is_dge_war = e_war is not None is_fbe_war = fbe_war is not None if is_dge_war and is_fbe_war: assert e_war == fbe_war, ( "Both implementations should agree on WeightAndReduce impls. " f"Got e_war: {e_war}, and fbe_war: {fbe_war}" ) if e_war is not None: return e_war assert fbe_war is not None return fbe_war @abstractmethod def workspace_shapes( self, M: int, N: int, K: int, topk: int, global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, activation: MoEActivation, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: raise NotImplementedError @abstractmethod def _select_experts_impl( self, hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, ) -> mk.FusedMoEPermuteExpertsUnpermute: raise NotImplementedError def apply( self, output: torch.Tensor, hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: MoEActivation, global_num_experts: int, expert_map: torch.Tensor | None, a1q_scale: torch.Tensor | None, a2_scale: torch.Tensor | None, workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, ): experts = self._select_experts_impl(hidden_states, w1, w2) experts.apply( output, hidden_states, w1, w2, topk_weights, topk_ids, activation, global_num_experts, expert_map, a1q_scale, a2_scale, workspace13, workspace2, expert_tokens_meta, apply_router_weight_on_input, )