[gpt-oss] Add gpt-oss mxfp4 support
This commit is contained in:
@@ -4,8 +4,12 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Optional
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize)
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
_config: Optional[dict[str, Any]] = None
|
||||
@@ -26,8 +30,12 @@ def get_config() -> Optional[dict[str, Any]]:
|
||||
|
||||
__all__ = [
|
||||
"FusedMoE",
|
||||
"FusedMoEConfig",
|
||||
"FusedMoEMethodBase",
|
||||
"FusedMoeWeightScaleSupported",
|
||||
"FusedMoEPermuteExpertsUnpermute",
|
||||
"FusedMoEActivationFormat",
|
||||
"FusedMoEPrepareAndFinalize",
|
||||
"override_config",
|
||||
"get_config",
|
||||
]
|
||||
@@ -36,11 +44,21 @@ if HAS_TRITON:
|
||||
# import to register the custom ops
|
||||
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
|
||||
import vllm.model_executor.layers.fused_moe.fused_moe # noqa
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
||||
BatchedTritonOrDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
cutlass_moe_fp4, cutlass_moe_fp8)
|
||||
CutlassExpertsFp8, cutlass_moe_fp4, cutlass_moe_fp8)
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
DeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
TritonExperts, fused_experts, fused_moe, fused_topk,
|
||||
get_config_file_name, grouped_topk)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts)
|
||||
|
||||
__all__ += [
|
||||
"fused_moe",
|
||||
@@ -50,5 +68,11 @@ if HAS_TRITON:
|
||||
"grouped_topk",
|
||||
"cutlass_moe_fp8",
|
||||
"cutlass_moe_fp4",
|
||||
"CutlassExpertsFp8",
|
||||
"TritonExperts",
|
||||
"BatchedTritonExperts",
|
||||
"DeepGemmExperts",
|
||||
"BatchedDeepGemmExperts",
|
||||
"TritonOrDeepGemmExperts",
|
||||
"BatchedTritonOrDeepGemmExperts",
|
||||
]
|
||||
|
||||
490
vllm/model_executor/layers/fused_moe/config.py
Normal file
490
vllm/model_executor/layers/fused_moe/config.py
Normal file
@@ -0,0 +1,490 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import (QuantizationArgs,
|
||||
QuantizationStrategy,
|
||||
QuantizationType)
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.utils import cdiv
|
||||
# from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_quant_config_quantization_args(
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
prop_name: str,
|
||||
) -> Optional[QuantizationArgs]:
|
||||
if (quant_config is not None and hasattr(quant_config, 'target_scheme_map')
|
||||
and "Linear" in quant_config.target_scheme_map and
|
||||
"input_activations" in quant_config.target_scheme_map["Linear"]):
|
||||
return quant_config.target_scheme_map["Linear"].get(prop_name)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_quant_config_input_quant(
|
||||
quant_config: Optional[QuantizationConfig]
|
||||
) -> Optional[QuantizationArgs]:
|
||||
return _get_quant_config_quantization_args(quant_config,
|
||||
"input_activations")
|
||||
|
||||
|
||||
def get_quant_config_weight_quant(
|
||||
quant_config: Optional[QuantizationConfig]
|
||||
) -> Optional[QuantizationArgs]:
|
||||
return _get_quant_config_quantization_args(quant_config, "weights")
|
||||
|
||||
|
||||
# TODO (bnell): use scalar_type instead of bools?
|
||||
def get_config_quant_dtype(
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
use_mxfp4_w4a4: bool,
|
||||
) -> Union[None, torch.dtype, str]:
|
||||
if use_fp8_w8a8:
|
||||
return torch.float8_e4m3fn
|
||||
elif use_int8_w8a8:
|
||||
return torch.int8
|
||||
elif use_mxfp4_w4a4:
|
||||
return "mxfp4"
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusedMoEQuantConfig:
|
||||
# The post quantization activation type.
|
||||
quant_dtype: Optional[torch.dtype] = None
|
||||
per_act_token_quant: bool = False
|
||||
per_out_ch_quant: bool = False
|
||||
block_shape: Optional[list[int]] = None
|
||||
|
||||
# TODO: add col major flag?
|
||||
# add detailed quant info for input, intermediates, weights, etc?
|
||||
|
||||
def __post_init__(self):
|
||||
assert (not self.per_act_token_quant
|
||||
or self.block_shape is None), "illegal quantization"
|
||||
|
||||
@property
|
||||
def is_quantized(self) -> bool:
|
||||
return self.quant_dtype is not None
|
||||
|
||||
@property
|
||||
def is_per_act_token(self) -> bool:
|
||||
return self.per_act_token_quant
|
||||
|
||||
@property
|
||||
def is_block_quantized(self) -> bool:
|
||||
return self.block_shape is not None
|
||||
|
||||
@property
|
||||
def is_per_tensor(self) -> bool:
|
||||
return not self.per_act_token_quant and self.block_shape is None
|
||||
|
||||
def scale_shape(
|
||||
self,
|
||||
max_tokens: int,
|
||||
hidden_dim: int,
|
||||
) -> Optional[tuple[int, int]]:
|
||||
if self.is_quantized:
|
||||
if self.is_block_quantized:
|
||||
assert self.block_shape is not None
|
||||
_, block_k = self.block_shape
|
||||
k_tiles = cdiv(hidden_dim, block_k)
|
||||
return (max_tokens, k_tiles)
|
||||
elif self.is_per_act_token:
|
||||
return (max_tokens, 1)
|
||||
else:
|
||||
return (1, 1)
|
||||
else:
|
||||
return None
|
||||
|
||||
def batched_scale_shape(
|
||||
self,
|
||||
num_experts: int,
|
||||
max_tokens: int,
|
||||
hidden_dim: int,
|
||||
) -> Optional[tuple[int, int, int]]:
|
||||
if self.is_quantized:
|
||||
scale_shape = self.scale_shape(max_tokens, hidden_dim)
|
||||
assert scale_shape is not None
|
||||
return (num_experts, *scale_shape)
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_act_token_quant: bool = False,
|
||||
per_out_ch_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> "FusedMoEQuantConfig":
|
||||
assert sum([
|
||||
int(flag) for flag in [
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16,
|
||||
]
|
||||
]) <= 1, "Quantization flags are mutually exclusive."
|
||||
|
||||
quant_dtype = get_config_quant_dtype(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
||||
)
|
||||
return FusedMoEQuantConfig(
|
||||
quant_dtype,
|
||||
per_act_token_quant,
|
||||
per_out_ch_quant,
|
||||
block_shape,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusedMoEParallelConfig:
|
||||
tp_size: int
|
||||
dp_size: int
|
||||
ep_size: int
|
||||
tp_rank: int
|
||||
dp_rank: int
|
||||
ep_rank: int
|
||||
|
||||
use_ep: bool # whether to use EP or not
|
||||
|
||||
@property
|
||||
def use_all2all_kernels(self):
|
||||
return self.dp_size > 1 and self.use_ep
|
||||
|
||||
@property
|
||||
def use_pplx_kernels(self):
|
||||
return (self.use_all2all_kernels
|
||||
and envs.VLLM_ALL2ALL_BACKEND == "pplx")
|
||||
|
||||
@property
|
||||
def use_deepep_ht_kernels(self):
|
||||
return (self.use_all2all_kernels
|
||||
and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput")
|
||||
|
||||
@property
|
||||
def use_deepep_ll_kernels(self):
|
||||
return (self.use_all2all_kernels
|
||||
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
|
||||
|
||||
@property
|
||||
def use_flashinfer_cutlass_kernels(self):
|
||||
# return (envs.VLLM_USE_FLASHINFER_MOE_FP4
|
||||
# and has_flashinfer_cutlass_fused_moe()
|
||||
# and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def make(tp_size_: int, dp_size_: int,
|
||||
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
|
||||
"""
|
||||
Determine MoE parallel configuration. Based on the input `tp_size_`,
|
||||
`dp_size_` and vllm's parallel config, determine what
|
||||
level's of parallelism to use in the fused moe layer.
|
||||
|
||||
Args:
|
||||
tp_size_ (int): `tp_size` passed into the FusedMoE constructor.
|
||||
dp_size_ (int): `dp_size` passed into the FusedMoE constructor.
|
||||
vllm_parallel_config (ParallelConfig): vLLM's parallel config
|
||||
object which contains the `enable_expert_parallel` flag.
|
||||
|
||||
Examples:
|
||||
When there is no parallelism requested,
|
||||
i.e. `tp_size_` = `dp_size_` = 1, we simply return the sizes
|
||||
unaltered and the ranks set to 0.
|
||||
|
||||
Expert Parallelism is considered only when either `dp_size_` or
|
||||
`tp_size_` is non trivial.
|
||||
|
||||
When TP = 2, DP = 1 and EP = False, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} //
|
||||
legend : {size, rank}
|
||||
- device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0}
|
||||
- Comment : Tensors are sharded across 2 devices.
|
||||
|
||||
When TP = 1, DP = 2 and EP = False, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0}
|
||||
- device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0}
|
||||
- Comment: There are 2 engine instances and the tensors are sharded
|
||||
across 2 decvices.
|
||||
|
||||
When TP = 2, DP = 2 and EP = False, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0}
|
||||
- device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0}
|
||||
- device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0}
|
||||
- device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0}
|
||||
- Comment: There are 2 engine instances and the tensors are sharded
|
||||
across 4 devices.
|
||||
|
||||
When, TP = 2, DP = 1 and EP = True, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0}
|
||||
- device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1}
|
||||
- Comment: The experts are split between the 2 devices.
|
||||
|
||||
When, TP = 1, DP = 2 and EP = True, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0}
|
||||
- device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1}
|
||||
- Comment: There are 2 engine instances and the experts are split
|
||||
between the 2 devices.
|
||||
|
||||
When TP = 2, DP = 2 and EP = True, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0}
|
||||
- device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1}
|
||||
- device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2}
|
||||
- device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3}
|
||||
- Comment: There are 2 engine instances and the experts are split
|
||||
between the 4 devices.
|
||||
"""
|
||||
|
||||
def flatten_tp_across_dp(dp_rank: int):
|
||||
tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank()
|
||||
# There are actually dp_size_ * tp_size_ devices. Update tp_size
|
||||
# and tp_rank so we shard across all devices.
|
||||
tp_size = dp_size_ * tp_size_
|
||||
tp_rank = dp_rank * tp_size_ + tp_rank
|
||||
return tp_size, tp_rank
|
||||
|
||||
use_ep = (dp_size_ * tp_size_ > 1
|
||||
and vllm_parallel_config.enable_expert_parallel)
|
||||
|
||||
dp_size = dp_size_
|
||||
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
|
||||
tp_size, tp_rank = flatten_tp_across_dp(dp_rank)
|
||||
|
||||
if not use_ep:
|
||||
return FusedMoEParallelConfig(tp_size=tp_size,
|
||||
tp_rank=tp_rank,
|
||||
dp_size=dp_size,
|
||||
dp_rank=dp_rank,
|
||||
ep_size=1,
|
||||
ep_rank=0,
|
||||
use_ep=False)
|
||||
# DP + EP / TP + EP / DP + TP + EP
|
||||
assert use_ep
|
||||
# In EP, each device owns a set of experts fully. There is no tensor
|
||||
# parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that.
|
||||
ep_size = tp_size
|
||||
ep_rank = tp_rank
|
||||
return FusedMoEParallelConfig(tp_size=1,
|
||||
tp_rank=0,
|
||||
dp_size=dp_size,
|
||||
dp_rank=dp_rank,
|
||||
ep_size=ep_size,
|
||||
ep_rank=ep_rank,
|
||||
use_ep=True)
|
||||
|
||||
|
||||
# Adapted from pplx-kernels tests/all_to_all_utils.py
|
||||
@dataclass
|
||||
class FusedMoEConfig:
|
||||
num_experts: int
|
||||
experts_per_token: int
|
||||
hidden_dim: int
|
||||
|
||||
num_local_experts: int
|
||||
moe_parallel_config: FusedMoEParallelConfig
|
||||
|
||||
# The activation type.
|
||||
in_dtype: torch.dtype
|
||||
|
||||
quant_config: Optional[FusedMoEQuantConfig] = None
|
||||
|
||||
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
|
||||
|
||||
has_bias: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dp_size > 1:
|
||||
logger.debug_once("Using FusedMoEConfig::max_num_tokens=%d",
|
||||
self.max_num_tokens)
|
||||
|
||||
assert self.max_num_tokens > 0
|
||||
|
||||
@property
|
||||
def quant_dtype(self) -> Optional[torch.dtype]:
|
||||
if self.quant_config is not None:
|
||||
return self.quant_config.quant_dtype
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def block_shape(self) -> Optional[list[int]]:
|
||||
if self.quant_config is not None:
|
||||
return self.quant_config.block_shape
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def per_act_token_quant(self) -> bool:
|
||||
if self.quant_config is not None:
|
||||
return self.quant_config.per_act_token_quant
|
||||
else:
|
||||
return False
|
||||
|
||||
@property
|
||||
def per_out_ch_quant(self) -> bool:
|
||||
if self.quant_config is not None:
|
||||
return self.quant_config.per_out_ch_quant
|
||||
else:
|
||||
return False
|
||||
|
||||
@property
|
||||
def tp_size(self):
|
||||
return self.moe_parallel_config.tp_size
|
||||
|
||||
@property
|
||||
def dp_size(self):
|
||||
return self.moe_parallel_config.dp_size
|
||||
|
||||
@property
|
||||
def ep_size(self):
|
||||
return self.moe_parallel_config.ep_size
|
||||
|
||||
@property
|
||||
def tp_rank(self):
|
||||
return self.moe_parallel_config.tp_rank
|
||||
|
||||
@property
|
||||
def dp_rank(self):
|
||||
return self.moe_parallel_config.dp_rank
|
||||
|
||||
@property
|
||||
def ep_rank(self):
|
||||
return self.moe_parallel_config.ep_rank
|
||||
|
||||
@property
|
||||
def use_ep(self):
|
||||
return self.moe_parallel_config.use_ep
|
||||
|
||||
@property
|
||||
def use_pplx_kernels(self):
|
||||
return self.moe_parallel_config.use_pplx_kernels
|
||||
|
||||
@property
|
||||
def use_deepep_ht_kernels(self):
|
||||
return self.moe_parallel_config.use_deepep_ht_kernels
|
||||
|
||||
@property
|
||||
def use_deepep_ll_kernels(self):
|
||||
return self.moe_parallel_config.use_deepep_ll_kernels
|
||||
|
||||
@property
|
||||
def use_flashinfer_cutlass_kernels(self):
|
||||
return self.moe_parallel_config.use_flashinfer_cutlass_kernels
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
num_experts: int,
|
||||
experts_per_token: int,
|
||||
hidden_dim: int,
|
||||
num_local_experts: int,
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
in_dtype: torch.dtype,
|
||||
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE,
|
||||
quant_config: Optional[Union[FusedMoEQuantConfig,
|
||||
QuantizationConfig]] = None,
|
||||
has_bias: bool = False,
|
||||
) -> "FusedMoEConfig":
|
||||
|
||||
_quant_config: Optional[FusedMoEQuantConfig] = None
|
||||
|
||||
if quant_config is not None and isinstance(quant_config,
|
||||
QuantizationConfig):
|
||||
if hasattr(quant_config, 'weight_block_size'):
|
||||
block_shape = quant_config.weight_block_size
|
||||
else:
|
||||
block_shape = None
|
||||
per_act_token_quant = False
|
||||
per_out_ch_quant = False
|
||||
quant_dtype: Optional[torch.dtype] = None
|
||||
|
||||
input_quant = get_quant_config_input_quant(quant_config)
|
||||
weight_quant = get_quant_config_weight_quant(quant_config)
|
||||
|
||||
if input_quant is not None:
|
||||
per_act_token_quant = (input_quant.strategy
|
||||
== QuantizationStrategy.TOKEN
|
||||
if input_quant is not None else False)
|
||||
|
||||
if input_quant.num_bits == 8:
|
||||
if input_quant.type == QuantizationType.FLOAT:
|
||||
quant_dtype = torch.float8_e4m3fn
|
||||
elif input_quant.type == QuantizationType.INT:
|
||||
quant_dtype = torch.int8
|
||||
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
if quant_dtype is None and isinstance(quant_config, Fp8Config):
|
||||
quant_dtype = torch.float8_e4m3fn
|
||||
|
||||
from vllm.model_executor.layers.quantization.modelopt import (
|
||||
ModelOptNvFp4Config)
|
||||
if quant_dtype is None and isinstance(quant_config,
|
||||
ModelOptNvFp4Config):
|
||||
quant_dtype = torch.uint8
|
||||
|
||||
if weight_quant is not None:
|
||||
per_out_ch_quant = (
|
||||
weight_quant.strategy == QuantizationStrategy.CHANNEL)
|
||||
|
||||
if quant_dtype is not None:
|
||||
_quant_config = FusedMoEQuantConfig(
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=per_out_ch_quant,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
else:
|
||||
_quant_config = FusedMoEQuantConfig()
|
||||
if moe_parallel_config.dp_size > 1:
|
||||
logger.warning_once("MoE DP setup unable to determine "
|
||||
"quantization scheme or unsupported "
|
||||
"quantization type. This model will "
|
||||
"not run with DP enabled.")
|
||||
else:
|
||||
_quant_config = quant_config
|
||||
|
||||
return FusedMoEConfig(
|
||||
num_experts=num_experts,
|
||||
experts_per_token=experts_per_token,
|
||||
hidden_dim=hidden_dim,
|
||||
num_local_experts=num_local_experts,
|
||||
moe_parallel_config=moe_parallel_config,
|
||||
in_dtype=in_dtype,
|
||||
quant_config=_quant_config,
|
||||
max_num_tokens=max_num_tokens,
|
||||
has_bias=has_bias,
|
||||
)
|
||||
@@ -1503,8 +1503,8 @@ def fused_experts_impl(
|
||||
qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
|
||||
A=curr_hidden_states,
|
||||
A_scale=a1_scale,
|
||||
qtype=qtype,
|
||||
per_channel_quant=per_channel_quant,
|
||||
quant_dtype=qtype,
|
||||
per_act_token_quant=per_channel_quant,
|
||||
block_shape=block_shape)
|
||||
|
||||
invoke_fused_moe_kernel(qcurr_hidden_states,
|
||||
@@ -1562,8 +1562,8 @@ def fused_experts_impl(
|
||||
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||
A=intermediate_cache2,
|
||||
A_scale=a2_scale,
|
||||
qtype=qtype,
|
||||
per_channel_quant=per_channel_quant,
|
||||
quant_dtype=qtype,
|
||||
per_act_token_quant=per_channel_quant,
|
||||
block_shape=block_shape)
|
||||
|
||||
invoke_fused_moe_kernel(qintermediate_cache2,
|
||||
|
||||
@@ -0,0 +1,248 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate)
|
||||
from vllm.model_executor.layers.fused_moe.utils import extract_required_args
|
||||
from vllm.utils import has_triton_kernels
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if has_triton_kernels():
|
||||
try:
|
||||
import triton_kernels.swiglu
|
||||
from triton_kernels.matmul_ogs import (FnSpecs, FusedActivation,
|
||||
matmul_ogs)
|
||||
from triton_kernels.routing import routing
|
||||
except ModuleNotFoundError:
|
||||
logger.error(
|
||||
"Failed to import Triton kernels. Please make sure your triton "
|
||||
"version is compatible.")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from triton_kernels.matmul_ogs import PrecisionConfig
|
||||
|
||||
|
||||
def triton_kernel_moe_forward(
|
||||
hidden_states: torch.Tensor,
|
||||
w1, # Tensor or triton_kernels.Tensor
|
||||
w2, # Tensor or triton_kernels.Tensor
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
w1_precision: Optional["PrecisionConfig"] = None,
|
||||
w2_precision: Optional["PrecisionConfig"] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
routing_data, gather_idx, scatter_idx = routing(gating_output,
|
||||
topk,
|
||||
sm_first=not renormalize)
|
||||
|
||||
return triton_kernel_fused_experts(
|
||||
None,
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
routing_data,
|
||||
gather_idx,
|
||||
scatter_idx,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
per_channel_quant=per_channel_quant,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
w1_precision=w1_precision,
|
||||
w2_precision=w2_precision,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape)
|
||||
|
||||
|
||||
# This is a triton implementation of the fused_experts function
|
||||
def triton_kernel_fused_experts(
|
||||
output_tensor: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1, # Tensor or triton_kernels.Tensor
|
||||
w2, # Tensor or triton_kernels.Tensor
|
||||
routing_data, # RoutingData
|
||||
gather_indx, # GatherIndx
|
||||
scatter_indx, # ScatterIndx
|
||||
activation: str = "silu",
|
||||
swiglu_alpha: float = 1.702,
|
||||
swiglu_limit: float = 7.0,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
w1_precision: Optional["PrecisionConfig"] = None,
|
||||
w2_precision: Optional["PrecisionConfig"] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# type check, uint8 means mxfp4
|
||||
assert hidden_states.dtype == torch.bfloat16
|
||||
assert w1_bias is None or w1_bias.dtype == torch.float32
|
||||
assert w2_bias is None or w2_bias.dtype == torch.float32
|
||||
|
||||
# Shape check, only check non-mxfp4
|
||||
assert hidden_states.shape[-1] == w1.shape[-2]
|
||||
assert w2.shape[-1] == w1.shape[1]
|
||||
|
||||
E, _, N = w1.shape
|
||||
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
|
||||
act = FusedActivation(
|
||||
FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")),
|
||||
(swiglu_alpha, swiglu_limit), 2)
|
||||
gammas = routing_data.gate_scal if routing_data else None
|
||||
|
||||
intermediate_cache1 = matmul_ogs(
|
||||
hidden_states,
|
||||
w1,
|
||||
w1_bias,
|
||||
routing_data,
|
||||
gather_indx=gather_indx,
|
||||
precision_config=w1_precision,
|
||||
gammas=gammas if apply_router_weight_on_input else None,
|
||||
fused_activation=act)
|
||||
|
||||
intermediate_cache3 = matmul_ogs(
|
||||
intermediate_cache1,
|
||||
w2,
|
||||
w2_bias,
|
||||
routing_data,
|
||||
scatter_indx=scatter_indx,
|
||||
precision_config=w2_precision,
|
||||
gammas=None if apply_router_weight_on_input else gammas,
|
||||
y=output_tensor,
|
||||
)
|
||||
return intermediate_cache3
|
||||
|
||||
|
||||
class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config,
|
||||
max_num_tokens: int,
|
||||
num_dispatchers: int,
|
||||
w1_precision: "PrecisionConfig",
|
||||
w2_precision: "PrecisionConfig",
|
||||
):
|
||||
super().__init__(quant_config)
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.num_dispatchers = num_dispatchers
|
||||
self.w1_precision = w1_precision
|
||||
self.w2_precision = w2_precision
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self
|
||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||
return (mk.FusedMoEActivationFormat.BatchedExperts,
|
||||
mk.FusedMoEActivationFormat.BatchedExperts)
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
return TopKWeightAndReduceDelegate()
|
||||
|
||||
def workspace_shapes(
|
||||
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
|
||||
topk: int, global_num_experts: int, local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
# workspace are allocated inside the kernel
|
||||
assert a.dim() == 2
|
||||
num_dp = self.num_dispatchers
|
||||
num_experts = local_num_experts
|
||||
max_num_tokens = self.max_num_tokens
|
||||
workspace2 = (0, 0, 0)
|
||||
output = (num_experts, max_num_tokens * num_dp, N)
|
||||
return (output, workspace2, output, a.dtype)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
extra_expert_args: Optional[dict[str, Any]],
|
||||
):
|
||||
w1_bias, w2_bias = (extract_required_args(extra_expert_args,
|
||||
["w1_bias", "w2_bias"]))
|
||||
|
||||
return triton_kernel_fused_experts(
|
||||
output,
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=False,
|
||||
use_fp8_w8a8=False,
|
||||
per_channel_quant=False,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
w1_precision=self.w1_precision,
|
||||
w2_precision=self.w2_precision,
|
||||
a1_scale=a1q_scale,
|
||||
a2_scale=a2_scale)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,10 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from math import prod
|
||||
from typing import Any, Optional, final
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable
|
||||
_resize_cache, count_expert_num_tokens)
|
||||
from vllm.utils import cdiv
|
||||
|
||||
#
|
||||
# This file defines a set of base classes used to make MoE kernels more modular.
|
||||
# The goal is to be able to utilize different communication mechanisms with
|
||||
@@ -14,7 +23,7 @@ import torch
|
||||
#
|
||||
# [Router] → [Quantize-Dispatch] → [Permute-Experts-Unpermute] → [Combine]
|
||||
#
|
||||
# Each component will be independent of the others except for
|
||||
# Each component will be independent of (but may inform) the others except for
|
||||
# [Quantize-Dispatch] and `[Combine] (see below). The components can then be
|
||||
# mixed and matched with so that DP+EP can be supported easily for multiple
|
||||
# MoE kernel implementations.
|
||||
@@ -23,13 +32,19 @@ import torch
|
||||
# * FusedMoEPrepareAndFinalize - an abstract base class for preparation of MoE
|
||||
# inputs (e.g. quantization, distribution) and finalization of Moe outputs.
|
||||
# The prepare method must take care of any needed quantization and the
|
||||
# finalize method must apply weights and do the final reduction of the output.
|
||||
# finalize method, informed by the FusedMoEPermuteExpertsUnpermute method,
|
||||
# may apply weights and/or do the final reduction of the output.
|
||||
# * FusedMoEPermuteExpertsUnpermute - an abstract base class for the main fused
|
||||
# MoE operation. One important feature to note is that this class does not
|
||||
# apply topk weights or reduce the final output.
|
||||
# MoE operation, i.e matmul + act_mul + optionally quant + matmul.
|
||||
# Some FusedMoEPermuteExpertsUnpermute implementations may choose to do
|
||||
# the weight application and/or reduction. The class communicates this
|
||||
# to [Finalize] via a TopKWeightAndReduce object.
|
||||
# * FusedMoEModularKernel - an interface class that combines a
|
||||
# FusedMoEPrepareAndFinalize and a FusedMoEPermuteExpertsUnpermute to
|
||||
# provide the standard fused MoE kernel interface.
|
||||
# * TopKWeightAndReduce - A TopKWeightAndReduce implementation chosen
|
||||
# by the FusedMoEPermuteExpertsUnpermute implementation that is passed
|
||||
# on to [Finalize].
|
||||
#
|
||||
# [Quantize-Prepare] and [Finalize] functionality are bundled into a single
|
||||
# class `FusedMoEPrepareAndFinalize` since they could use collective
|
||||
@@ -77,6 +92,56 @@ def _moe_problem_size(
|
||||
return E, M, N, K, topk
|
||||
|
||||
|
||||
class FusedMoEActivationFormat(Enum):
|
||||
"""
|
||||
The standard activation format (num_tokens, hidden dim).
|
||||
"""
|
||||
Standard = "standard",
|
||||
"""
|
||||
The batched experts format (num experts, max tokens per expert, hidden dim)
|
||||
"""
|
||||
BatchedExperts = "batched_experts",
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpertTokensMetadata:
|
||||
"""
|
||||
Metadata regarding expert-token routing.
|
||||
"""
|
||||
expert_num_tokens: torch.Tensor
|
||||
expert_num_tokens_cpu: Optional[torch.Tensor]
|
||||
|
||||
@staticmethod
|
||||
def make_from_list(expert_num_tokens_list: list[int],
|
||||
device: str) -> "ExpertTokensMetadata":
|
||||
expert_num_tokens_cpu = torch.tensor(expert_num_tokens_list,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
return ExpertTokensMetadata(
|
||||
expert_num_tokens=expert_num_tokens_cpu.to(device,
|
||||
non_blocking=True),
|
||||
expert_num_tokens_cpu=expert_num_tokens_cpu)
|
||||
|
||||
|
||||
class TopKWeightAndReduce(ABC):
|
||||
"""
|
||||
An abstract base class for weight application and reduction implementations.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, output: Optional[torch.Tensor],
|
||||
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool) -> torch.Tensor:
|
||||
"""
|
||||
Apply topk_weights to the fused_experts_outputs and/or reduce.
|
||||
If an output tensor is not passed, it will be created in the
|
||||
function.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
|
||||
class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
An abstract base class for the [Quantize-Prepare] and [Finalize] steps
|
||||
@@ -85,17 +150,15 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor, num_experts: int,
|
||||
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
extra_prepare_args: Optional[dict[str, Any]]
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[ExpertTokensMetadata], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
"""
|
||||
Perform any quantization (and/or) dispatching needed
|
||||
for this kernel.
|
||||
@@ -114,22 +177,20 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
Returns a tuple of:
|
||||
- quantized + dispatched a.
|
||||
- quantized + dispatched a1_scales.
|
||||
- Optional tensor as big as number of local experts that contains the
|
||||
number of tokens assigned to each local expert.
|
||||
- Optional ExpertTokensMetadata containing gpu/cpu tensors
|
||||
as big as the number of local experts with the information about the
|
||||
number of tokens assigned to each local expert.
|
||||
- Optional dispatched expert topk IDs
|
||||
- Optional dispatched expert topk weight
|
||||
- Optional dispatched expert topk weight
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> None:
|
||||
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: TopKWeightAndReduce,
|
||||
extra_finalize_args: Optional[dict[str, Any]]) -> None:
|
||||
"""
|
||||
Perform any combine plus apply weights and perform a reduction on the
|
||||
fused experts output.
|
||||
@@ -140,6 +201,17 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
- topk_ids: The topk_ids.
|
||||
- apply_router_weight_on_input: When False, apply the weights to
|
||||
fused_expert_output.
|
||||
- weight_and_reduce_impl: An optional TopKWeightAndReduce
|
||||
implementation.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def activation_format(self) -> FusedMoEActivationFormat:
|
||||
"""
|
||||
A property indicating the output format of the activations for the
|
||||
'prepare' method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -159,11 +231,15 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
Some PrepareFinalize All2All implementations are batched. Meaning,
|
||||
they can processes only as set of tokens at a time. This
|
||||
function returns the batch size i.e the maximum number of tokens
|
||||
the implementation can process at a time.
|
||||
the implementation can process at a time.
|
||||
Return None if there are no such restrictions.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def num_dispatchers(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
"""
|
||||
@@ -171,6 +247,57 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
above.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: Optional[FusedMoEQuantConfig],
|
||||
):
|
||||
if quant_config is not None:
|
||||
self.quant_config = quant_config
|
||||
else:
|
||||
self.quant_config = FusedMoEQuantConfig()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def activation_formats(
|
||||
self) -> tuple[FusedMoEActivationFormat, FusedMoEActivationFormat]:
|
||||
"""
|
||||
A property which is a tuple of the input and output activation formats
|
||||
for the 'apply' method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def quant_dtype(self) -> Optional[torch.dtype]:
|
||||
return self.quant_config.quant_dtype
|
||||
|
||||
@property
|
||||
def block_shape(self) -> Optional[list[int]]:
|
||||
return self.quant_config.block_shape
|
||||
|
||||
@property
|
||||
def per_act_token_quant(self) -> bool:
|
||||
return self.quant_config.per_act_token_quant
|
||||
|
||||
@property
|
||||
def per_out_ch_quant(self) -> bool:
|
||||
return self.quant_config.per_out_ch_quant
|
||||
|
||||
# TODO (bnell): make this return a CHUNK_SIZE or None instead?
|
||||
@abstractmethod
|
||||
def supports_chunking(self) -> bool:
|
||||
"""
|
||||
A flag indicating whether or not this class supports activation
|
||||
chunking.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def supports_expert_map(self) -> bool:
|
||||
"""
|
||||
A flag indicating whether or not this class supports expert maps
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def workspace_shapes(
|
||||
self,
|
||||
@@ -180,20 +307,25 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
) -> tuple[int, int, torch.dtype]:
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
"""
|
||||
Compute the number of elements for the temporary outputs of the two
|
||||
gemms and activation in the fused expert function. Since the
|
||||
gemms are independent, the workspace for the first gemm can be shared
|
||||
with the workspace for the last gemm.
|
||||
Compute the shapes for the temporary and final outputs of the two gemms
|
||||
and activation in the fused expert function. Since the gemms are
|
||||
independent, the workspace for the first gemm can be shared with the
|
||||
workspace for the last gemm.
|
||||
|
||||
Returns a tuple of:
|
||||
- Number of workspace13 elements: must be large enough to hold the
|
||||
- workspace13 shape tuple: must be large enough to hold the
|
||||
result of either expert gemm.
|
||||
- Number of workspace2 elements: must be large enough to hold the
|
||||
- workspace2 shape tuple: must be large enough to hold the
|
||||
result of the activation function.
|
||||
- output shape tuple: must be exact size of the final gemm output.
|
||||
- Workspace type: The dtype to use for the workspace tensors.
|
||||
- Note: in order for activation chunking to work, the first dimension
|
||||
of each tuple must be the number of tokens.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -207,12 +339,21 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
else:
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||
|
||||
def enable_chunking(self):
|
||||
return envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and \
|
||||
self.supports_chunking()
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> TopKWeightAndReduce:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
@@ -225,17 +366,22 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
extra_expert_args: Optional[dict[str, Any]],
|
||||
):
|
||||
"""
|
||||
This function computes the intermediate result of a Mixture of Experts
|
||||
(MoE) layer using two sets of weights, w1 and w2.
|
||||
|
||||
Parameters:
|
||||
- output: (torch.Tensor): The unweighted, unreduced output tensor.
|
||||
- hidden_states: (torch.Tensor): The (quantized) input tensor to the MoE
|
||||
layer.
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- topk_weights: A map of row to expert weights. Some implementations
|
||||
choose to do weight application.
|
||||
- topk_ids (torch.Tensor): A map of row to expert id.
|
||||
- activation (str): The activation function to apply after the first
|
||||
MoE layer.
|
||||
@@ -257,15 +403,28 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
must be large enough to hold output of either MoE gemm.
|
||||
- workspace2 (torch.Tensor): A scratch tensor used for the activation
|
||||
function.
|
||||
- expert_num_tokens: An optional tensor containing the number of tokens
|
||||
assigned to each expert when using batched experts format input.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The unweighted, unreduced output tensor
|
||||
- expert_tokens_meta (Optional[ExpertTokensMetadata]) - An optional
|
||||
ExpertTokensMetadata object containing gpu/cpu tensors
|
||||
as big as the number of local experts with the information about the
|
||||
number of tokens assigned to each local expert.
|
||||
- apply_router_weight_on_input: True if router weights are already
|
||||
applied on the input. This is relevant if the implementation
|
||||
chooses to do weight application.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _chunk_scales(scales: Optional[torch.Tensor], start: int,
|
||||
end: int) -> Optional[torch.Tensor]:
|
||||
if scales is not None:
|
||||
if scales.numel() == 1:
|
||||
return scales
|
||||
else:
|
||||
return scales[start:end]
|
||||
return None
|
||||
|
||||
|
||||
@final
|
||||
class FusedMoEModularKernel(torch.nn.Module):
|
||||
"""
|
||||
This class combines a FusedMoEPrepareAndFinalize instance and
|
||||
@@ -287,46 +446,56 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.prepare_finalize = prepare_finalize
|
||||
self.fused_experts = fused_experts
|
||||
assert prepare_finalize.activation_format == \
|
||||
fused_experts.activation_formats[0], (
|
||||
f"{prepare_finalize.__class__.__name__}."
|
||||
f"{prepare_finalize.activation_format} == "
|
||||
f"{fused_experts.__class__.__name__}."
|
||||
f"{fused_experts.activation_formats[0]}")
|
||||
|
||||
def _do_fused_experts(
|
||||
self,
|
||||
a1: torch.Tensor, # input to forward fn
|
||||
a1q: torch.Tensor, # output of prepare fn
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_num_tokens: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
self, fused_out: Optional[torch.Tensor], a1: torch.Tensor,
|
||||
a1q: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
activation: str, global_num_experts: int, local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
extra_expert_args: Optional[dict[str, Any]]) -> torch.Tensor:
|
||||
|
||||
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
|
||||
|
||||
# Use a1 here to decipher the correct workspace datatype
|
||||
workspace13_shape, workspace2_shape, workspace_dtype = (
|
||||
self.fused_experts.workspace_shapes(a1, a1q, M, N, K, top_k,
|
||||
global_num_experts))
|
||||
(workspace13_shape, workspace2_shape, fused_out_shape,
|
||||
workspace_dtype) = self.fused_experts.workspace_shapes(
|
||||
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
|
||||
expert_tokens_meta)
|
||||
|
||||
# We can reuse the memory between cache1 and cache3 because by the time
|
||||
# we need cache3, we're done with cache1
|
||||
workspace13 = torch.zeros(workspace13_shape,
|
||||
# We can reuse the memory between cache1 and cache3 because by the
|
||||
# time we need cache3, we're done with cache1.
|
||||
workspace13 = torch.empty(prod(workspace13_shape),
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
workspace2 = torch.zeros(workspace2_shape,
|
||||
workspace2 = torch.empty(prod(workspace2_shape),
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
|
||||
fused_out = self.fused_experts.apply(
|
||||
assert fused_out is None or fused_out.shape == fused_out_shape, (
|
||||
f"fused_out {fused_out.shape} but expected {fused_out_shape}")
|
||||
if fused_out is None:
|
||||
# reuse workspace13 for the output
|
||||
fused_out = _resize_cache(workspace13, fused_out_shape)
|
||||
|
||||
self.fused_experts.apply(
|
||||
fused_out,
|
||||
a1q,
|
||||
w1,
|
||||
w2,
|
||||
topk_ids,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
@@ -338,8 +507,162 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
a2_scale=a2_scale,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_num_tokens=expert_num_tokens,
|
||||
)
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
extra_expert_args=extra_expert_args)
|
||||
|
||||
return fused_out
|
||||
|
||||
def _maybe_chunk_fused_experts(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1q: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
extra_expert_args: Optional[dict[str, Any]],
|
||||
) -> torch.Tensor:
|
||||
|
||||
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
|
||||
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
num_chunks = cdiv(M, CHUNK_SIZE)
|
||||
|
||||
if not self.fused_experts.supports_chunking() or num_chunks == 1:
|
||||
return self._do_fused_experts(
|
||||
fused_out=None,
|
||||
a1=a1,
|
||||
a1q=a1q,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
extra_expert_args=extra_expert_args)
|
||||
|
||||
# Chunking required case
|
||||
assert num_chunks > 1
|
||||
|
||||
# Construct the entire output that can then be processed in chunks.
|
||||
(_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes(
|
||||
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
|
||||
expert_tokens_meta)
|
||||
fused_out = torch.empty(fused_out_shape,
|
||||
device=a1q.device,
|
||||
dtype=a1.dtype)
|
||||
|
||||
def slice_input_tensors(
|
||||
chunk_idx: int
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], torch.Tensor, torch.Tensor]:
|
||||
s = chunk_idx * CHUNK_SIZE
|
||||
e = min(s + CHUNK_SIZE, M)
|
||||
return (a1q[s:e], _chunk_scales(a1q_scale, s, e),
|
||||
_chunk_scales(a2_scale, s,
|
||||
e), topk_ids[s:e], topk_weights[s:e])
|
||||
|
||||
def slice_output_tensor(chunk_idx: int) -> torch.Tensor:
|
||||
assert fused_out.size(0) % M == 0, (
|
||||
f"fused_out shape {fused_out.shape} vs M {M}")
|
||||
factor = fused_out.size(0) // M
|
||||
out_chunk_size = CHUNK_SIZE * factor
|
||||
s = chunk_idx * out_chunk_size
|
||||
e = min(s + out_chunk_size, fused_out.size(0))
|
||||
return fused_out[s:e]
|
||||
|
||||
def slice_expert_tokens_metadata(
|
||||
full_expert_tokens_meta: ExpertTokensMetadata,
|
||||
chunk_topk_ids: torch.Tensor, local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor]) -> ExpertTokensMetadata:
|
||||
# The existing expert_num_tokens is for the entire a1q
|
||||
# input. Chunking forces recomputation of the number
|
||||
# of tokens assigned to each expert.
|
||||
c_expert_num_tokens = count_expert_num_tokens(
|
||||
chunk_topk_ids, local_num_experts, expert_map)
|
||||
|
||||
c_expert_num_tokens_cpu = None
|
||||
need_expert_num_tokens_cpu = (
|
||||
full_expert_tokens_meta.expert_num_tokens_cpu is not None)
|
||||
if need_expert_num_tokens_cpu:
|
||||
# This is blocking as some implementations need the count
|
||||
# on the CPU to determine appropriate input/out fused-moe
|
||||
# buffers
|
||||
c_expert_num_tokens_cpu = c_expert_num_tokens.to(
|
||||
"cpu", non_blocking=False)
|
||||
|
||||
return ExpertTokensMetadata(
|
||||
expert_num_tokens=c_expert_num_tokens,
|
||||
expert_num_tokens_cpu=c_expert_num_tokens_cpu)
|
||||
|
||||
m = None
|
||||
if extra_expert_args is not None and 'm' in extra_expert_args:
|
||||
m = extra_expert_args.get('m')
|
||||
|
||||
if extra_expert_args is not None:
|
||||
chunked_extra_expert_args = extra_expert_args
|
||||
else:
|
||||
chunked_extra_expert_args = {}
|
||||
|
||||
for chunk_idx in range(num_chunks):
|
||||
c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = (
|
||||
slice_input_tensors(chunk_idx))
|
||||
|
||||
c_expert_tokens_meta = None
|
||||
if expert_tokens_meta is not None:
|
||||
c_expert_tokens_meta = slice_expert_tokens_metadata(
|
||||
expert_tokens_meta, c_topk_ids, local_num_experts,
|
||||
expert_map)
|
||||
|
||||
s = chunk_idx * CHUNK_SIZE
|
||||
e = min(s + CHUNK_SIZE, M)
|
||||
|
||||
if m is not None:
|
||||
chunked_extra_expert_args['m'] = e - s
|
||||
self._do_fused_experts(
|
||||
fused_out=slice_output_tensor(chunk_idx),
|
||||
a1=a1,
|
||||
a1q=c_a1q,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=c_topk_weights,
|
||||
topk_ids=c_topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=c_a1q_scale,
|
||||
a2_scale=c_a2_scale,
|
||||
expert_tokens_meta=c_expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
extra_expert_args=chunked_extra_expert_args)
|
||||
|
||||
return fused_out
|
||||
|
||||
@@ -361,6 +684,9 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
extra_expert_args: Optional[dict] = None,
|
||||
extra_prepare_args: Optional[dict] = None,
|
||||
extra_finalize_args: Optional[dict] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets
|
||||
@@ -393,6 +719,12 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
- apply_router_weight_on_input (bool): When true, the topk weights are
|
||||
applied directly on the inputs. This is only applicable when topk is
|
||||
1.
|
||||
- extra_expert_args (Optional[dict]): Extra keyword arguments to pass to
|
||||
fused_experts.apply.
|
||||
- extra_prepare_args (Optional[dict]): Extra keyword arguments to pass
|
||||
to prepare.
|
||||
- extra_finalize_args (Optional[dict]): Extra keyword arguments to pass
|
||||
to finalize.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
@@ -401,19 +733,31 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
a1 = hidden_states
|
||||
output = a1 if inplace else torch.zeros_like(a1)
|
||||
|
||||
local_num_experts = w1.size(0)
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = w1.size(0)
|
||||
global_num_experts = local_num_experts
|
||||
|
||||
(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
|
||||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
||||
_expert_topk_weights) = self.prepare_finalize.prepare(
|
||||
a1, a1_scale, a2_scale, topk_weights, topk_ids,
|
||||
global_num_experts, expert_map, apply_router_weight_on_input)
|
||||
a1,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.quant_config,
|
||||
extra_prepare_args,
|
||||
)
|
||||
|
||||
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
|
||||
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
|
||||
topk_weights = (topk_weights if _expert_topk_weights is None else
|
||||
_expert_topk_weights)
|
||||
|
||||
fused_out = None
|
||||
|
||||
if a1q.numel() == 0:
|
||||
# This happens when none of the tokens from the all2all reach this
|
||||
# EP rank. Also, note that this is only relevant for CUDAGraph
|
||||
@@ -423,24 +767,31 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
# and can never run into the tensor.numel() == 0 case.
|
||||
fused_out = torch.empty_like(a1q).to(dtype=a1.dtype)
|
||||
else:
|
||||
fused_out = self._do_fused_experts(
|
||||
fused_out = self._maybe_chunk_fused_experts(
|
||||
a1=a1,
|
||||
a1q=a1q,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_num_tokens=expert_num_tokens,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale)
|
||||
a2_scale=a2_scale,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
extra_expert_args=extra_expert_args)
|
||||
|
||||
self.prepare_finalize.finalize(output, fused_out, topk_weights,
|
||||
topk_ids, apply_router_weight_on_input)
|
||||
self.prepare_finalize.finalize(
|
||||
output, fused_out, topk_weights, topk_ids,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.finalize_weight_and_reduce_impl(),
|
||||
extra_finalize_args)
|
||||
|
||||
return output
|
||||
|
||||
146
vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py
Normal file
146
vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
|
||||
|
||||
class TopKWeightAndReduceDelegate(mk.TopKWeightAndReduce):
|
||||
"""
|
||||
Useful in the case when some FusedMoEPermuteExpertsUnpermute
|
||||
implementation does not perform weight application and reduction
|
||||
but cannot address the needs of all the compatible PrepareAndFinalize
|
||||
implementations.
|
||||
For example, BatchedTritonExperts is compatible with both
|
||||
PplxPrepareAndFinalize and BatchedPrepareAndFinalize. PplxPrepareAndFinalize
|
||||
does the weight-application + reduction as part of the pplx combine kernel.
|
||||
But the BatchedPrepareAndFinalize needs an implementation. To facilitate
|
||||
this case, the BatchedTritonExperts could use TopKWeightAndReduceDelegate
|
||||
so the PrepareAndFinalize implementations could choose how to
|
||||
weight + reduce.
|
||||
"""
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, TopKWeightAndReduceDelegate)
|
||||
|
||||
def apply(self, output: Optional[torch.Tensor],
|
||||
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool) -> torch.Tensor:
|
||||
raise RuntimeError("The caller is expected to choose an appropriate "
|
||||
"TopKWeightAndReduce implementation.")
|
||||
|
||||
|
||||
class TopKWeightAndReduceNoOP(mk.TopKWeightAndReduce):
|
||||
"""
|
||||
The fused_experts outputs have already been weight applied and reduced.
|
||||
This implementation is a no-op.
|
||||
"""
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, TopKWeightAndReduceNoOP)
|
||||
|
||||
def apply(self, output: Optional[torch.Tensor],
|
||||
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool) -> torch.Tensor:
|
||||
# Weight application and reduction operations are already done.
|
||||
if output is None:
|
||||
return fused_expert_output
|
||||
|
||||
# MoEPrepareAndFinalizeNoEP needs the output to be in the `output`
|
||||
# tensor.
|
||||
assert output.size() == fused_expert_output.size(), (
|
||||
"output shape is expected to match the fused_expert_output shape. "
|
||||
f"But got output={output.size()}, "
|
||||
f"used_expert_output={fused_expert_output.size()}")
|
||||
output.copy_(fused_expert_output, non_blocking=True)
|
||||
return output
|
||||
|
||||
|
||||
class TopKWeightAndReduceContiguous(mk.TopKWeightAndReduce):
|
||||
"""
|
||||
TopKWeightAndReduce implementation for a fused_experts output
|
||||
of shape (m, topk, K)
|
||||
"""
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, TopKWeightAndReduceContiguous)
|
||||
|
||||
def apply(self, output: Optional[torch.Tensor],
|
||||
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool) -> torch.Tensor:
|
||||
|
||||
m, num_topk = topk_ids.size()
|
||||
k = fused_expert_output.size(-1)
|
||||
if fused_expert_output.ndim == 2:
|
||||
fused_expert_output = fused_expert_output.view(m, num_topk, k)
|
||||
|
||||
assert fused_expert_output.size() == (m, num_topk, k), (
|
||||
f"Expected fused_expert_output size {(m, num_topk, k)}. But got "
|
||||
f"{fused_expert_output.size()}")
|
||||
|
||||
if not apply_router_weight_on_input:
|
||||
fused_expert_output.mul_(topk_weights.view(m, -1, 1))
|
||||
|
||||
if output is None:
|
||||
output = torch.empty((m, k),
|
||||
device=fused_expert_output.device,
|
||||
dtype=fused_expert_output.dtype)
|
||||
assert output.size() == (m, k), (
|
||||
f"Expected output size {(m, k)}. But got {output.size()}")
|
||||
|
||||
ops.moe_sum(fused_expert_output, output)
|
||||
return output
|
||||
|
||||
|
||||
class TopKWeightAndReduceNaiveBatched(mk.TopKWeightAndReduce):
|
||||
"""
|
||||
TopKWeightAndReduce implementation for a fused_experts output
|
||||
of shape (num_experts, batch_size, K)
|
||||
"""
|
||||
|
||||
def __init__(self, rank: int):
|
||||
self.rank = rank
|
||||
|
||||
def __eq__(self, other):
|
||||
return (isinstance(other, TopKWeightAndReduceNaiveBatched)
|
||||
and (other.rank == self.rank))
|
||||
|
||||
def apply(self, output: Optional[torch.Tensor],
|
||||
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool) -> torch.Tensor:
|
||||
assert fused_expert_output.ndim == 3
|
||||
num_tokens = topk_ids.size(0)
|
||||
num_local_experts = fused_expert_output.size(0)
|
||||
K = fused_expert_output.size(-1)
|
||||
|
||||
if output is None:
|
||||
output = torch.zeros((num_tokens, K),
|
||||
device=fused_expert_output.device,
|
||||
dtype=fused_expert_output.dtype)
|
||||
else:
|
||||
output.fill_(0)
|
||||
|
||||
assert output.size() == (num_tokens, K), (
|
||||
f"Expected output size {(num_tokens, K)}, but got {output.size()}")
|
||||
|
||||
first_expert = num_local_experts * self.rank
|
||||
last_expert = first_expert + num_local_experts
|
||||
|
||||
for expert_id in range(first_expert, last_expert):
|
||||
matching_tokens = topk_ids == expert_id
|
||||
topks = torch.any(matching_tokens, dim=1).flatten()
|
||||
rows = torch.count_nonzero(topks)
|
||||
rhs = fused_expert_output[expert_id - first_expert, :rows, :]
|
||||
if not apply_router_weight_on_input:
|
||||
rhs.mul_(topk_weights[matching_tokens].view(rhs.size(0), 1))
|
||||
output[topks] = output[topks] + rhs
|
||||
|
||||
return output
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from math import prod
|
||||
from typing import Optional
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -10,7 +10,83 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
per_token_group_quant_int8, per_token_quant_int8)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
quant_dequant_mxfp4)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import cdiv
|
||||
# from vllm.utils.flashinfer import fp4_quantize
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _count_expert_num_tokens(topk_ids_ptr, expert_num_tokens_ptr, num_experts,
|
||||
topk_numel, expert_map,
|
||||
HAS_EXPERT_MAP: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr):
|
||||
|
||||
curr_expert = tl.program_id(0)
|
||||
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
topk_ids_ptrs = topk_ids_ptr + offsets
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE, ), dtype=tl.int32)
|
||||
for x in range(tl.cdiv(topk_numel, BLOCK_SIZE)):
|
||||
mask = offsets < (topk_numel - x * BLOCK_SIZE)
|
||||
expert_ids = tl.load(topk_ids_ptrs, mask=mask, other=-1)
|
||||
if HAS_EXPERT_MAP:
|
||||
expert_map_ptrs = expert_map + expert_ids
|
||||
expert_map_mask = expert_ids >= 0
|
||||
expert_ids = tl.load(expert_map_ptrs,
|
||||
mask=expert_map_mask,
|
||||
other=-1)
|
||||
|
||||
has_curr_expert = tl.where(expert_ids == curr_expert, 1, 0)
|
||||
acc = acc + has_curr_expert
|
||||
topk_ids_ptrs += BLOCK_SIZE
|
||||
|
||||
if curr_expert < num_experts:
|
||||
tl.store(expert_num_tokens_ptr + curr_expert, tl.sum(acc))
|
||||
|
||||
|
||||
def count_expert_num_tokens(
|
||||
topk_ids: torch.Tensor, num_local_experts: int,
|
||||
expert_map: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Count the number to tokens assigned to each expert.
|
||||
|
||||
Parameters:
|
||||
- topk_ids (torch.Tensor): Tensor mapping each token to its
|
||||
list of experts.
|
||||
- num_local_experts (int): Number of experts in this rank.
|
||||
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
||||
from the global expert space to the local expert space of the expert
|
||||
parallel shard.
|
||||
|
||||
Returns:
|
||||
A tensor of size num_local_experts, where tensor[i] holds the number
|
||||
of tokens assigned to the ith expert.
|
||||
"""
|
||||
assert topk_ids.dtype.is_signed, (
|
||||
"The kernel uses -1 to represent invalid topk_ids")
|
||||
expert_num_tokens = torch.empty((num_local_experts),
|
||||
device=topk_ids.device,
|
||||
dtype=torch.int32)
|
||||
|
||||
grid = num_local_experts
|
||||
BLOCK_SIZE = min(topk_ids.numel(), 1024)
|
||||
BLOCK_SIZE = triton.next_power_of_2(BLOCK_SIZE)
|
||||
|
||||
_count_expert_num_tokens[(grid, )](
|
||||
topk_ids,
|
||||
expert_num_tokens,
|
||||
num_local_experts,
|
||||
topk_ids.numel(),
|
||||
expert_map,
|
||||
HAS_EXPERT_MAP=expert_map is not None,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
return expert_num_tokens
|
||||
|
||||
|
||||
def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor:
|
||||
@@ -23,6 +99,16 @@ def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor:
|
||||
return x.flatten()[:prod(v)].view(*v)
|
||||
|
||||
|
||||
# def _fp4_quantize(
|
||||
# A: torch.Tensor,
|
||||
# A_scale: Optional[torch.Tensor],
|
||||
# is_sf_swizzled_layout: bool,
|
||||
# ) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# return fp4_quantize(A,
|
||||
# A_scale,
|
||||
# is_sf_swizzled_layout=is_sf_swizzled_layout)
|
||||
|
||||
|
||||
def _fp8_quantize(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
@@ -34,9 +120,12 @@ def _fp8_quantize(
|
||||
is provided, the output will be blocked.
|
||||
"""
|
||||
if block_shape is None:
|
||||
# TODO(luka): use QuantFP8 custom op
|
||||
# https://github.com/vllm-project/vllm/issues/20711
|
||||
A, A_scale = ops.scaled_fp8_quant(
|
||||
A, A_scale, use_per_token_if_dynamic=per_act_token)
|
||||
else:
|
||||
assert not per_act_token
|
||||
assert len(block_shape) == 2
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
||||
@@ -62,9 +151,9 @@ def _int8_quantize(
|
||||
if block_shape is None:
|
||||
assert per_act_token, \
|
||||
"int8 quantization only supports block or channel-wise"
|
||||
# A, A_scale = per_token_quant_int8(A)
|
||||
A, A_scale, _ = ops.scaled_int8_quant(A, A_scale)
|
||||
A, A_scale = per_token_quant_int8(A)
|
||||
else:
|
||||
assert not per_act_token
|
||||
assert len(block_shape) == 2
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
A, A_scale = per_token_group_quant_int8(A, block_k)
|
||||
@@ -73,19 +162,40 @@ def _int8_quantize(
|
||||
return A, A_scale
|
||||
|
||||
|
||||
def _mxfp4_quantize(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> tuple[torch.Tensor, None]:
|
||||
assert block_shape is None
|
||||
if not current_platform.supports_mx():
|
||||
A = quant_dequant_mxfp4(A)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return A, None
|
||||
|
||||
|
||||
def moe_kernel_quantize_input(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
qtype: Optional[torch.dtype],
|
||||
per_channel_quant: bool,
|
||||
quant_dtype: Union[None, torch.dtype, str],
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
is_fp4_scale_swizzled: bool = True,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if qtype == torch.float8_e4m3fn:
|
||||
return _fp8_quantize(A, A_scale, per_channel_quant, block_shape)
|
||||
elif qtype == torch.int8:
|
||||
return _int8_quantize(A, A_scale, per_channel_quant, block_shape)
|
||||
if quant_dtype == torch.float8_e4m3fn:
|
||||
return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
|
||||
elif quant_dtype == torch.int8:
|
||||
return _int8_quantize(A, A_scale, per_act_token_quant, block_shape)
|
||||
elif quant_dtype == torch.uint8: # nvfp4
|
||||
return _fp4_quantize(A,
|
||||
A_scale,
|
||||
is_sf_swizzled_layout=is_fp4_scale_swizzled)
|
||||
elif quant_dtype == "mxfp4":
|
||||
return _mxfp4_quantize(A, A_scale, per_act_token_quant, block_shape)
|
||||
else:
|
||||
assert A_scale is None
|
||||
return A, A_scale
|
||||
|
||||
|
||||
@@ -97,3 +207,62 @@ def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
||||
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
|
||||
else:
|
||||
return m[idx, ...]
|
||||
|
||||
|
||||
def normalize_scales_shape(
|
||||
scales: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
||||
if scales is not None:
|
||||
if scales.numel() == 1:
|
||||
scales = scales.view(1, 1)
|
||||
else:
|
||||
scales = scales.view(-1, scales.size(-1))
|
||||
return scales
|
||||
|
||||
|
||||
def normalize_batched_scales_shape(
|
||||
scales: Optional[torch.Tensor],
|
||||
num_experts: int,
|
||||
) -> Optional[torch.Tensor]:
|
||||
if scales is not None and scales.ndim < 3:
|
||||
if scales.numel() == 1:
|
||||
scales = scales.view(1)
|
||||
scales = torch.repeat_interleave(scales, num_experts,
|
||||
dim=0).view(num_experts, 1, 1)
|
||||
else:
|
||||
scales = scales.view(num_experts, -1, scales.size(-1))
|
||||
|
||||
return scales
|
||||
|
||||
|
||||
def _validate_scale_shape(
|
||||
a: torch.Tensor,
|
||||
a_scale: Optional[torch.Tensor],
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]],
|
||||
) -> None:
|
||||
if a_scale is None:
|
||||
return
|
||||
|
||||
if not per_act_token_quant and block_shape is None:
|
||||
assert a_scale.numel() == 1, f"{a_scale.shape}"
|
||||
elif per_act_token_quant:
|
||||
assert a_scale.shape[0] == a.shape[0] and a_scale.shape[1] == 1, (
|
||||
f"{a_scale.shape[0]} == {a.shape[0]} and {a_scale.shape[1]} == 1")
|
||||
else:
|
||||
assert block_shape is not None
|
||||
expected = (a.shape[0], cdiv(a.shape[1], block_shape[1]))
|
||||
assert a_scale.shape == expected, f"{a_scale.shape} == {expected}"
|
||||
|
||||
|
||||
def extract_required_args(
|
||||
extra_args: Optional[dict[str, Any]],
|
||||
required_keys: list[str],
|
||||
) -> tuple[Any, ...]:
|
||||
if extra_args is None:
|
||||
raise ValueError("`extra_args` must be provided.")
|
||||
|
||||
missing_keys = [k for k in required_keys if k not in extra_args]
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing keys in `extra_args`: {missing_keys}")
|
||||
|
||||
return tuple(extra_args[k] for k in required_keys)
|
||||
|
||||
Reference in New Issue
Block a user