init
This commit is contained in:
69
model_executor/layers/fused_moe/prepare_finalize.py
Normal file
69
model_executor/layers/fused_moe/prepare_finalize.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||
_moe_unpermute_and_reduce)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input)
|
||||
|
||||
|
||||
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
per_channel_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.per_channel_quant = per_channel_quant
|
||||
self.block_shape = block_shape
|
||||
self.quant_dtype = quant_dtype
|
||||
|
||||
def max_num_tokens_per_rank(self) -> Optional[int]:
|
||||
return None
|
||||
|
||||
def topk_indices_dtype(self) -> Optional[torch.dtype]:
|
||||
return None
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
# TODO: this only works for topK=1, will need to update for topK>1
|
||||
assert topk == 1, \
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
a1.mul_(topk_weights.to(a1.dtype))
|
||||
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale,
|
||||
self.quant_dtype,
|
||||
self.per_channel_quant,
|
||||
self.block_shape)
|
||||
|
||||
return a1q, a1q_scale, None, None, None
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> None:
|
||||
_moe_unpermute_and_reduce(output, fused_expert_output, None,
|
||||
topk_weights, apply_router_weight_on_input)
|
||||
Reference in New Issue
Block a user