Sync from v0.13
This commit is contained in:
0
tests/kernels/moe/__init__.py
Normal file
0
tests/kernels/moe/__init__.py
Normal file
0
tests/kernels/moe/modular_kernel_tools/__init__.py
Normal file
0
tests/kernels/moe/modular_kernel_tools/__init__.py
Normal file
164
tests/kernels/moe/modular_kernel_tools/cli_args.py
Normal file
164
tests/kernels/moe/modular_kernel_tools/cli_args.py
Normal file
@@ -0,0 +1,164 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
|
||||
from .common import Config
|
||||
from .mk_objects import (
|
||||
MK_ALL_PREPARE_FINALIZE_TYPES,
|
||||
MK_FUSED_EXPERT_TYPES,
|
||||
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES,
|
||||
)
|
||||
|
||||
|
||||
def make_config_arg_parser(description: str):
|
||||
def to_pf_class_type(s: str) -> mk.FusedMoEPrepareAndFinalize:
|
||||
for pf in MK_ALL_PREPARE_FINALIZE_TYPES:
|
||||
if pf.__name__ == s:
|
||||
return pf
|
||||
raise ValueError(f"Cannot find a PrepareFinalize type that matches {s}")
|
||||
|
||||
def to_experts_class_type(s: str) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
for fe in MK_FUSED_EXPERT_TYPES:
|
||||
if fe.__name__ == s:
|
||||
return fe
|
||||
raise ValueError(f"Cannot find a FusedExperts type that matches {s}")
|
||||
|
||||
def to_quant_torch_dtype(s: str) -> torch.dtype:
|
||||
if s == "torch.float8_e4m3fn":
|
||||
return torch.float8_e4m3fn
|
||||
raise ValueError(f"Unsupported quant type {s}")
|
||||
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
|
||||
parser.add_argument(
|
||||
"--world-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Number of ranks that participate in all2all",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pf-type",
|
||||
type=to_pf_class_type,
|
||||
required=True,
|
||||
help=(
|
||||
"Choose a PrepareFinalize Type : "
|
||||
f"{[x.__name__ for x in MK_ALL_PREPARE_FINALIZE_TYPES]}"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--experts-type",
|
||||
type=to_experts_class_type,
|
||||
required=True,
|
||||
help=(
|
||||
f"Choose a FusedExpert type : {[x.__name__ for x in MK_FUSED_EXPERT_TYPES]}"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[64],
|
||||
help="num tokens per rank",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-k",
|
||||
type=int,
|
||||
default=7168,
|
||||
help="hidden-size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-n",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="N dimension of the first fused-moe matmul",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-experts", type=int, default=32, help="Global num experts"
|
||||
)
|
||||
parser.add_argument("--topk", nargs="+", type=int, default=[4, 1], help="num topk")
|
||||
parser.add_argument(
|
||||
"--fused-moe-chunk-size",
|
||||
type=int,
|
||||
help="Fused moe chunk size used for the non-batched fused experts impl.",
|
||||
)
|
||||
|
||||
# Quant args
|
||||
parser.add_argument(
|
||||
"--quant-dtype", type=to_quant_torch_dtype, help="Quant datatype"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per-token-quantized-activations",
|
||||
action="store_true",
|
||||
help=("The input activations must be per-token quantized"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per-channel-quantized-weights",
|
||||
action="store_true",
|
||||
help="The weights must be per-channel quantized.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--block-shape", nargs="+", type=int, help="Quantization block shape"
|
||||
)
|
||||
|
||||
# Torch trace profile generation args
|
||||
parser.add_argument(
|
||||
"--torch-trace-dir-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Get torch trace for single execution",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def _validate_args(args: argparse.Namespace):
|
||||
if args.quant_dtype is not None:
|
||||
assert args.quant_dtype == torch.float8_e4m3fn
|
||||
if args.block_shape is not None:
|
||||
assert len(args.block_shape) == 2, (
|
||||
f"block shape must have 2 elements. got {args.block_shape}"
|
||||
)
|
||||
|
||||
if args.experts_type in MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES:
|
||||
assert args.world_size == 1, "Single GPU objects need world size set to 1"
|
||||
|
||||
if args.torch_trace_dir_path is not None:
|
||||
from pathlib import Path
|
||||
|
||||
assert Path(args.torch_trace_dir_path).is_dir(), (
|
||||
f"Please create {args.torch_trace_dir_path}"
|
||||
)
|
||||
|
||||
|
||||
def make_config(args: argparse.Namespace) -> Config:
|
||||
_validate_args(args)
|
||||
|
||||
quant_config = None
|
||||
if args.quant_dtype is not None:
|
||||
quant_config = FusedMoEQuantConfig(
|
||||
quant_dtype=args.quant_dtype,
|
||||
per_act_token_quant=args.per_token_quantized_activations,
|
||||
per_out_ch_quant=args.per_channel_quantized_weights,
|
||||
block_shape=args.block_shape,
|
||||
)
|
||||
|
||||
return Config(
|
||||
Ms=args.m,
|
||||
K=args.k,
|
||||
N=args.n,
|
||||
E=args.num_experts,
|
||||
topks=args.topk,
|
||||
dtype=torch.bfloat16, # hard-code
|
||||
quant_config=quant_config,
|
||||
prepare_finalize_type=args.pf_type,
|
||||
fused_experts_type=args.experts_type,
|
||||
fused_moe_chunk_size=args.fused_moe_chunk_size,
|
||||
world_size=args.world_size,
|
||||
torch_trace_dir_path=args.torch_trace_dir_path,
|
||||
)
|
||||
668
tests/kernels/moe/modular_kernel_tools/common.py
Normal file
668
tests/kernels/moe/modular_kernel_tools/common.py
Normal file
@@ -0,0 +1,668 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from tests.kernels.moe.utils import make_test_weights, per_token_cast_to_fp8
|
||||
from tests.kernels.quantization.nvfp4_utils import (
|
||||
FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX,
|
||||
dequantize_nvfp4_to_dtype,
|
||||
)
|
||||
from tests.kernels.utils import torch_experts
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (
|
||||
get_dp_group,
|
||||
get_pcp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||
|
||||
from .mk_objects import (
|
||||
TestMoEQuantConfig,
|
||||
expert_info,
|
||||
make_fused_experts,
|
||||
make_prepare_finalize,
|
||||
prepare_finalize_info,
|
||||
)
|
||||
from .parallel_utils import ProcessGroupInfo
|
||||
|
||||
|
||||
def _describe_tensor(t: torch.Tensor | None, name: str) -> str:
|
||||
if t is None:
|
||||
return f"{name} : None"
|
||||
else:
|
||||
return f"{name} : {t.shape} {t.dtype} {t.device}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
Ms: list[int] | int
|
||||
K: int
|
||||
N: int
|
||||
E: int
|
||||
topks: list[int] | int
|
||||
dtype: torch.dtype
|
||||
quant_config: TestMoEQuantConfig | None
|
||||
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize
|
||||
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute
|
||||
|
||||
fused_moe_chunk_size: int | None
|
||||
world_size: int
|
||||
|
||||
torch_trace_dir_path: str | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.quant_config is None:
|
||||
self.quant_config = TestMoEQuantConfig(None, False, False, None)
|
||||
|
||||
def describe(self) -> str:
|
||||
s = ""
|
||||
s += "== Config:\n"
|
||||
s += f" world_size={self.world_size}\n"
|
||||
s += f" PF={self.prepare_finalize_type.__name__}\n"
|
||||
s += f" FE={self.fused_experts_type.__name__}\n"
|
||||
s += f" E={self.E}\n"
|
||||
s += f" Ms={self.Ms}\n"
|
||||
s += f" N={self.N}\n"
|
||||
s += f" K={self.K}\n"
|
||||
s += f" topk={self.topks}\n"
|
||||
s += f" dtype={self.dtype}\n"
|
||||
s += f" fused_moe_chunk_size={self.fused_moe_chunk_size}\n"
|
||||
s += " Quant:\n"
|
||||
if self.quant_config is not None:
|
||||
s += f" q_dtype={self.quant_dtype}\n"
|
||||
s += f" q_block_shape={self.quant_block_shape}\n"
|
||||
s += f" q_per_out_ch_quant={self.is_per_out_ch_quant}\n"
|
||||
s += f" q_per_act_token={self.is_per_act_token_quant}\n"
|
||||
else:
|
||||
s += " quant=None\n"
|
||||
return s
|
||||
|
||||
@property
|
||||
def M(self) -> int:
|
||||
assert isinstance(self.Ms, int)
|
||||
return self.Ms
|
||||
|
||||
@property
|
||||
def quant_dtype(self) -> torch.dtype | str | None:
|
||||
assert self.quant_config is not None
|
||||
return self.quant_config.quant_dtype
|
||||
|
||||
@property
|
||||
def is_per_act_token_quant(self) -> bool:
|
||||
assert self.quant_config is not None
|
||||
return self.quant_config.per_act_token_quant
|
||||
|
||||
@property
|
||||
def is_per_tensor_act_quant(self) -> bool:
|
||||
return not self.is_per_act_token_quant and self.quant_block_shape is None
|
||||
|
||||
@property
|
||||
def is_per_out_ch_quant(self) -> bool:
|
||||
assert self.quant_config is not None
|
||||
return self.quant_config.per_out_ch_quant
|
||||
|
||||
@property
|
||||
def quant_block_shape(self) -> list[int] | None:
|
||||
assert self.quant_config is not None
|
||||
return self.quant_config.block_shape
|
||||
|
||||
@property
|
||||
def topk(self) -> int:
|
||||
assert isinstance(self.topks, int)
|
||||
return self.topks
|
||||
|
||||
@property
|
||||
def num_local_experts(self) -> int:
|
||||
return self.E // self.world_size
|
||||
|
||||
def make_env_data(self) -> tuple[VllmConfig, dict[Any, Any]]:
|
||||
"""
|
||||
make env data for vllm launch.
|
||||
"""
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.parallel_config.data_parallel_size = self.world_size
|
||||
vllm_config.parallel_config.enable_expert_parallel = True
|
||||
|
||||
env_dict = {
|
||||
"VLLM_USE_DEEP_GEMM": str(int(self.needs_deep_gemm())),
|
||||
}
|
||||
|
||||
backend = self.all2all_backend()
|
||||
vllm_config.parallel_config.all2all_backend = backend
|
||||
if backend is not None:
|
||||
env_dict.update({"VLLM_ALL2ALL_BACKEND": backend})
|
||||
|
||||
if self.fused_moe_chunk_size is not None:
|
||||
env_dict.update(
|
||||
{"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)}
|
||||
)
|
||||
|
||||
return vllm_config, env_dict
|
||||
|
||||
def is_fp8_block_quantized(self):
|
||||
return (
|
||||
self.quant_dtype == torch.float8_e4m3fn
|
||||
and self.quant_block_shape is not None
|
||||
)
|
||||
|
||||
def is_batched_prepare_finalize(self):
|
||||
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||
return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format
|
||||
|
||||
def is_batched_fused_experts(self):
|
||||
info = expert_info(self.fused_experts_type)
|
||||
return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format
|
||||
|
||||
def is_standard_fused_experts(self):
|
||||
info = expert_info(self.fused_experts_type)
|
||||
return mk.FusedMoEActivationFormat.Standard == info.activation_format
|
||||
|
||||
def fe_supported_types(self):
|
||||
info = expert_info(self.fused_experts_type)
|
||||
return info.supported_dtypes
|
||||
|
||||
def pf_supported_types(self):
|
||||
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||
return info.supported_dtypes
|
||||
|
||||
def is_block_quant_supported(self):
|
||||
info = expert_info(self.fused_experts_type)
|
||||
return info.blocked_quantization_support
|
||||
|
||||
def is_fe_supports_chunking(self):
|
||||
info = expert_info(self.fused_experts_type)
|
||||
return info.supports_chunking
|
||||
|
||||
def supports_expert_map(self):
|
||||
info = expert_info(self.fused_experts_type)
|
||||
return info.supports_expert_map
|
||||
|
||||
def supports_apply_weight_on_input(self):
|
||||
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||
return info.supports_apply_weight_on_input
|
||||
|
||||
def needs_deep_gemm(self):
|
||||
info = expert_info(self.fused_experts_type)
|
||||
return info.needs_deep_gemm
|
||||
|
||||
def needs_pplx(self):
|
||||
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||
return info.backend == "pplx"
|
||||
|
||||
def needs_deep_ep(self):
|
||||
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||
return (
|
||||
info.backend == "deepep_high_throughput"
|
||||
or info.backend == "deepep_low_latency"
|
||||
)
|
||||
|
||||
def all2all_backend(self):
|
||||
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||
return info.backend
|
||||
|
||||
def is_valid(self) -> tuple[bool, str | None]:
|
||||
# Check prepare-finalize and fused-experts compatibility
|
||||
if self.is_batched_prepare_finalize():
|
||||
if not self.is_batched_fused_experts():
|
||||
return False, "Mismatched format."
|
||||
else:
|
||||
if not self.is_standard_fused_experts():
|
||||
return False, "Mismatched format."
|
||||
|
||||
use_chunking = self.fused_moe_chunk_size is not None
|
||||
if use_chunking and not self.is_fe_supports_chunking():
|
||||
return False, "Chunking not supported."
|
||||
|
||||
# Check quantization sanity
|
||||
if (
|
||||
int(self.is_per_act_token_quant)
|
||||
+ int(self.is_per_tensor_act_quant)
|
||||
+ int(self.quant_block_shape is not None)
|
||||
) > 1:
|
||||
# invalid quant config
|
||||
return False, f"Bad quant_config {self.quant_config}."
|
||||
|
||||
# check type support
|
||||
if self.quant_dtype is None:
|
||||
if (
|
||||
self.dtype not in self.pf_supported_types()
|
||||
or self.dtype not in self.fe_supported_types()
|
||||
):
|
||||
return False, (
|
||||
f"Unsupported type {self.dtype} not in "
|
||||
f"{self.pf_supported_types()} and "
|
||||
f"{self.fe_supported_types()}."
|
||||
)
|
||||
else:
|
||||
if (
|
||||
self.quant_dtype not in self.pf_supported_types()
|
||||
or self.quant_dtype not in self.fe_supported_types()
|
||||
):
|
||||
return False, (
|
||||
f"Unsupported quant type {self.quant_dtype} "
|
||||
f"not in {self.pf_supported_types()} and "
|
||||
f"{self.fe_supported_types()}."
|
||||
)
|
||||
|
||||
# Check block quanization support
|
||||
is_block_quatized = self.quant_block_shape is not None
|
||||
if is_block_quatized and self.quant_dtype is None:
|
||||
return False, "No block quantization support."
|
||||
|
||||
if is_block_quatized and not self.is_block_quant_supported():
|
||||
return False, "Mismatched block quantization support."
|
||||
|
||||
# deep_gemm only works with block-quantized
|
||||
if self.needs_deep_gemm() and not is_block_quatized:
|
||||
return False, "Needs DeepGEMM but not block quantized."
|
||||
|
||||
# Check dependencies (turn into asserts?)
|
||||
if self.needs_deep_ep() and not has_deep_ep():
|
||||
return False, "Needs DeepEP, but DeepEP not available."
|
||||
if self.needs_deep_gemm() and not has_deep_gemm():
|
||||
return False, "Needs DeepGEMM, but DeepGEMM not available."
|
||||
if self.needs_pplx() and not has_pplx(): # noqa: SIM103
|
||||
return False, "Needs PPLX, but PPLX not available."
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
@dataclass
|
||||
class WeightTensors:
|
||||
w1: torch.Tensor
|
||||
w2: torch.Tensor
|
||||
w1_scale: torch.Tensor | None
|
||||
w2_scale: torch.Tensor | None
|
||||
w1_gs: torch.Tensor | None = None
|
||||
w2_gs: torch.Tensor | None = None
|
||||
|
||||
def describe(self):
|
||||
s = ""
|
||||
s += "== Weight Tensors: \n"
|
||||
s += f" - {_describe_tensor(self.w1, 'w1')} \n"
|
||||
s += f" - {_describe_tensor(self.w2, 'w2')} \n"
|
||||
s += f" - {_describe_tensor(self.w1_scale, 'w1_scale')} \n"
|
||||
s += f" - {_describe_tensor(self.w2_scale, 'w2_scale')} \n"
|
||||
s += f" - {_describe_tensor(self.w1_gs, 'w1_gs')} \n"
|
||||
s += f" - {_describe_tensor(self.w2_gs, 'w2_gs')} \n"
|
||||
return s
|
||||
|
||||
def is_quantized(self) -> bool:
|
||||
# or w1_scale is not None?
|
||||
return (
|
||||
self.w1.dtype == torch.float8_e4m3fn
|
||||
or self.w1.dtype == torch.uint8
|
||||
or self.w1.dtype == torch.int8
|
||||
)
|
||||
|
||||
def to_current_device(self):
|
||||
device = torch.cuda.current_device()
|
||||
self.w1 = self.w1.to(device=device)
|
||||
self.w2 = self.w2.to(device=device)
|
||||
|
||||
if self.w1_scale is not None:
|
||||
self.w1_scale = self.w1_scale.to(device=device)
|
||||
if self.w2_scale is not None:
|
||||
self.w2_scale = self.w2_scale.to(device=device)
|
||||
|
||||
if self.w1_gs is not None:
|
||||
self.w1_gs = self.w1_gs.to(device=device)
|
||||
if self.w2_gs is not None:
|
||||
self.w2_gs = self.w2_gs.to(device=device)
|
||||
|
||||
def slice_weights(self, rank: int, num_local_experts: int) -> "WeightTensors":
|
||||
s = rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
w1 = self.w1[s:e, :, :]
|
||||
w2 = self.w2[s:e, :, :]
|
||||
w1_scale = self.w1_scale[s:e, :, :] if self.w1_scale is not None else None
|
||||
w2_scale = self.w2_scale[s:e, :, :] if self.w2_scale is not None else None
|
||||
w1_gs = self.w1_gs[s:e] if self.w1_gs is not None else None
|
||||
w2_gs = self.w2_gs[s:e] if self.w2_gs is not None else None
|
||||
|
||||
return WeightTensors(w1, w2, w1_scale, w2_scale, w1_gs, w2_gs)
|
||||
|
||||
@staticmethod
|
||||
def make(config: Config) -> "WeightTensors":
|
||||
(_, w1, w1_scale, w1_gs), (_, w2, w2_scale, w2_gs) = make_test_weights(
|
||||
e=config.E,
|
||||
n=config.N,
|
||||
k=config.K,
|
||||
in_dtype=config.dtype,
|
||||
quant_dtype=config.quant_dtype,
|
||||
block_shape=config.quant_block_shape,
|
||||
# or config.is_per_out_ch_quant
|
||||
per_out_ch_quant=config.is_per_act_token_quant,
|
||||
)
|
||||
return WeightTensors(
|
||||
w1=w1, w2=w2, w1_scale=w1_scale, w2_scale=w2_scale, w1_gs=w1_gs, w2_gs=w2_gs
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RankTensors:
|
||||
hidden_states: torch.Tensor
|
||||
hidden_states_scale: torch.Tensor | None
|
||||
|
||||
topk_weights: torch.Tensor
|
||||
topk_ids: torch.Tensor
|
||||
expert_map: torch.Tensor | None
|
||||
|
||||
def describe(self):
|
||||
s = ""
|
||||
s += "== Rank Tensors: \n"
|
||||
s += f" - {_describe_tensor(self.hidden_states, 'HS')} \n"
|
||||
s += f" - {_describe_tensor(self.hidden_states_scale, 'HS_scale')} \n"
|
||||
s += f" - {_describe_tensor(self.topk_weights, 'topk_weights')} \n"
|
||||
s += f" - {_describe_tensor(self.topk_ids, 'topk_ids')} \n"
|
||||
s += f" - {_describe_tensor(self.expert_map, 'expert_map')} \n"
|
||||
return s
|
||||
|
||||
@staticmethod
|
||||
def make_hidden_states(
|
||||
config: Config,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""
|
||||
Return hidden_states
|
||||
"""
|
||||
m, k, dtype = (config.M, config.K, config.dtype)
|
||||
a = torch.randn((m, k), device=torch.cuda.current_device(), dtype=dtype) / 15.0
|
||||
|
||||
if config.quant_dtype is None:
|
||||
return a, None
|
||||
|
||||
# We dequant and use that as hidden_states so the tests are stable.
|
||||
# quantizing and dequantizing yield slightly different results
|
||||
# depending on the hardware. Here we, quantize and dequantize
|
||||
# first - so further quantize and dequantize will yield the same
|
||||
# values.
|
||||
if config.is_per_tensor_act_quant:
|
||||
a_q, a_scales = ops.scaled_fp8_quant(a, use_per_token_if_dynamic=False)
|
||||
return a_q.float().mul(a_scales).to(dtype), a_scales
|
||||
|
||||
if config.is_per_act_token_quant:
|
||||
a_q, a_scales = ops.scaled_fp8_quant(a, use_per_token_if_dynamic=True)
|
||||
return a_q.float().mul(a_scales).to(dtype), None
|
||||
|
||||
assert config.quant_block_shape is not None
|
||||
block_k = config.quant_block_shape[1]
|
||||
a_q, a_scales = per_token_cast_to_fp8(a, block_size=block_k)
|
||||
return a_q.float().view((-1, block_k)).mul(a_scales.view(-1, 1)).view(m, k).to(
|
||||
dtype
|
||||
), None
|
||||
|
||||
@staticmethod
|
||||
def make(config: Config, pgi: ProcessGroupInfo):
|
||||
dtype = config.dtype
|
||||
topk, m, _ = (config.topk, config.M, config.K)
|
||||
hidden_states, hidden_states_scale = RankTensors.make_hidden_states(config)
|
||||
|
||||
num_local_experts, global_num_experts = (config.num_local_experts, config.E)
|
||||
score = torch.randn((m, global_num_experts), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk, False)
|
||||
|
||||
# distribute topk_ids evenly
|
||||
for mi in range(m):
|
||||
topk_ids[mi] = torch.randperm(config.E)[:topk]
|
||||
topk_ids = topk_ids.to(device=torch.cuda.current_device())
|
||||
|
||||
expert_map = None
|
||||
if config.world_size > 1 and config.supports_expert_map():
|
||||
expert_map = torch.full(
|
||||
(global_num_experts,), fill_value=-1, dtype=torch.int32
|
||||
)
|
||||
s = pgi.rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
|
||||
expert_map = expert_map.to(
|
||||
device=torch.cuda.current_device(), dtype=torch.int32
|
||||
)
|
||||
|
||||
return RankTensors(
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_scale=hidden_states_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
|
||||
|
||||
def reference_moe_impl(
|
||||
config: Config, weights: WeightTensors, rank_tensors: RankTensors
|
||||
) -> torch.Tensor:
|
||||
if config.quant_dtype == "nvfp4":
|
||||
quant_blocksize = 16
|
||||
dtype = config.dtype
|
||||
|
||||
w1_q = weights.w1
|
||||
w1_blockscale = weights.w1_scale
|
||||
w1_gs = weights.w1_gs
|
||||
|
||||
w2_q = weights.w2
|
||||
w2_blockscale = weights.w2_scale
|
||||
w2_gs = weights.w2_gs
|
||||
|
||||
a_global_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX)
|
||||
/ torch.amax(rank_tensors.hidden_states.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
|
||||
assert w1_gs is not None
|
||||
assert w2_gs is not None
|
||||
assert w1_blockscale is not None
|
||||
assert w2_blockscale is not None
|
||||
|
||||
assert w1_blockscale.shape[1] % 128 == 0
|
||||
assert w1_blockscale.shape[2] % 4 == 0
|
||||
assert w2_blockscale.shape[1] % 128 == 0
|
||||
assert w2_blockscale.shape[2] % 4 == 0
|
||||
|
||||
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(
|
||||
rank_tensors.hidden_states, a_global_scale
|
||||
)
|
||||
|
||||
a = dequantize_nvfp4_to_dtype(
|
||||
a_fp4,
|
||||
a_scale_interleaved,
|
||||
a_global_scale,
|
||||
dtype=dtype,
|
||||
device=a_fp4.device,
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
|
||||
e = w1_q.shape[0]
|
||||
n = w1_q.shape[1] // 2
|
||||
k = w2_q.shape[1]
|
||||
|
||||
w1 = torch.zeros((e, 2 * n, k), device="cuda", dtype=dtype)
|
||||
w2 = torch.zeros((e, k, n), device="cuda", dtype=dtype)
|
||||
|
||||
for idx in range(0, e):
|
||||
w1[idx] = dequantize_nvfp4_to_dtype(
|
||||
w1_q[idx],
|
||||
w1_blockscale[idx],
|
||||
w1_gs[idx],
|
||||
dtype=dtype,
|
||||
device=w1_q.device,
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
w2[idx] = dequantize_nvfp4_to_dtype(
|
||||
w2_q[idx],
|
||||
w2_blockscale[idx],
|
||||
w2_gs[idx],
|
||||
dtype=dtype,
|
||||
device=w2_q.device,
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
a_scale = None
|
||||
w1_scale = None
|
||||
w2_scale = None
|
||||
quant_dtype = None
|
||||
per_act_token_quant = False
|
||||
block_shape = None
|
||||
else:
|
||||
a = rank_tensors.hidden_states
|
||||
a_scale = rank_tensors.hidden_states_scale
|
||||
w1 = weights.w1
|
||||
w1_scale = weights.w1_scale
|
||||
w2 = weights.w2
|
||||
w2_scale = weights.w2_scale
|
||||
quant_dtype = config.quant_dtype
|
||||
per_act_token_quant = config.is_per_act_token_quant
|
||||
block_shape = config.quant_block_shape
|
||||
|
||||
return torch_experts(
|
||||
a=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weight=rank_tensors.topk_weights,
|
||||
topk_ids=rank_tensors.topk_ids,
|
||||
global_num_experts=config.E,
|
||||
expert_map=None,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_scale,
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
apply_router_weights_on_input=config.topk == 1
|
||||
and config.supports_apply_weight_on_input(),
|
||||
)
|
||||
|
||||
|
||||
def _make_gscale(num_experts: int) -> torch.Tensor:
|
||||
return torch.ones(
|
||||
(num_experts,), device=torch.cuda.current_device(), dtype=torch.float32
|
||||
)
|
||||
|
||||
|
||||
def make_modular_kernel(
|
||||
config: Config,
|
||||
vllm_config: VllmConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.FusedMoEModularKernel:
|
||||
def next_power_of_2(x):
|
||||
import math
|
||||
|
||||
if x == 0:
|
||||
return 1
|
||||
return 2 ** math.ceil(math.log2(x))
|
||||
|
||||
# make moe config
|
||||
moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
|
||||
tp_size_=get_tensor_model_parallel_world_size(),
|
||||
pcp_size_=get_pcp_group().world_size,
|
||||
dp_size_=get_dp_group().world_size,
|
||||
vllm_parallel_config=vllm_config.parallel_config,
|
||||
)
|
||||
|
||||
moe = FusedMoEConfig(
|
||||
num_experts=config.E,
|
||||
experts_per_token=config.topk,
|
||||
hidden_dim=config.K,
|
||||
num_local_experts=config.num_local_experts,
|
||||
moe_parallel_config=moe_parallel_config,
|
||||
in_dtype=config.dtype,
|
||||
max_num_tokens=next_power_of_2(config.M),
|
||||
)
|
||||
|
||||
# make modular kernel
|
||||
prepare_finalize = make_prepare_finalize(
|
||||
config.prepare_finalize_type, config.all2all_backend(), moe, quant_config
|
||||
)
|
||||
|
||||
fused_experts = make_fused_experts(
|
||||
config.fused_experts_type,
|
||||
moe,
|
||||
quant_config,
|
||||
prepare_finalize.num_dispatchers(),
|
||||
config.N,
|
||||
)
|
||||
|
||||
modular_kernel = mk.FusedMoEModularKernel(
|
||||
prepare_finalize=prepare_finalize,
|
||||
fused_experts=fused_experts,
|
||||
)
|
||||
|
||||
return modular_kernel
|
||||
|
||||
|
||||
def run_modular_kernel(
|
||||
pgi: ProcessGroupInfo,
|
||||
vllm_config: VllmConfig,
|
||||
config: Config,
|
||||
weights: WeightTensors,
|
||||
rank_tensors: RankTensors,
|
||||
) -> torch.Tensor:
|
||||
assert isinstance(config.Ms, int)
|
||||
assert isinstance(config.topks, int)
|
||||
|
||||
# weights for rank
|
||||
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
|
||||
|
||||
if config.quant_dtype == "nvfp4":
|
||||
gscale = _make_gscale(config.num_local_experts)
|
||||
else:
|
||||
gscale = None
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
config.quant_dtype,
|
||||
w1_scale=rank_weights.w1_scale,
|
||||
w2_scale=rank_weights.w2_scale,
|
||||
a1_scale=rank_tensors.hidden_states_scale,
|
||||
g1_alphas=(1 / rank_weights.w1_gs) if rank_weights.w1_gs is not None else None,
|
||||
g2_alphas=(1 / rank_weights.w2_gs) if rank_weights.w2_gs is not None else None,
|
||||
a1_gscale=gscale,
|
||||
a2_gscale=gscale,
|
||||
block_shape=config.quant_block_shape,
|
||||
per_act_token_quant=config.is_per_act_token_quant,
|
||||
per_out_ch_quant=config.is_per_out_ch_quant,
|
||||
)
|
||||
|
||||
mk = make_modular_kernel(config, vllm_config, quant_config)
|
||||
|
||||
# impls might update the tensor in place
|
||||
hidden_states = rank_tensors.hidden_states.clone()
|
||||
|
||||
topk_ids = rank_tensors.topk_ids.to(mk.prepare_finalize.topk_indices_dtype())
|
||||
|
||||
mk_kwargs = {
|
||||
"hidden_states": hidden_states,
|
||||
"w1": rank_weights.w1,
|
||||
"w2": rank_weights.w2,
|
||||
"topk_weights": rank_tensors.topk_weights,
|
||||
"topk_ids": topk_ids,
|
||||
"expert_map": rank_tensors.expert_map,
|
||||
"global_num_experts": config.E,
|
||||
"apply_router_weight_on_input": config.topk == 1
|
||||
and config.supports_apply_weight_on_input(),
|
||||
}
|
||||
|
||||
num_tokens = rank_tensors.hidden_states.shape[0]
|
||||
num_tokens_across_dp = torch.tensor(
|
||||
[num_tokens] * config.world_size, device="cuda", dtype=torch.int
|
||||
)
|
||||
|
||||
with set_forward_context(
|
||||
None,
|
||||
vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
):
|
||||
out = mk.forward(**mk_kwargs)
|
||||
|
||||
return out
|
||||
196
tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py
Normal file
196
tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
from enum import Enum
|
||||
from itertools import product
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .common import (
|
||||
Config,
|
||||
RankTensors,
|
||||
WeightTensors,
|
||||
reference_moe_impl,
|
||||
run_modular_kernel,
|
||||
)
|
||||
from .mk_objects import (
|
||||
MK_FUSED_EXPERT_TYPES,
|
||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
|
||||
MK_QUANT_CONFIGS,
|
||||
)
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch_with_config
|
||||
|
||||
|
||||
class Result(Enum):
|
||||
PASS = 1
|
||||
FAIL = 2
|
||||
SKIP = 3
|
||||
|
||||
|
||||
def rank_worker(
|
||||
pgi: ProcessGroupInfo,
|
||||
vllm_config: VllmConfig,
|
||||
cpu_group,
|
||||
config: Config,
|
||||
weights: WeightTensors,
|
||||
):
|
||||
current_platform.seed_everything(pgi.rank)
|
||||
|
||||
# sanity check
|
||||
from vllm import envs
|
||||
|
||||
if config.fused_moe_chunk_size is not None:
|
||||
assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
|
||||
# get weights to this device
|
||||
weights.to_current_device()
|
||||
|
||||
Ms = config.Ms
|
||||
assert isinstance(Ms, list)
|
||||
TOPKs = config.topks
|
||||
assert isinstance(TOPKs, list)
|
||||
|
||||
for m, topk in product(Ms, TOPKs):
|
||||
print(f"Running m={m}, topk={topk} ...")
|
||||
# override m and topk
|
||||
cfgx = copy.deepcopy(config)
|
||||
cfgx.Ms = m
|
||||
cfgx.topks = topk
|
||||
|
||||
# inputs for rank
|
||||
rank_tensors = RankTensors.make(cfgx, pgi)
|
||||
|
||||
# modular kernel out
|
||||
mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights, rank_tensors)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ref_out = reference_moe_impl(cfgx, weights, rank_tensors)
|
||||
|
||||
torch.testing.assert_close(ref_out, mk_out, atol=3e-2, rtol=3e-2)
|
||||
|
||||
|
||||
def make_feature_matrix(csv_file_path: str):
|
||||
from dataclasses import asdict
|
||||
|
||||
import pandas as pd
|
||||
|
||||
def add_to_results(
|
||||
config: Config, success: Result, results_df: pd.DataFrame | None = None
|
||||
):
|
||||
config_dict = asdict(config)
|
||||
config_dict["prepare_finalize_type"] = config_dict[
|
||||
"prepare_finalize_type"
|
||||
].__name__
|
||||
config_dict["fused_experts_type"] = config_dict["fused_experts_type"].__name__
|
||||
config_dict["per_tensor_act_quant"] = config.is_per_tensor_act_quant
|
||||
quant_config_dict = config_dict["quant_config"]
|
||||
del config_dict["quant_config"]
|
||||
if quant_config_dict is None:
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
quant_config_dict = asdict(quant_config)
|
||||
|
||||
config_dict |= quant_config_dict
|
||||
result_dict = config_dict | {"success": success.name}
|
||||
|
||||
result_df = pd.DataFrame([result_dict])
|
||||
if results_df is None:
|
||||
results_df = result_df
|
||||
else:
|
||||
results_df = pd.concat([results_df, result_df], ignore_index=True)
|
||||
|
||||
return results_df
|
||||
|
||||
Ms = [64]
|
||||
Ks = [7168] # hidden sizes
|
||||
Ns = [2048]
|
||||
TOPKs = [[4, 1]]
|
||||
Es = [32]
|
||||
DTYPEs = [torch.bfloat16]
|
||||
PF_TYPES = MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
|
||||
FE_TYPES = MK_FUSED_EXPERT_TYPES
|
||||
Q_TYPES = MK_QUANT_CONFIGS
|
||||
|
||||
combinations = list(
|
||||
product(Ms, Ks, Ns, Es, TOPKs, DTYPEs, PF_TYPES, FE_TYPES, Q_TYPES)
|
||||
)
|
||||
|
||||
results_df: pd.DataFrame | None = None
|
||||
for m, k, n, e, topks, dtype, pf_type, experts_type, quant_config in tqdm(
|
||||
combinations
|
||||
):
|
||||
config = Config(
|
||||
Ms=[m],
|
||||
K=k,
|
||||
N=n,
|
||||
E=e,
|
||||
topks=topks,
|
||||
dtype=dtype,
|
||||
prepare_finalize_type=pf_type,
|
||||
fused_experts_type=experts_type,
|
||||
quant_config=quant_config,
|
||||
world_size=2,
|
||||
fused_moe_chunk_size=None,
|
||||
)
|
||||
|
||||
success = None
|
||||
if config.is_valid()[0]:
|
||||
print(f"Running config : {config.describe()} ...")
|
||||
try:
|
||||
weights: WeightTensors = WeightTensors.make(config)
|
||||
vllm_config, env_dict = config.make_env_data()
|
||||
parallel_launch_with_config(
|
||||
config.world_size,
|
||||
rank_worker,
|
||||
vllm_config,
|
||||
env_dict,
|
||||
config,
|
||||
weights,
|
||||
)
|
||||
success = Result.PASS
|
||||
except Exception as _:
|
||||
success = Result.FAIL
|
||||
else:
|
||||
success = Result.SKIP
|
||||
|
||||
results_df = add_to_results(config, success, results_df)
|
||||
|
||||
if results_df is not None:
|
||||
results_df.to_csv(f"{csv_file_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Make ModularKernel feature matrix \n"
|
||||
"Example : python3 -m tests.kernels.moe.modular_kernel_tools.make_feature_matrix " # noqa: E501
|
||||
"-f ./feature_matrices/feature_matrix.csv"
|
||||
)
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--feature-matrix-csv-file-path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="File name to Generate a .csv file",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
csv_path = args.feature_matrix_csv_file_path
|
||||
assert csv_path.endswith("csv"), (
|
||||
f"Need a file path ending with .csv, got {csv_path}"
|
||||
)
|
||||
assert Path(csv_path).parent.is_dir(), (
|
||||
f"Cannot find parent directory for {Path(csv_path).parent}"
|
||||
)
|
||||
|
||||
make_feature_matrix(args.feature_matrix_csv_file_path)
|
||||
509
tests/kernels/moe/modular_kernel_tools/mk_objects.py
Normal file
509
tests/kernels/moe/modular_kernel_tools/mk_objects.py
Normal file
@@ -0,0 +1,509 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
# Fused experts and PrepareFinalize imports
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe import TritonExperts
|
||||
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
maybe_make_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts,
|
||||
NaiveBatchedExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
cutlass_fp4_supported,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
cutlass_fp8_supported,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_supported
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestMoEQuantConfig:
|
||||
quant_dtype: torch.dtype | str | None
|
||||
per_out_ch_quant: bool
|
||||
per_act_token_quant: bool
|
||||
block_shape: list[int] | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PrepareFinalizeInfo:
|
||||
activation_format: mk.FusedMoEActivationFormat
|
||||
supported_dtypes: list[torch.dtype | str]
|
||||
blocked_quantization_support: bool
|
||||
backend: str | None
|
||||
supports_apply_weight_on_input: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpertInfo:
|
||||
activation_format: mk.FusedMoEActivationFormat
|
||||
supported_dtypes: list[torch.dtype | str]
|
||||
blocked_quantization_support: bool
|
||||
supports_chunking: bool
|
||||
supports_expert_map: bool
|
||||
needs_matching_quant: bool = False
|
||||
needs_deep_gemm: bool = False
|
||||
|
||||
|
||||
PREPARE_FINALIZE_INFO: dict[mk.FusedMoEPrepareAndFinalize, PrepareFinalizeInfo] = {}
|
||||
EXPERT_INFO: dict[mk.FusedMoEPermuteExpertsUnpermute, ExpertInfo] = {}
|
||||
MK_ALL_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
|
||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
|
||||
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
|
||||
MK_FUSED_EXPERT_TYPES: list[mk.FusedMoEPermuteExpertsUnpermute] = []
|
||||
|
||||
standard_format = mk.FusedMoEActivationFormat.Standard
|
||||
batched_format = mk.FusedMoEActivationFormat.BatchedExperts
|
||||
common_float_types: list[torch.dtype | str] = [
|
||||
torch.float8_e4m3fn,
|
||||
torch.bfloat16,
|
||||
torch.float16,
|
||||
torch.float32,
|
||||
]
|
||||
common_float_and_int_types = common_float_types + [torch.int8]
|
||||
nvfp4_types = ["nvfp4"]
|
||||
fp8_types = [torch.float8_e4m3fn]
|
||||
|
||||
|
||||
def register_prepare_and_finalize(
|
||||
kind,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
supported_dtypes: list[torch.dtype | str],
|
||||
blocked_quantization_support: bool,
|
||||
backend: str | None,
|
||||
force_multigpu: bool = False,
|
||||
supports_apply_weight_on_input: bool = True,
|
||||
):
|
||||
global PREPARE_FINALIZE_INFO
|
||||
global MK_ALL_PREPARE_FINALIZE_TYPES
|
||||
global MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
|
||||
global MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
|
||||
assert kind not in PREPARE_FINALIZE_INFO
|
||||
|
||||
PREPARE_FINALIZE_INFO[kind] = PrepareFinalizeInfo(
|
||||
activation_format,
|
||||
supported_dtypes,
|
||||
blocked_quantization_support,
|
||||
backend,
|
||||
supports_apply_weight_on_input,
|
||||
)
|
||||
MK_ALL_PREPARE_FINALIZE_TYPES.append(kind)
|
||||
if backend is not None or force_multigpu:
|
||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES.append(kind)
|
||||
else:
|
||||
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES.append(kind)
|
||||
|
||||
|
||||
def register_experts(
|
||||
kind,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
supported_dtypes: list[torch.dtype | str],
|
||||
blocked_quantization_support: bool,
|
||||
supports_chunking: bool,
|
||||
supports_expert_map: bool,
|
||||
needs_matching_quant: bool = False,
|
||||
needs_deep_gemm: bool = False,
|
||||
):
|
||||
global EXPERT_INFO
|
||||
global MK_FUSED_EXPERT_TYPES
|
||||
assert kind not in EXPERT_INFO
|
||||
|
||||
EXPERT_INFO[kind] = ExpertInfo(
|
||||
activation_format,
|
||||
supported_dtypes,
|
||||
blocked_quantization_support,
|
||||
supports_chunking,
|
||||
supports_expert_map,
|
||||
needs_matching_quant,
|
||||
needs_deep_gemm,
|
||||
)
|
||||
|
||||
MK_FUSED_EXPERT_TYPES.append(kind)
|
||||
|
||||
|
||||
def prepare_finalize_info(kind) -> PrepareFinalizeInfo:
|
||||
info = PREPARE_FINALIZE_INFO.get(kind)
|
||||
assert info is not None
|
||||
return info
|
||||
|
||||
|
||||
def expert_info(kind) -> ExpertInfo:
|
||||
info = EXPERT_INFO.get(kind)
|
||||
assert info is not None
|
||||
return info
|
||||
|
||||
|
||||
register_prepare_and_finalize(
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
standard_format,
|
||||
common_float_types,
|
||||
blocked_quantization_support=True,
|
||||
backend=None,
|
||||
)
|
||||
|
||||
register_experts(
|
||||
BatchedTritonExperts,
|
||||
batched_format,
|
||||
common_float_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=False,
|
||||
supports_expert_map=False,
|
||||
needs_matching_quant=True,
|
||||
)
|
||||
|
||||
register_experts(
|
||||
TritonExperts,
|
||||
standard_format,
|
||||
common_float_and_int_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=True,
|
||||
needs_matching_quant=True,
|
||||
)
|
||||
|
||||
register_experts(
|
||||
NaiveBatchedExperts,
|
||||
batched_format,
|
||||
common_float_and_int_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=False,
|
||||
supports_expert_map=True,
|
||||
)
|
||||
|
||||
# Disable on blackwell for now
|
||||
if has_deep_ep() and not current_platform.has_device_capability(100):
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import (
|
||||
DeepEPHTPrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import (
|
||||
DeepEPLLPrepareAndFinalize,
|
||||
)
|
||||
|
||||
register_prepare_and_finalize(
|
||||
DeepEPHTPrepareAndFinalize,
|
||||
standard_format,
|
||||
common_float_types,
|
||||
blocked_quantization_support=True,
|
||||
backend="deepep_high_throughput",
|
||||
)
|
||||
|
||||
register_prepare_and_finalize(
|
||||
DeepEPLLPrepareAndFinalize,
|
||||
batched_format,
|
||||
common_float_types,
|
||||
blocked_quantization_support=True,
|
||||
backend="deepep_low_latency",
|
||||
)
|
||||
|
||||
if has_pplx():
|
||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize,
|
||||
)
|
||||
|
||||
register_prepare_and_finalize(
|
||||
PplxPrepareAndFinalize,
|
||||
batched_format,
|
||||
common_float_and_int_types,
|
||||
blocked_quantization_support=True,
|
||||
backend="pplx",
|
||||
)
|
||||
|
||||
if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability(100):
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||
FlashInferCutlassMoEPrepareAndFinalize,
|
||||
create_flashinfer_prepare_finalize,
|
||||
)
|
||||
|
||||
register_prepare_and_finalize(
|
||||
FlashInferCutlassMoEPrepareAndFinalize,
|
||||
standard_format,
|
||||
nvfp4_types + fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
backend=None,
|
||||
force_multigpu=True,
|
||||
supports_apply_weight_on_input=False,
|
||||
)
|
||||
|
||||
register_experts(
|
||||
FlashInferExperts,
|
||||
standard_format,
|
||||
nvfp4_types + fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
# Note: this is a hack to get it to run for now
|
||||
supports_expert_map=True,
|
||||
)
|
||||
else:
|
||||
FlashInferCutlassMoEPrepareAndFinalize = None
|
||||
|
||||
if has_deep_gemm() and is_deep_gemm_supported():
|
||||
register_experts(
|
||||
BatchedDeepGemmExperts,
|
||||
batched_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=False,
|
||||
supports_expert_map=False,
|
||||
needs_matching_quant=False,
|
||||
needs_deep_gemm=True,
|
||||
)
|
||||
register_experts(
|
||||
DeepGemmExperts,
|
||||
standard_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=True,
|
||||
needs_matching_quant=False,
|
||||
needs_deep_gemm=True,
|
||||
)
|
||||
register_experts(
|
||||
TritonOrDeepGemmExperts,
|
||||
standard_format,
|
||||
common_float_and_int_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=True,
|
||||
needs_matching_quant=True,
|
||||
needs_deep_gemm=True,
|
||||
)
|
||||
|
||||
if cutlass_fp8_supported():
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
CutlassBatchedExpertsFp8,
|
||||
CutlassExpertsFp8,
|
||||
)
|
||||
|
||||
register_experts(
|
||||
CutlassExpertsFp8,
|
||||
standard_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=False,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=False,
|
||||
)
|
||||
register_experts(
|
||||
CutlassBatchedExpertsFp8,
|
||||
batched_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=False,
|
||||
supports_chunking=False,
|
||||
supports_expert_map=False,
|
||||
)
|
||||
|
||||
if cutlass_fp4_supported():
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp4
|
||||
|
||||
register_experts(
|
||||
CutlassExpertsFp4,
|
||||
standard_format,
|
||||
nvfp4_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=False,
|
||||
)
|
||||
|
||||
MK_QUANT_CONFIGS: list[TestMoEQuantConfig | None] = [
|
||||
None,
|
||||
# per-channel / per-column weights and per-tensor activations
|
||||
TestMoEQuantConfig(
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=True,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None,
|
||||
),
|
||||
# per-channel / per-column weights and per-token activations
|
||||
TestMoEQuantConfig(
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=True,
|
||||
per_act_token_quant=True,
|
||||
block_shape=None,
|
||||
),
|
||||
# per-tensor weights and per-tensor activations
|
||||
TestMoEQuantConfig(
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None,
|
||||
),
|
||||
# per-tensor weights and per-token activations
|
||||
TestMoEQuantConfig(
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=True,
|
||||
block_shape=None,
|
||||
),
|
||||
# block-quantized weights and 128 block per-token activations
|
||||
TestMoEQuantConfig(
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=[128, 128],
|
||||
),
|
||||
# TODO (varun) : Should we test the following combinations ?
|
||||
# block-quantized weights and per-token activations
|
||||
# block-quantized weights and per-tensor activations
|
||||
]
|
||||
|
||||
if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe():
|
||||
MK_QUANT_CONFIGS += [
|
||||
TestMoEQuantConfig(
|
||||
quant_dtype="nvfp4",
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def make_prepare_finalize(
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
|
||||
backend: str | None,
|
||||
moe: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.FusedMoEPrepareAndFinalize:
|
||||
if backend != "naive" and backend is not None:
|
||||
prepare_finalize = maybe_make_prepare_finalize(moe, quant_config)
|
||||
assert prepare_finalize is not None
|
||||
return prepare_finalize
|
||||
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
|
||||
return create_flashinfer_prepare_finalize(
|
||||
use_dp=moe.moe_parallel_config.dp_size > 1
|
||||
)
|
||||
else:
|
||||
return MoEPrepareAndFinalizeNoEP()
|
||||
|
||||
|
||||
def _slice(rank: int, num_local_experts: int, t: torch.Tensor) -> torch.Tensor:
|
||||
s = rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
return t[s:e]
|
||||
|
||||
|
||||
def make_cutlass_strides(
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
ab_strides1 = torch.full((e,), k, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e,), n, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64)
|
||||
return ab_strides1, ab_strides2, c_strides1, c_strides2
|
||||
|
||||
|
||||
def make_fused_experts(
|
||||
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
|
||||
moe: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
num_dispatchers: int,
|
||||
N: int,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
batch_kwargs = {
|
||||
"max_num_tokens": moe.max_num_tokens,
|
||||
"num_dispatchers": num_dispatchers,
|
||||
}
|
||||
quant_kwargs = {
|
||||
"quant_config": quant_config,
|
||||
}
|
||||
deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()}
|
||||
|
||||
torch.set_printoptions(threshold=0, edgeitems=0, linewidth=10000)
|
||||
|
||||
if fused_experts_type == BatchedDeepGemmExperts:
|
||||
kwargs = batch_kwargs | quant_kwargs
|
||||
print(f"Making BatchedDeepGemmExperts {kwargs} ...")
|
||||
experts = BatchedDeepGemmExperts(**kwargs)
|
||||
elif fused_experts_type == BatchedTritonExperts:
|
||||
kwargs = batch_kwargs | quant_kwargs
|
||||
print(f"Making BatchedTritonExperts {kwargs} ...")
|
||||
experts = BatchedTritonExperts(**kwargs)
|
||||
elif fused_experts_type == DeepGemmExperts:
|
||||
print(f"Making DeepGemmExperts {quant_config} ...")
|
||||
experts = DeepGemmExperts(quant_config)
|
||||
elif fused_experts_type == TritonExperts:
|
||||
kwargs = quant_kwargs
|
||||
print(f"Making TritonExperts {kwargs} ...")
|
||||
experts = TritonExperts(**kwargs)
|
||||
elif fused_experts_type == TritonOrDeepGemmExperts:
|
||||
kwargs = quant_kwargs | deepgemm_kwargs
|
||||
print(f"Making TritonOrDeepGemmExperts {kwargs} ...")
|
||||
experts = TritonOrDeepGemmExperts(**kwargs)
|
||||
elif fused_experts_type == NaiveBatchedExperts:
|
||||
kwargs = batch_kwargs | quant_kwargs
|
||||
print(f"Making NaiveBatchedExperts {kwargs} ...")
|
||||
experts = NaiveBatchedExperts(**kwargs)
|
||||
elif fused_experts_type == CutlassExpertsFp8:
|
||||
strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim)
|
||||
kwargs = {
|
||||
"out_dtype": moe.in_dtype,
|
||||
"ab_strides1": strides[0],
|
||||
"ab_strides2": strides[1],
|
||||
"c_strides1": strides[2],
|
||||
"c_strides2": strides[3],
|
||||
} | quant_kwargs
|
||||
print(f"Making CutlassExpertsFp8 {kwargs} ...")
|
||||
experts = CutlassExpertsFp8(**kwargs)
|
||||
elif fused_experts_type == CutlassBatchedExpertsFp8:
|
||||
strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim)
|
||||
kwargs = {
|
||||
"max_experts_per_worker": moe.num_local_experts,
|
||||
"num_dispatchers": num_dispatchers,
|
||||
"out_dtype": moe.in_dtype,
|
||||
"ab_strides1": strides[0],
|
||||
"ab_strides2": strides[1],
|
||||
"c_strides1": strides[2],
|
||||
"c_strides2": strides[3],
|
||||
} | quant_kwargs
|
||||
print(f"Making CutlassBatchedExpertsFp8 {kwargs} ...")
|
||||
experts = CutlassBatchedExpertsFp8(**kwargs)
|
||||
elif fused_experts_type == CutlassExpertsFp4:
|
||||
kwargs = {
|
||||
"max_experts_per_worker": moe.num_local_experts,
|
||||
"num_dispatchers": num_dispatchers,
|
||||
"out_dtype": moe.in_dtype,
|
||||
} | quant_kwargs
|
||||
print(f"Making CutlassExpertsFp4 {kwargs} ...")
|
||||
experts = CutlassExpertsFp4(**kwargs)
|
||||
elif fused_experts_type == FlashInferExperts:
|
||||
kwargs = {
|
||||
"out_dtype": moe.in_dtype,
|
||||
"ep_rank": moe.ep_rank,
|
||||
"ep_size": moe.ep_size,
|
||||
"tp_rank": moe.tp_rank,
|
||||
"tp_size": moe.tp_size,
|
||||
} | quant_kwargs
|
||||
print(f"Making FlashInferExperts {kwargs} ...")
|
||||
experts = FlashInferExperts(**kwargs)
|
||||
else:
|
||||
raise RuntimeError(f"Unknown fused experts type: {fused_experts_type}")
|
||||
|
||||
torch.set_printoptions(threshold=1000, edgeitems=5, linewidth=80)
|
||||
|
||||
return experts
|
||||
134
tests/kernels/moe/modular_kernel_tools/parallel_utils.py
Normal file
134
tests/kernels/moe/modular_kernel_tools/parallel_utils.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import dataclasses
|
||||
import os
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Concatenate
|
||||
|
||||
import torch
|
||||
from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage]
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed import init_distributed_environment, initialize_model_parallel
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
|
||||
## Parallel Processes Utils
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ProcessGroupInfo:
|
||||
world_size: int
|
||||
world_local_size: int
|
||||
rank: int
|
||||
node_rank: int
|
||||
local_rank: int
|
||||
device: torch.device
|
||||
|
||||
|
||||
def _set_vllm_config(
|
||||
vllm_config: VllmConfig, world_size: int, rank: int, local_rank: int
|
||||
):
|
||||
import tempfile
|
||||
|
||||
temp_file = tempfile.mkstemp()[1]
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
init_distributed_environment(
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
distributed_init_method=f"file://{temp_file}",
|
||||
local_rank=local_rank,
|
||||
backend="nccl",
|
||||
)
|
||||
|
||||
initialize_model_parallel(
|
||||
tensor_model_parallel_size=vllm_config.parallel_config.tensor_parallel_size,
|
||||
pipeline_model_parallel_size=vllm_config.parallel_config.pipeline_parallel_size,
|
||||
)
|
||||
cpu_group = torch.distributed.new_group(list(range(world_size)), backend="gloo")
|
||||
return cpu_group
|
||||
|
||||
|
||||
def _worker_parallel_launch(
|
||||
local_rank: int,
|
||||
world_size: int,
|
||||
world_local_size: int,
|
||||
node_rank: int,
|
||||
init_method: str,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, VllmConfig | None, Any, P], None],
|
||||
vllm_config: VllmConfig | None,
|
||||
env_dict: dict | None,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
rank = node_rank * world_local_size + local_rank
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device("cuda", local_rank)
|
||||
torch.distributed.init_process_group(
|
||||
backend="cpu:gloo,cuda:nccl",
|
||||
init_method=init_method,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
device_id=device,
|
||||
)
|
||||
barrier = torch.tensor([rank], device=device)
|
||||
torch.distributed.all_reduce(barrier)
|
||||
|
||||
if env_dict is not None:
|
||||
os.environ.update(env_dict)
|
||||
|
||||
cpu_group = None
|
||||
if vllm_config is not None:
|
||||
cpu_group = _set_vllm_config(vllm_config, world_size, rank, local_rank)
|
||||
|
||||
try:
|
||||
worker(
|
||||
ProcessGroupInfo(
|
||||
world_size=world_size,
|
||||
world_local_size=world_local_size,
|
||||
rank=rank,
|
||||
node_rank=node_rank,
|
||||
local_rank=local_rank,
|
||||
device=device,
|
||||
),
|
||||
vllm_config,
|
||||
cpu_group,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def parallel_launch_with_config(
|
||||
world_size: int,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, VllmConfig, Any, P], None],
|
||||
vllm_config: VllmConfig,
|
||||
env_dict: dict[Any, Any],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
assert not kwargs
|
||||
spawn(
|
||||
_worker_parallel_launch,
|
||||
args=(
|
||||
world_size,
|
||||
world_size,
|
||||
0,
|
||||
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
|
||||
worker,
|
||||
vllm_config,
|
||||
env_dict,
|
||||
)
|
||||
+ args,
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
||||
137
tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py
Normal file
137
tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
from collections.abc import Callable
|
||||
from itertools import product
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .common import Config, RankTensors, WeightTensors, make_modular_kernel
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch_with_config
|
||||
|
||||
|
||||
def do_profile(
|
||||
fn: Callable,
|
||||
fn_kwargs: dict[Any, Any],
|
||||
pgi: ProcessGroupInfo,
|
||||
config: Config,
|
||||
num_warmups: int = 5,
|
||||
):
|
||||
for _ in range(num_warmups):
|
||||
fn(**fn_kwargs)
|
||||
|
||||
with torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
with_stack=True,
|
||||
record_shapes=True,
|
||||
) as tprof:
|
||||
fn(**fn_kwargs)
|
||||
torch.cuda.synchronize(torch.cuda.current_device())
|
||||
|
||||
# TODO (varun): Add a descriptive trace file name
|
||||
tprof.export_chrome_trace(
|
||||
f"{config.torch_trace_dir_path}/m{config.M}_{pgi.rank}_trace.json"
|
||||
)
|
||||
|
||||
|
||||
def profile_modular_kernel(
|
||||
pgi: ProcessGroupInfo,
|
||||
vllm_config: VllmConfig,
|
||||
config: Config,
|
||||
weights: WeightTensors,
|
||||
rank_tensors: RankTensors,
|
||||
) -> None:
|
||||
assert isinstance(config.Ms, int)
|
||||
assert isinstance(config.topks, int)
|
||||
|
||||
# weights for rank
|
||||
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
|
||||
|
||||
# make modular kernel
|
||||
mk = make_modular_kernel(config, vllm_config, weights)
|
||||
|
||||
mk_kwargs = {
|
||||
"hidden_states": rank_tensors.hidden_states,
|
||||
"w1": rank_weights.w1,
|
||||
"w2": rank_weights.w2,
|
||||
"topk_weights": rank_tensors.topk_weights,
|
||||
"topk_ids": rank_tensors.topk_ids,
|
||||
"expert_map": rank_tensors.expert_map,
|
||||
"w1_scale": rank_weights.w1_scale,
|
||||
"w2_scale": rank_weights.w2_scale,
|
||||
"a1_scale": rank_tensors.hidden_states_scale,
|
||||
"global_num_experts": config.E,
|
||||
"apply_router_weight_on_input": config.topk == 1,
|
||||
}
|
||||
|
||||
do_profile(mk.forward, mk_kwargs, pgi, config)
|
||||
|
||||
|
||||
def rank_worker(
|
||||
pgi: ProcessGroupInfo,
|
||||
vllm_config: VllmConfig,
|
||||
cpu_group,
|
||||
config: Config,
|
||||
weights: WeightTensors,
|
||||
):
|
||||
current_platform.seed_everything(pgi.rank)
|
||||
|
||||
# sanity check
|
||||
from vllm import envs
|
||||
|
||||
if config.fused_moe_chunk_size is not None:
|
||||
assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
|
||||
# get weights to this device
|
||||
weights.to_current_device()
|
||||
|
||||
Ms = config.Ms
|
||||
assert isinstance(Ms, list)
|
||||
TOPKs = config.topks
|
||||
assert isinstance(TOPKs, list)
|
||||
|
||||
for m, topk in product(Ms, TOPKs):
|
||||
print(f"Running m={m}, topk={topk} ...")
|
||||
# override m and topk
|
||||
cfgx = copy.deepcopy(config)
|
||||
cfgx.Ms = m
|
||||
cfgx.topks = topk
|
||||
|
||||
# inputs for rank
|
||||
rank_tensors = RankTensors.make(cfgx, pgi)
|
||||
profile_modular_kernel(pgi, vllm_config, cfgx, weights, rank_tensors)
|
||||
|
||||
|
||||
def run(config: Config):
|
||||
weights: WeightTensors = WeightTensors.make(config)
|
||||
vllm_config, env_dict = config.make_env_data()
|
||||
parallel_launch_with_config(
|
||||
config.world_size, rank_worker, vllm_config, env_dict, config, weights
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from .cli_args import make_config, make_config_arg_parser
|
||||
|
||||
parser = make_config_arg_parser(
|
||||
description=(
|
||||
"Run single prepare-finalize & fused-experts combination test"
|
||||
"Example : python3 -m tests.kernels.moe.modular_kernel_tools.profile_modular_kernel " # noqa: E501
|
||||
"--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts"
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
assert args.torch_trace_dir_path is not None, (
|
||||
"Please pass in a directory to store torch traces"
|
||||
)
|
||||
config = make_config(args)
|
||||
|
||||
run(config)
|
||||
202
tests/kernels/moe/parallel_utils.py
Normal file
202
tests/kernels/moe/parallel_utils.py
Normal file
@@ -0,0 +1,202 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
DeepEP test utilities
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from typing import Concatenate
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage]
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.utils.import_utils import has_deep_ep
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
|
||||
if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import (
|
||||
DeepEPHTPrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import (
|
||||
DeepEPLLPrepareAndFinalize,
|
||||
)
|
||||
|
||||
## Parallel Processes Utils
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ProcessGroupInfo:
|
||||
world_size: int
|
||||
world_local_size: int
|
||||
rank: int
|
||||
node_rank: int
|
||||
local_rank: int
|
||||
device: torch.device
|
||||
|
||||
|
||||
def _worker_parallel_launch(
|
||||
local_rank: int,
|
||||
world_size: int,
|
||||
world_local_size: int,
|
||||
node_rank: int,
|
||||
init_method: str,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
rank = node_rank * world_local_size + local_rank
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device("cuda", local_rank)
|
||||
torch.distributed.init_process_group(
|
||||
backend="cpu:gloo,cuda:nccl",
|
||||
init_method=init_method,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
device_id=device,
|
||||
)
|
||||
barrier = torch.tensor([rank], device=device)
|
||||
torch.distributed.all_reduce(barrier)
|
||||
|
||||
try:
|
||||
worker(
|
||||
ProcessGroupInfo(
|
||||
world_size=world_size,
|
||||
world_local_size=world_local_size,
|
||||
rank=rank,
|
||||
node_rank=node_rank,
|
||||
local_rank=local_rank,
|
||||
device=device,
|
||||
),
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def parallel_launch(
|
||||
world_size: int,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
assert not kwargs
|
||||
spawn(
|
||||
_worker_parallel_launch,
|
||||
args=(
|
||||
world_size,
|
||||
world_size,
|
||||
0,
|
||||
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
|
||||
worker,
|
||||
)
|
||||
+ args,
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
||||
|
||||
|
||||
## DeepEP specific utils
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DeepEPHTArgs:
|
||||
num_local_experts: int
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DeepEPLLArgs:
|
||||
max_tokens_per_rank: int
|
||||
hidden_size: int
|
||||
num_experts: int
|
||||
use_fp8_dispatch: bool
|
||||
|
||||
|
||||
def make_deepep_ht_a2a(
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
ht_args: DeepEPHTArgs,
|
||||
q_dtype: torch.dtype | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
):
|
||||
import deep_ep
|
||||
|
||||
# high throughput a2a
|
||||
num_nvl_bytes = 1024 * 1024 * 1024 # 1GB
|
||||
num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1
|
||||
buffer = deep_ep.Buffer(
|
||||
group=pg,
|
||||
num_nvl_bytes=num_nvl_bytes,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=low_latency_mode,
|
||||
num_qps_per_rank=num_qps_per_rank,
|
||||
)
|
||||
return DeepEPHTPrepareAndFinalize(
|
||||
buffer=buffer,
|
||||
num_dispatchers=pgi.world_size,
|
||||
dp_size=dp_size,
|
||||
rank_expert_offset=pgi.rank * ht_args.num_local_experts,
|
||||
)
|
||||
|
||||
|
||||
def make_deepep_ll_a2a(
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
deepep_ll_args: DeepEPLLArgs,
|
||||
q_dtype: torch.dtype | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
):
|
||||
import deep_ep
|
||||
|
||||
# low-latency a2a
|
||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
||||
deepep_ll_args.max_tokens_per_rank,
|
||||
deepep_ll_args.hidden_size,
|
||||
pgi.world_size,
|
||||
deepep_ll_args.num_experts,
|
||||
)
|
||||
|
||||
buffer = deep_ep.Buffer(
|
||||
group=pg,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=deepep_ll_args.num_experts // pgi.world_size,
|
||||
)
|
||||
|
||||
return DeepEPLLPrepareAndFinalize(
|
||||
buffer=buffer,
|
||||
num_dispatchers=pgi.world_size,
|
||||
max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank,
|
||||
use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
|
||||
)
|
||||
|
||||
|
||||
def make_deepep_a2a(
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
deepep_ht_args: DeepEPHTArgs | None,
|
||||
deepep_ll_args: DeepEPLLArgs | None,
|
||||
q_dtype: torch.dtype | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
):
|
||||
if deepep_ht_args is not None:
|
||||
assert deepep_ll_args is None
|
||||
return make_deepep_ht_a2a(
|
||||
pg, pgi, dp_size, deepep_ht_args, q_dtype, block_shape
|
||||
)
|
||||
|
||||
assert deepep_ll_args is not None
|
||||
return make_deepep_ll_a2a(pg, pgi, deepep_ll_args, q_dtype, block_shape)
|
||||
106
tests/kernels/moe/test_batched_deepgemm.py
Normal file
106
tests/kernels/moe/test_batched_deepgemm.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedPrepareAndFinalize,
|
||||
BatchedTritonExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.utils.deep_gemm import calc_diff, is_deep_gemm_supported
|
||||
|
||||
from .test_deepgemm import make_block_quant_fp8_weights
|
||||
|
||||
BLOCK_SIZE = [128, 128]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_deep_gemm_supported(), reason="Requires deep_gemm kernels")
|
||||
@pytest.mark.parametrize("E", [16, 32]) # number of experts
|
||||
@pytest.mark.parametrize("T", [256, 512]) # tokens per expert
|
||||
@pytest.mark.parametrize("K", [128, 256]) # hidden dim
|
||||
@pytest.mark.parametrize("N", [512, 1024]) # intermediate dim per expert
|
||||
@pytest.mark.parametrize("topk", [2, 4])
|
||||
def test_batched_deepgemm_vs_triton(
|
||||
E: int, T: int, K: int, N: int, topk: int, monkeypatch, workspace_init
|
||||
):
|
||||
"""Compare BatchedDeepGemmExperts to BatchedTritonExperts."""
|
||||
|
||||
monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1")
|
||||
|
||||
device = "cuda"
|
||||
w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(E, N, K, BLOCK_SIZE)
|
||||
|
||||
M = E * T # total tokens
|
||||
a = torch.randn(M, K, device=device, dtype=torch.bfloat16) / 10.0
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
a.clamp_(fp8_info.min, fp8_info.max)
|
||||
|
||||
# random router outputs → top-k indices / weights
|
||||
router_logits = torch.randn(M, E, device=device, dtype=torch.float32)
|
||||
topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1)
|
||||
topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)
|
||||
|
||||
# token number for each expert
|
||||
cnt = torch.bincount(topk_ids.flatten(), minlength=E)
|
||||
max_cnt = int(cnt.max().item())
|
||||
# next power of 2 for max token number
|
||||
max_num_tokens = 1 << (max_cnt - 1).bit_length()
|
||||
|
||||
prep_finalize = BatchedPrepareAndFinalize(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_local_experts=E,
|
||||
num_dispatchers=1,
|
||||
rank=0,
|
||||
)
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
per_act_token_quant=False,
|
||||
block_shape=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
# triton (reference)
|
||||
triton_experts = BatchedTritonExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
mk_triton = FusedMoEModularKernel(prep_finalize, triton_experts)
|
||||
|
||||
out_triton = mk_triton(
|
||||
hidden_states=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
global_num_experts=E,
|
||||
)
|
||||
|
||||
# deepgemm
|
||||
deepgemm_experts = BatchedDeepGemmExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
mk_deepgemm = FusedMoEModularKernel(prep_finalize, deepgemm_experts)
|
||||
|
||||
out_deepgemm = mk_deepgemm(
|
||||
hidden_states=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
global_num_experts=E,
|
||||
)
|
||||
|
||||
diff = calc_diff(out_deepgemm, out_triton)
|
||||
assert diff < 1e-3, f"Output diff too large: {diff}"
|
||||
352
tests/kernels/moe/test_batched_moe.py
Normal file
352
tests/kernels/moe/test_batched_moe.py
Normal file
@@ -0,0 +1,352 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import (
|
||||
batched_moe,
|
||||
make_quantized_test_activations,
|
||||
make_test_weights,
|
||||
naive_batched_moe,
|
||||
)
|
||||
from tests.kernels.quant_utils import native_batched_masked_quant_matmul
|
||||
from tests.kernels.utils import torch_experts
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
invoke_moe_batched_triton_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 128, 128),
|
||||
(1, 512, 512),
|
||||
(1, 1024, 2048),
|
||||
(32, 128, 128),
|
||||
(32, 512, 512),
|
||||
(32, 1024, 2048),
|
||||
(45, 128, 2048),
|
||||
(45, 1024, 128),
|
||||
(64, 512, 512),
|
||||
(64, 1024, 2048),
|
||||
(222, 128, 2048),
|
||||
(222, 1024, 2048),
|
||||
]
|
||||
NUM_EXPERTS = [8, 64]
|
||||
TOP_KS = [1, 2, 6]
|
||||
|
||||
DTYPES = [torch.bfloat16]
|
||||
|
||||
if not current_platform.is_fp8_fnuz():
|
||||
DTYPES.append(torch.float8_e4m3fn)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchedMMConfig:
|
||||
in_dtype: torch.dtype
|
||||
quant_dtype: torch.dtype | None
|
||||
out_dtype: torch.dtype
|
||||
num_experts: int
|
||||
max_tokens_per_expert: int
|
||||
K: int
|
||||
N: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchedMMTensors:
|
||||
A: torch.Tensor # [E, max_tokens, K]
|
||||
B: torch.Tensor # [E, K, N] - column major
|
||||
C: torch.Tensor # [E, max_tokens, N]
|
||||
num_expert_tokens: torch.Tensor # [E]
|
||||
|
||||
@staticmethod
|
||||
def make_tensors(config: BatchedMMConfig):
|
||||
A = (
|
||||
torch.randn(
|
||||
(config.num_experts, config.max_tokens_per_expert, config.K),
|
||||
device="cuda",
|
||||
dtype=config.in_dtype,
|
||||
)
|
||||
/ 10
|
||||
)
|
||||
B = torch.randn(
|
||||
(config.num_experts, config.N, config.K),
|
||||
device="cuda",
|
||||
dtype=config.in_dtype,
|
||||
)
|
||||
C = torch.zeros(
|
||||
(config.num_experts, config.max_tokens_per_expert, config.N),
|
||||
device="cuda",
|
||||
dtype=config.out_dtype,
|
||||
)
|
||||
|
||||
num_expert_tokens = torch.randint(
|
||||
low=0,
|
||||
high=config.max_tokens_per_expert,
|
||||
size=(config.num_experts,),
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
return BatchedMMTensors(A, B, C, num_expert_tokens)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_experts", [8, 32])
|
||||
@pytest.mark.parametrize("max_tokens_per_expert", [32, 224, 512])
|
||||
@pytest.mark.parametrize("K", [128, 1024])
|
||||
@pytest.mark.parametrize("N", [128, 1024])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
|
||||
@pytest.mark.parametrize("per_act_token_quant", [False, True])
|
||||
def test_batched_mm(
|
||||
num_experts: int,
|
||||
max_tokens_per_expert: int,
|
||||
K: int,
|
||||
N: int,
|
||||
dtype: torch.dtype,
|
||||
block_shape: list[int] | None,
|
||||
per_act_token_quant: bool,
|
||||
):
|
||||
"""Note: float8_e4m3fn is not supported on CUDA architecture < 89,
|
||||
and those tests will be skipped on unsupported hardware."""
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
|
||||
|
||||
if (dtype == torch.float8_e4m3fn) and not current_platform.has_device_capability(
|
||||
89
|
||||
):
|
||||
pytest.skip(
|
||||
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
|
||||
)
|
||||
|
||||
if (per_act_token_quant or block_shape is not None) and not use_fp8_w8a8:
|
||||
pytest.skip("Don't test blocking for non-quantized types.")
|
||||
|
||||
if per_act_token_quant and block_shape is not None:
|
||||
pytest.skip("Skip illegal quantization test.")
|
||||
|
||||
if dtype.itemsize == 1:
|
||||
act_dtype = torch.bfloat16
|
||||
quant_dtype = dtype
|
||||
else:
|
||||
act_dtype = dtype
|
||||
quant_dtype = None
|
||||
|
||||
num_expert_tokens = torch.randint(
|
||||
low=0,
|
||||
high=max_tokens_per_expert,
|
||||
size=(num_experts,),
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
A, A_q, A_scale = make_quantized_test_activations(
|
||||
num_experts,
|
||||
max_tokens_per_expert,
|
||||
K,
|
||||
in_dtype=act_dtype,
|
||||
quant_dtype=quant_dtype,
|
||||
block_shape=block_shape,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
)
|
||||
|
||||
(B, B_q, B_scale, _), _ = make_test_weights(
|
||||
num_experts,
|
||||
N // 2,
|
||||
K,
|
||||
in_dtype=act_dtype,
|
||||
quant_dtype=quant_dtype,
|
||||
block_shape=block_shape,
|
||||
per_out_ch_quant=per_act_token_quant,
|
||||
)
|
||||
|
||||
out_shape = (num_experts, max_tokens_per_expert, N)
|
||||
test_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
|
||||
ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
|
||||
q_ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
|
||||
|
||||
compute_tl_dtype = {
|
||||
torch.float16: tl.float16,
|
||||
torch.bfloat16: tl.bfloat16,
|
||||
torch.float32: tl.float32,
|
||||
}[test_output.dtype]
|
||||
|
||||
assert A_q.dtype == B_q.dtype
|
||||
|
||||
invoke_moe_batched_triton_kernel(
|
||||
A_q,
|
||||
B_q,
|
||||
test_output,
|
||||
num_expert_tokens,
|
||||
compute_tl_dtype,
|
||||
# Quantization data
|
||||
A_scale,
|
||||
B_scale,
|
||||
None,
|
||||
# Quantization schemes
|
||||
use_fp8_w8a8,
|
||||
False,
|
||||
False,
|
||||
config={
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32,
|
||||
},
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
ref_output = native_batched_masked_quant_matmul(
|
||||
A,
|
||||
B,
|
||||
ref_output,
|
||||
num_expert_tokens,
|
||||
)
|
||||
|
||||
q_ref_output = native_batched_masked_quant_matmul(
|
||||
A_q,
|
||||
B_q,
|
||||
q_ref_output,
|
||||
num_expert_tokens,
|
||||
A_scale,
|
||||
B_scale,
|
||||
block_shape,
|
||||
per_act_token_quant,
|
||||
)
|
||||
|
||||
rtol, atol = {
|
||||
torch.float16: (6e-2, 6e-2),
|
||||
torch.bfloat16: (6e-2, 6e-2),
|
||||
torch.float32: (1e-2, 1e-2),
|
||||
}[test_output.dtype]
|
||||
|
||||
torch.testing.assert_close(ref_output, q_ref_output, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("per_act_token_quant", [False, True])
|
||||
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
|
||||
@pytest.mark.parametrize("input_scales", [False])
|
||||
def test_fused_moe_batched_experts(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
per_act_token_quant: bool,
|
||||
block_shape: list[int] | None,
|
||||
input_scales: bool,
|
||||
workspace_init,
|
||||
):
|
||||
"""Note: float8_e4m3fn is not supported on CUDA architecture < 89,
|
||||
and those tests will be skipped on unsupported hardware."""
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
|
||||
|
||||
if (dtype == torch.float8_e4m3fn) and not current_platform.has_device_capability(
|
||||
89
|
||||
):
|
||||
pytest.skip(
|
||||
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
|
||||
)
|
||||
|
||||
if topk > e:
|
||||
pytest.skip("topk > e")
|
||||
|
||||
if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
|
||||
pytest.skip("Skip quantization test for non-quantized type")
|
||||
|
||||
if per_act_token_quant and block_shape is not None:
|
||||
pytest.skip("Skip illegal quantization test.")
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
if dtype.itemsize == 1:
|
||||
act_dtype = torch.bfloat16
|
||||
quant_dtype = dtype
|
||||
else:
|
||||
act_dtype = dtype
|
||||
quant_dtype = None
|
||||
|
||||
(w1_16, w1, w1_s, _), (w2_16, w2, w2_s, _) = make_test_weights(
|
||||
e,
|
||||
n,
|
||||
k,
|
||||
block_shape=block_shape,
|
||||
in_dtype=act_dtype,
|
||||
quant_dtype=quant_dtype,
|
||||
per_out_ch_quant=per_act_token_quant,
|
||||
)
|
||||
|
||||
if input_scales and quant_dtype is not None:
|
||||
a1_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
|
||||
a2_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
|
||||
else:
|
||||
a1_scale = None
|
||||
a2_scale = None
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
|
||||
baseline_output = torch_experts(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
batched_output = naive_batched_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
triton_output = batched_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(batched_output, baseline_output, atol=3e-2, rtol=2e-2)
|
||||
|
||||
torch.testing.assert_close(triton_output, batched_output, atol=2e-2, rtol=2e-2)
|
||||
267
tests/kernels/moe/test_block_fp8.py
Normal file
267
tests/kernels/moe/test_block_fp8.py
Normal file
@@ -0,0 +1,267 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_quant_config, make_test_weights
|
||||
from tests.kernels.quant_utils import (
|
||||
native_per_token_group_quant_fp8,
|
||||
native_w8a8_block_matmul,
|
||||
)
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
_valid_deep_gemm_shape,
|
||||
deep_gemm_moe_fp8,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk,
|
||||
modular_triton_fused_moe,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import (
|
||||
get_mk_alignment_for_contiguous_layout,
|
||||
is_deep_gemm_e8m0_used,
|
||||
)
|
||||
from vllm.utils.import_utils import has_deep_gemm
|
||||
|
||||
dg_available = has_deep_gemm()
|
||||
|
||||
if current_platform.get_device_capability() < (9, 0):
|
||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
|
||||
if current_platform.is_fp8_fnuz():
|
||||
pytest.skip(
|
||||
"Tests in this file require float8_e4m3fn and platform does not support",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
|
||||
# Test configurations
|
||||
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
|
||||
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
|
||||
# and its hidden size is 7168.
|
||||
MNK_FACTORS = [
|
||||
(1, 128, 128),
|
||||
(1, 128, 7168),
|
||||
(1, 1024, 7168),
|
||||
(1, 4608, 128),
|
||||
(1, 4608, 7168),
|
||||
(83, 128, 128),
|
||||
(83, 512, 512),
|
||||
(83, 4608, 512),
|
||||
(83, 4608, 7168),
|
||||
(128, 512, 512),
|
||||
(128, 1024, 7168),
|
||||
(128, 4608, 7168),
|
||||
(2048, 128, 128),
|
||||
(2048, 1024, 7168),
|
||||
(2048, 4608, 512),
|
||||
(2048, 4608, 7168),
|
||||
(8192, 128, 128),
|
||||
(8192, 128, 7168),
|
||||
(8192, 1024, 7168),
|
||||
(8192, 4608, 7168),
|
||||
]
|
||||
|
||||
MNK_FACTORS_DG = [
|
||||
(128, 128, 128),
|
||||
(128, 128, 7168),
|
||||
(128, 1024, 7168),
|
||||
(128, 4608, 128),
|
||||
(128, 4608, 7168),
|
||||
(192, 512, 512),
|
||||
(192, 1024, 7168),
|
||||
(192, 4608, 7168),
|
||||
(1335, 128, 128),
|
||||
(1335, 1024, 7168),
|
||||
(1335, 4608, 512),
|
||||
(1335, 4608, 7168),
|
||||
(2048, 128, 128),
|
||||
(2048, 128, 7168),
|
||||
(2048, 1024, 7168),
|
||||
(2048, 4608, 7168),
|
||||
]
|
||||
|
||||
BLOCK_SIZE = [[128, 128]]
|
||||
E = [2, 8, 16] # [128, 256]
|
||||
TOP_KS = [1, 2, 6]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, block_shape):
|
||||
"""Fused moe with block-wise quantization using native torch."""
|
||||
B, D = a.shape
|
||||
topk = topk_ids.size(1)
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
a_q, a_s = native_per_token_group_quant_fp8(a, block_k)
|
||||
a_q = a_q.to(torch.float32)
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
inter_out = native_w8a8_block_matmul(
|
||||
a_q[mask], w1[i], a_s[mask], w1_s[i], block_shape, output_dtype=a.dtype
|
||||
)
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
act_out_q, act_out_s = native_per_token_group_quant_fp8(act_out, block_k)
|
||||
out[mask] = native_w8a8_block_matmul(
|
||||
act_out_q, w2[i], act_out_s, w2_s[i], block_shape, output_dtype=a.dtype
|
||||
)
|
||||
return (
|
||||
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
# Skip all tests if CUDA is not available
|
||||
pytest.importorskip("torch.cuda")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_cuda():
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS)
|
||||
@pytest.mark.parametrize("E", E)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_fused_moe(
|
||||
M, N, K, E, topk, block_size, dtype, seed, monkeypatch, workspace_init
|
||||
):
|
||||
if topk > E:
|
||||
pytest.skip(f"Skipping test; topk={topk} > E={E}")
|
||||
|
||||
torch.manual_seed(seed)
|
||||
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "2048")
|
||||
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
w1, w2, quant_config = make_test_quant_config(
|
||||
E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size,
|
||||
)
|
||||
|
||||
m_fused_moe = modular_triton_fused_moe(quant_config)
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
|
||||
|
||||
# Set the context to avoid lots of warning spam.
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ref_out = torch_w8a8_block_fp8_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
quant_config.w1_scale,
|
||||
quant_config.w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
block_size,
|
||||
)
|
||||
|
||||
out = fused_experts(
|
||||
a, w1, w2, topk_weights, topk_ids, quant_config=quant_config
|
||||
)
|
||||
|
||||
m_out = m_fused_moe(a, w1, w2, topk_weights, topk_ids)
|
||||
|
||||
# 0.039 only needed for M >= 8192
|
||||
tol = 0.035 if M < 8192 else 0.039
|
||||
torch.testing.assert_close(out, ref_out, atol=tol, rtol=tol)
|
||||
torch.testing.assert_close(m_out, ref_out, atol=tol, rtol=tol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS_DG)
|
||||
@pytest.mark.parametrize("E", E)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
|
||||
@pytest.mark.skipif(is_deep_gemm_e8m0_used(), reason="Not E8M0 scale MOE")
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch):
|
||||
if topk > E:
|
||||
pytest.skip(f"Skipping test: topk={topk} > E={E}")
|
||||
|
||||
if not _valid_deep_gemm_shape(M, N, K):
|
||||
pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")
|
||||
|
||||
chunk_size = 1024
|
||||
|
||||
torch.manual_seed(seed)
|
||||
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
|
||||
block_size = get_mk_alignment_for_contiguous_layout()
|
||||
dtype = torch.bfloat16
|
||||
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
(_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
|
||||
E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
block_shape=block_size,
|
||||
)
|
||||
|
||||
# Note: for now use_compile will error out if the problem size is
|
||||
# large enough to trigger chunking. I'm leaving the flag and
|
||||
# setup code in case we are able to revisit this later.
|
||||
use_compile = False
|
||||
|
||||
use_cudagraph = (
|
||||
chunk_size < M and N >= 1024 and K >= 1024 and current_platform.is_cuda_alike()
|
||||
)
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
|
||||
|
||||
# Set the context to avoid lots of warning spam.
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ref_out = torch_w8a8_block_fp8_moe(
|
||||
a, w1, w2, w1_s, w2_s, topk_weights, topk_ids, block_size
|
||||
)
|
||||
|
||||
if use_compile:
|
||||
deep_gemm_moe_fp8_fn = torch.compile(
|
||||
deep_gemm_moe_fp8, backend="inductor", fullgraph=True
|
||||
)
|
||||
torch._dynamo.mark_dynamic(a, 0)
|
||||
torch._dynamo.mark_dynamic(topk_weights, 0)
|
||||
torch._dynamo.mark_dynamic(topk_ids, 0)
|
||||
else:
|
||||
deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8
|
||||
|
||||
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids)
|
||||
|
||||
if use_cudagraph:
|
||||
out.fill_(0)
|
||||
stream = torch.cuda.Stream()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=stream):
|
||||
out = deep_gemm_moe_fp8_fn(
|
||||
a, w1, w2, w1_s, w2_s, topk_weights, topk_ids
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
torch.testing.assert_close(out, ref_out, atol=0.035, rtol=0.035)
|
||||
134
tests/kernels/moe/test_block_int8.py
Normal file
134
tests/kernels/moe/test_block_int8.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_quant_config
|
||||
from tests.kernels.quant_utils import (
|
||||
native_per_token_group_quant_int8,
|
||||
native_w8a8_block_matmul,
|
||||
)
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.get_device_capability() < (7, 0):
|
||||
pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
|
||||
DTYPES = [torch.bfloat16]
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 128, 128),
|
||||
(1, 128, 7168),
|
||||
(1, 1024, 7168),
|
||||
(1, 4096, 512),
|
||||
(1, 4096, 7168),
|
||||
(33, 512, 512),
|
||||
(33, 128, 7168),
|
||||
(33, 1024, 7168),
|
||||
(33, 4096, 128),
|
||||
(33, 4096, 7168),
|
||||
(128, 128, 128),
|
||||
(128, 1024, 7168),
|
||||
(128, 4096, 512),
|
||||
(128, 4096, 7168),
|
||||
(222, 512, 512),
|
||||
(222, 1024, 7168),
|
||||
(222, 4096, 7168),
|
||||
(2048, 128, 128),
|
||||
(2048, 1024, 7168),
|
||||
(2048, 4096, 4096),
|
||||
]
|
||||
|
||||
E = [8, 24]
|
||||
TOP_KS = [2, 6]
|
||||
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
|
||||
BLOCK_SIZE = [[128, 128]]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
# For test
|
||||
def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
|
||||
"""This function performs fused moe with block-wise quantization using
|
||||
native torch."""
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
a_q, a_s = native_per_token_group_quant_int8(a, block_k)
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
inter_out = native_w8a8_block_matmul(
|
||||
a_q[mask], w1[i], a_s[mask], w1_s[i], block_shape, output_dtype=a.dtype
|
||||
)
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
act_out_q, act_out_s = native_per_token_group_quant_int8(act_out, block_k)
|
||||
act_out = act_out.to(torch.float32)
|
||||
out[mask] = native_w8a8_block_matmul(
|
||||
act_out_q, w2[i], act_out_s, w2_s[i], block_shape, output_dtype=a.dtype
|
||||
)
|
||||
return (
|
||||
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def setup_cuda():
|
||||
"""Sets the default CUDA device for all tests in this module."""
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS)
|
||||
@pytest.mark.parametrize("E", E)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
|
||||
"""Tests the fused_moe kernel with W8A8 INT8 block quantization against a
|
||||
native torch reference."""
|
||||
torch.manual_seed(seed)
|
||||
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
|
||||
|
||||
w1, w2, quant_config = make_test_quant_config(
|
||||
E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
quant_dtype=torch.int8,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size,
|
||||
)
|
||||
|
||||
# Set the context to avoid lots of warning spam.
|
||||
with set_current_vllm_config(vllm_config):
|
||||
out = fused_experts(
|
||||
a, w1, w2, topk_weights, topk_ids, quant_config=quant_config
|
||||
)
|
||||
ref_out = torch_w8a8_block_int8_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
quant_config.w1_scale,
|
||||
quant_config.w2_scale,
|
||||
score,
|
||||
topk,
|
||||
block_size,
|
||||
)
|
||||
|
||||
# Check results
|
||||
torch.testing.assert_close(out, ref_out, atol=0.065, rtol=0.065)
|
||||
143
tests/kernels/moe/test_count_expert_num_tokens.py
Normal file
143
tests/kernels/moe/test_count_expert_num_tokens.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Tests compute_expert_num_tokens kernels
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.utils import count_expert_num_tokens
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TestTensors:
|
||||
topk_ids: torch.Tensor
|
||||
expert_map: torch.Tensor | None = None
|
||||
|
||||
def to_device(self, device: str):
|
||||
self.topk_ids = self.topk_ids.to(device=device)
|
||||
if self.expert_map is not None:
|
||||
self.expert_map = self.expert_map.to(device=device)
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
num_tokens: int,
|
||||
num_topk: int,
|
||||
num_experts: int,
|
||||
device: str,
|
||||
topk_ids_dtype: torch.dtype,
|
||||
) -> "TestTensors":
|
||||
# make topk ids
|
||||
topk_ids = torch.empty((num_tokens, num_topk), device=device, dtype=torch.int64)
|
||||
for x in range(num_tokens):
|
||||
topk_ids[x] = torch.randperm(num_experts)[:num_topk]
|
||||
topk_ids = topk_ids.to(dtype=torch.int64)
|
||||
return TestTensors(topk_ids=topk_ids)
|
||||
|
||||
def with_ep_rank(
|
||||
self, ep_rank: int, num_global_experts: int, num_local_experts: int, device: str
|
||||
):
|
||||
# make an expert map
|
||||
expert_map = torch.empty((num_global_experts), device=device, dtype=torch.int32)
|
||||
expert_map.fill_(-1)
|
||||
s = ep_rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
expert_map[s:e] = torch.tensor(list(range(num_local_experts)), device=device)
|
||||
|
||||
return TestTensors(topk_ids=self.topk_ids.clone(), expert_map=expert_map)
|
||||
|
||||
|
||||
def ref_impl(tt: TestTensors, expert_num_tokens: torch.Tensor):
|
||||
# do the reference in cpu
|
||||
tt.to_device("cpu")
|
||||
expert_ids, counts = tt.topk_ids.unique(return_counts=True)
|
||||
|
||||
for eid, count in zip(expert_ids, counts):
|
||||
if eid != -1 and tt.expert_map is not None:
|
||||
eid = tt.expert_map[eid]
|
||||
|
||||
if eid == -1:
|
||||
continue
|
||||
|
||||
expert_num_tokens[eid] += count
|
||||
|
||||
|
||||
def do_test_compute_expert_num_tokens(
|
||||
num_tokens: int,
|
||||
num_topk: int,
|
||||
num_experts: int,
|
||||
ep_size: int,
|
||||
topk_ids_dtype: torch.dtype,
|
||||
):
|
||||
assert num_topk <= num_experts
|
||||
|
||||
tt = TestTensors.make(
|
||||
num_tokens, num_topk, num_experts, topk_ids_dtype=topk_ids_dtype, device="cpu"
|
||||
)
|
||||
|
||||
num_global_experts = num_experts
|
||||
assert num_global_experts % ep_size == 0
|
||||
num_local_experts = num_global_experts // ep_size
|
||||
for ep_rank in range(ep_size):
|
||||
tt_rank = tt.with_ep_rank(ep_rank, num_global_experts, num_local_experts, "cpu")
|
||||
|
||||
ref_expert_num_tokens = torch.zeros(
|
||||
(num_local_experts), device="cpu", dtype=torch.int32
|
||||
)
|
||||
ref_impl(tt_rank, ref_expert_num_tokens)
|
||||
ref_expert_num_tokens = ref_expert_num_tokens.to("cuda")
|
||||
|
||||
tt_rank.to_device("cuda")
|
||||
# Test with expert_map
|
||||
triton_expert_num_tokens_w_emap = count_expert_num_tokens(
|
||||
tt_rank.topk_ids, num_local_experts, tt_rank.expert_map
|
||||
)
|
||||
|
||||
# Test without expert map
|
||||
topk_ids = tt_rank.expert_map[tt_rank.topk_ids].to(topk_ids_dtype)
|
||||
triton_expert_num_tokens_wo_emap = count_expert_num_tokens(
|
||||
topk_ids, num_local_experts, expert_map=None
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
ref_expert_num_tokens, triton_expert_num_tokens_w_emap, atol=0, rtol=0
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
ref_expert_num_tokens, triton_expert_num_tokens_wo_emap, atol=0, rtol=0
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 11, 127, 128, 3333, 7317])
|
||||
@pytest.mark.parametrize("num_topk", [2, 6, 8])
|
||||
@pytest.mark.parametrize("num_experts", [64])
|
||||
@pytest.mark.parametrize("ep_size", [1, 2, 4])
|
||||
@pytest.mark.parametrize("topk_ids_dtype", [torch.int64])
|
||||
def test_compute_expert_num_tokens(
|
||||
num_tokens: int,
|
||||
num_topk: int,
|
||||
num_experts: int,
|
||||
ep_size: int,
|
||||
topk_ids_dtype: torch.dtype,
|
||||
):
|
||||
do_test_compute_expert_num_tokens(
|
||||
num_tokens, num_topk, num_experts, ep_size, topk_ids_dtype
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("numel", list(range(1, 8192, 111)))
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("ep_size", [2])
|
||||
@pytest.mark.parametrize("topk_ids_dtype", [torch.int64])
|
||||
def test_compute_expert_num_tokens_from_numel(
|
||||
numel: int, num_experts: int, ep_size: int, topk_ids_dtype: torch.dtype
|
||||
):
|
||||
do_test_compute_expert_num_tokens(
|
||||
num_tokens=numel,
|
||||
num_topk=1,
|
||||
num_experts=num_experts,
|
||||
ep_size=ep_size,
|
||||
topk_ids_dtype=topk_ids_dtype,
|
||||
)
|
||||
582
tests/kernels/moe/test_cutedsl_moe.py
Normal file
582
tests/kernels/moe/test_cutedsl_moe.py
Normal file
@@ -0,0 +1,582 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(
|
||||
reason="Nvfp4 Requires compute capability of 10 or above.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
import torch
|
||||
from flashinfer import fp4_quantize
|
||||
from torch.nn import functional as F
|
||||
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (
|
||||
flashinfer_cutedsl_moe_masked,
|
||||
)
|
||||
from vllm.utils.flashinfer import (
|
||||
flashinfer_cutedsl_grouped_gemm_nt_masked as cutedsl_gmm_masked,
|
||||
)
|
||||
from vllm.utils.flashinfer import (
|
||||
scaled_fp4_grouped_quantize,
|
||||
)
|
||||
|
||||
kE2M1ToFloat = torch.tensor(
|
||||
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
|
||||
)
|
||||
|
||||
FLOAT8_E4M3_MAX = 448.0
|
||||
FLOAT4_E2M1_MAX = 6.0
|
||||
|
||||
|
||||
def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
|
||||
m_tiles = (m + 128 - 1) // 128
|
||||
f = block_size * 4
|
||||
k_tiles = (k + f - 1) // f
|
||||
tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4))
|
||||
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
|
||||
out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size)
|
||||
return out[0:m, 0:k]
|
||||
|
||||
|
||||
def dequantize_nvfp4_to_dtype(
|
||||
tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16
|
||||
):
|
||||
"""Dequantize the fp4 tensor back to high precision."""
|
||||
# Two fp4 values are packed into one uint8.
|
||||
assert tensor_fp4.dtype == torch.uint8
|
||||
m, packed_k = tensor_fp4.shape
|
||||
k = packed_k * 2
|
||||
tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
|
||||
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
|
||||
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
|
||||
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
|
||||
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale
|
||||
|
||||
# scale the tensor
|
||||
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
|
||||
return out.to(dtype=dtype)
|
||||
|
||||
|
||||
def break_fp4_bytes(a, dtype):
|
||||
assert a.dtype == torch.uint8
|
||||
m, n = a.shape
|
||||
|
||||
# Vectorized nibble processing
|
||||
a_flat = a.flatten()
|
||||
high = (a_flat & 0xF0) >> 4 # Upper nibbles
|
||||
low = a_flat & 0x0F # Lower nibbles
|
||||
|
||||
# Combine nibbles for batch processing
|
||||
combined = torch.stack((low, high), dim=1).flatten()
|
||||
|
||||
# Vectorized sign and magnitude extraction
|
||||
signs = (combined & 0x08).to(torch.bool) # Sign bits
|
||||
abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices
|
||||
|
||||
# Device-aware lookup and sign application
|
||||
kE2M1 = kE2M1ToFloat.to(device=a.device)
|
||||
values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)
|
||||
|
||||
# Reshape to final form
|
||||
return values.reshape(m, n * 2).to(dtype=dtype)
|
||||
|
||||
|
||||
def generate_balanced_routing(
|
||||
hidden_states: torch.Tensor, num_experts: int, top_k: int
|
||||
):
|
||||
"""
|
||||
Generate routing weights and topk indices such that every expert is active.
|
||||
Returns routing_weights, topk_idx
|
||||
"""
|
||||
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
# num_tokens = batch_size * seq_len
|
||||
|
||||
# First, assign at least one token per expert
|
||||
tokens_per_expert = torch.arange(num_tokens) % num_experts
|
||||
tokens_per_expert = tokens_per_expert[torch.randperm(num_tokens)] # shuffle
|
||||
|
||||
# Each token has top_k experts — start with one guaranteed expert
|
||||
topk_idx = torch.full((num_tokens, top_k), -1, dtype=torch.long)
|
||||
topk_idx[:, 0] = tokens_per_expert
|
||||
|
||||
# For remaining top_k - 1 experts, pick randomly (allowing repeats)
|
||||
if top_k > 1:
|
||||
random_choices = torch.randint(0, num_experts, (num_tokens, top_k - 1))
|
||||
topk_idx[:, 1:] = random_choices
|
||||
|
||||
# Normalize routing weights so each token's weights sum to 1
|
||||
routing_weights = torch.rand(num_tokens, top_k)
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
# Reshape back if needed
|
||||
routing_weights = routing_weights.view(num_tokens, top_k)
|
||||
topk_idx = topk_idx.view(num_tokens, top_k)
|
||||
|
||||
return routing_weights, topk_idx
|
||||
|
||||
|
||||
def prepare_inputs(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
):
|
||||
routing_weights, topk_idx = generate_balanced_routing(
|
||||
router_logits, num_experts, topk
|
||||
)
|
||||
|
||||
masked_m = []
|
||||
for i in range(num_experts):
|
||||
mask = topk_idx.view(-1) == i
|
||||
masked_m.append(mask.sum())
|
||||
|
||||
masked_m = torch.tensor(masked_m, dtype=torch.int32)
|
||||
# Intialize the hidden_states_3d with ones instead of empty to avoid nan
|
||||
# issue.
|
||||
hidden_states_3d = torch.ones(
|
||||
(num_experts, max(masked_m), hidden_states.shape[1]), dtype=hidden_states.dtype
|
||||
)
|
||||
for i in range(num_experts):
|
||||
hidden_states_3d[i, : masked_m[i], :] = hidden_states[topk_idx.view(-1) == i]
|
||||
|
||||
return hidden_states_3d, masked_m, topk_idx, routing_weights
|
||||
|
||||
|
||||
MNK_FACTORS = [
|
||||
(2, 1024, 1024),
|
||||
(2, 1024, 1536),
|
||||
(2, 3072, 1024),
|
||||
(2, 3072, 1536),
|
||||
(64, 1024, 1024),
|
||||
(64, 1024, 1536),
|
||||
(64, 3072, 1024),
|
||||
(64, 2048, 1024),
|
||||
(224, 1024, 1024),
|
||||
(224, 1024, 1536),
|
||||
]
|
||||
|
||||
|
||||
# Reference implementation of torch_moe
|
||||
def torch_moe(a, w1, w2, score, topk, expert_map):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
if expert_map is not None:
|
||||
topk_ids = expert_map[topk_ids]
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(
|
||||
0, 1
|
||||
)
|
||||
return (
|
||||
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
def torch_moe_nvfp4(a, w1, w2, topk, topk_weight, topk_ids):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
m = w1[i].shape[0]
|
||||
assert m % 2 == 0
|
||||
# Note: w1 and w3 are swapped!
|
||||
w3_expert, w1_expert = w1[i][m // 2 :, :], w1[i][: m // 2, :]
|
||||
inter = F.silu(a[mask] @ w1_expert.t()) * (a[mask] @ w3_expert.t())
|
||||
inter_gs = torch.tensor(1.0).cuda()
|
||||
inter_q, inter_blockscale = fp4_quantize(inter, inter_gs)
|
||||
inter = dequantize_nvfp4_to_dtype(
|
||||
inter_q,
|
||||
inter_blockscale,
|
||||
inter_gs,
|
||||
dtype=inter.dtype,
|
||||
device=inter.device,
|
||||
block_size=16,
|
||||
).cuda()
|
||||
out[mask] = inter @ w2[i].transpose(0, 1)
|
||||
return (
|
||||
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
def grouped_gemm_ref(
|
||||
hidden_states_expanded: torch.Tensor,
|
||||
hidden_states_3d: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
masked_m: torch.Tensor,
|
||||
B: int,
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
*,
|
||||
block_size: int = 16,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes the reference grouped GEMM (fp4 quantized per-expert loop),
|
||||
computes flashinfer grouped GEMM (for scale consistency),
|
||||
and returns ONLY the repacked reference output: out_ref.
|
||||
|
||||
Returns:
|
||||
out_ref: Tensor [num_experts, max_m, n_out]
|
||||
"""
|
||||
device_hs = hidden_states_expanded.device
|
||||
device_w = weights.device
|
||||
out_dtype = weights.dtype
|
||||
n_out = weights.shape[1]
|
||||
|
||||
# Flattened reference output (B*topk, n_out)
|
||||
out = torch.zeros((B * topk, n_out), dtype=out_dtype, device=device_w)
|
||||
|
||||
# Per-expert reference compute loop
|
||||
for i in range(num_experts):
|
||||
mask = topk_idx.view(-1) == i
|
||||
if mask.any():
|
||||
lhs = hidden_states_expanded[mask]
|
||||
rhs = weights[i]
|
||||
|
||||
a_amax = lhs.abs().max().to(torch.float32).to(device_hs)
|
||||
b_amax = rhs.abs().max().to(torch.float32).to(device_w)
|
||||
|
||||
a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
|
||||
b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
|
||||
|
||||
lhsq, lhsq_sf = fp4_quantize(lhs, a_gs)
|
||||
rhsq, rhsq_sf = fp4_quantize(rhs, b_gs)
|
||||
|
||||
lhs_in_dtype = dequantize_nvfp4_to_dtype(
|
||||
lhsq,
|
||||
lhsq_sf,
|
||||
a_gs,
|
||||
dtype=lhs.dtype,
|
||||
device=device_hs,
|
||||
block_size=block_size,
|
||||
)
|
||||
rhs_in_dtype = dequantize_nvfp4_to_dtype(
|
||||
rhsq,
|
||||
rhsq_sf,
|
||||
b_gs,
|
||||
dtype=rhs.dtype,
|
||||
device=device_w,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
out[mask] = lhs_in_dtype @ rhs_in_dtype.t()
|
||||
|
||||
# Determine per-expert max_m
|
||||
max_m_val = int(masked_m.max().item())
|
||||
|
||||
# Repack into [num_experts, max_m, n_out]
|
||||
out_ref = torch.zeros(
|
||||
(num_experts, max_m_val, n_out),
|
||||
dtype=out.dtype,
|
||||
device=out.device,
|
||||
)
|
||||
expert_slot = [0] * num_experts
|
||||
|
||||
for i, expert_id in enumerate(topk_idx.view(-1).tolist()):
|
||||
slot = expert_slot[expert_id]
|
||||
if slot < max_m_val:
|
||||
out_ref[expert_id, slot, :] = out[i]
|
||||
expert_slot[expert_id] += 1
|
||||
else:
|
||||
raise IndexError(
|
||||
f"Expert {expert_id} exceeded max slots ({max_m_val}). "
|
||||
"Increase max_m or check masked_m."
|
||||
)
|
||||
|
||||
return out_ref
|
||||
|
||||
|
||||
def flashinfer_cutedsl_grouped_gemm_nt_masked(
|
||||
hidden_states: torch.Tensor, # 3d
|
||||
input_global_scale: torch.Tensor, # (l,)
|
||||
weights: torch.Tensor,
|
||||
w_global_scale: torch.Tensor, # (l,)
|
||||
masked_m: torch.Tensor,
|
||||
):
|
||||
# hidden_states: [l, m, k]
|
||||
# weights: [l, n, k]
|
||||
aq, aq_sf = scaled_fp4_grouped_quantize(
|
||||
hidden_states,
|
||||
masked_m.to(hidden_states.device),
|
||||
input_global_scale,
|
||||
)
|
||||
num_experts, n, k = weights.shape
|
||||
bq, bq_sf = scaled_fp4_grouped_quantize(
|
||||
weights,
|
||||
torch.full((num_experts,), n, device=weights.device, dtype=torch.int32),
|
||||
w_global_scale,
|
||||
)
|
||||
|
||||
out = torch.zeros(
|
||||
(num_experts, max(masked_m), n), dtype=weights.dtype, device=aq.device
|
||||
)
|
||||
out = out.permute(1, 2, 0) # requirement of kernel
|
||||
sf_vec_size = 16
|
||||
ab_dtype = "float4_e2m1fn"
|
||||
sf_dtype = "float8_e4m3fn"
|
||||
c_dtype = "bfloat16"
|
||||
alpha = 1.0 / (input_global_scale * w_global_scale).to(out.dtype).view(
|
||||
1, 1, num_experts
|
||||
)
|
||||
|
||||
def get_cute_dtype(input: torch.Tensor) -> str:
|
||||
if input.dtype == torch.bfloat16:
|
||||
return "bfloat16"
|
||||
elif input.dtype == torch.float16:
|
||||
return "float16"
|
||||
elif input.dtype == torch.float32:
|
||||
return "float32"
|
||||
else:
|
||||
raise ValueError(f"Unsupported cute dtype {input.dtype}")
|
||||
|
||||
cutedsl_gmm_masked(
|
||||
(aq, aq_sf),
|
||||
(bq, bq_sf),
|
||||
out,
|
||||
masked_m.to(aq.device),
|
||||
ab_dtype=ab_dtype,
|
||||
sf_dtype=sf_dtype,
|
||||
c_dtype=c_dtype,
|
||||
sf_vec_size=sf_vec_size,
|
||||
alpha=alpha,
|
||||
alpha_dtype=get_cute_dtype(alpha),
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.parametrize("bs, hidden_dim, inter_dim", [(2, 128, 256), (16, 128, 512)])
|
||||
@pytest.mark.parametrize("topk", [1, 2, 4])
|
||||
@torch.inference_mode()
|
||||
def test_flashinfer_cutedsl_moe_masked(
|
||||
bs: int, hidden_dim: int, inter_dim: int, topk: int
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
device = "cuda"
|
||||
num_experts = 8
|
||||
hidden_states = (
|
||||
torch.randn(bs, hidden_dim, dtype=torch.bfloat16, device=device) / 5.0
|
||||
)
|
||||
w1 = (
|
||||
torch.randn(
|
||||
num_experts, 2 * inter_dim, hidden_dim, dtype=torch.bfloat16, device=device
|
||||
)
|
||||
/ 10.0
|
||||
)
|
||||
w2 = (
|
||||
torch.randn(
|
||||
num_experts, hidden_dim, inter_dim, dtype=torch.bfloat16, device=device
|
||||
)
|
||||
/ 10.0
|
||||
)
|
||||
router_logits = torch.randn(bs, num_experts, dtype=torch.float32)
|
||||
|
||||
hidden_states_expanded = (
|
||||
hidden_states.view(bs, -1, hidden_dim)
|
||||
.repeat(1, topk, 1)
|
||||
.reshape(-1, hidden_dim)
|
||||
)
|
||||
hidden_states_3d, masked_m, topk_idx, routing_weights = prepare_inputs(
|
||||
hidden_states_expanded, router_logits, num_experts, topk
|
||||
)
|
||||
|
||||
w1_amax = w1.abs().amax(dim=(1, 2)).to(torch.float32).to(w1.device)
|
||||
w2_amax = w2.abs().amax(dim=(1, 2)).to(torch.float32).to(w2.device)
|
||||
input_global_scale = torch.ones(
|
||||
(num_experts,), dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
|
||||
w1_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
|
||||
w2_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
|
||||
a2_global_scale = torch.ones(
|
||||
(num_experts,), dtype=torch.float32, device=hidden_states.device
|
||||
) # assume intermediate scale is 1.0
|
||||
|
||||
w1_fp4, w1_blockscale = scaled_fp4_grouped_quantize(
|
||||
w1,
|
||||
torch.ones(num_experts, dtype=torch.int32, device=w1.device) * 2 * inter_dim,
|
||||
w1_global_scale,
|
||||
)
|
||||
w2_fp4, w2_blockscale = scaled_fp4_grouped_quantize(
|
||||
w2,
|
||||
torch.ones(num_experts, dtype=torch.int32, device=w2.device) * hidden_dim,
|
||||
w2_global_scale,
|
||||
)
|
||||
|
||||
w1_alpha = 1.0 / (input_global_scale * w1_global_scale)
|
||||
w2_alpha = 1.0 / (a2_global_scale * w2_global_scale)
|
||||
|
||||
out = torch.empty_like(hidden_states_3d)
|
||||
# Note: the 1st dim shouldn't be bs
|
||||
wk = torch.empty(
|
||||
num_experts,
|
||||
hidden_states_3d.shape[1],
|
||||
inter_dim * 2,
|
||||
dtype=hidden_states_3d.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
flashinfer_cutedsl_moe_masked(
|
||||
hidden_states_3d.to(hidden_states.device),
|
||||
input_global_scale,
|
||||
w1_fp4.permute(2, 0, 1),
|
||||
w1_blockscale,
|
||||
w1_alpha,
|
||||
w2_fp4.permute(2, 0, 1),
|
||||
a2_global_scale,
|
||||
w2_blockscale,
|
||||
w2_alpha,
|
||||
masked_m.to(hidden_states.device),
|
||||
wk,
|
||||
out,
|
||||
)
|
||||
|
||||
# reference
|
||||
a_fp4, a_scale_interleaved = fp4_quantize(hidden_states, input_global_scale)
|
||||
a_in_dtype = dequantize_nvfp4_to_dtype(
|
||||
a_fp4,
|
||||
a_scale_interleaved,
|
||||
input_global_scale,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
block_size=16,
|
||||
)
|
||||
w1_d = torch.empty(
|
||||
(num_experts, 2 * inter_dim, hidden_dim), device=w1.device, dtype=w1.dtype
|
||||
)
|
||||
w2_d = torch.empty(
|
||||
(num_experts, hidden_dim, inter_dim), device=w2.device, dtype=w2.dtype
|
||||
)
|
||||
|
||||
for idx in range(0, num_experts):
|
||||
w1_fp4_sliced, w1_blockscale_sliced = fp4_quantize(
|
||||
w1[idx], w1_global_scale[idx]
|
||||
)
|
||||
w2_fp4_sliced, w2_blockscale_sliced = fp4_quantize(
|
||||
w2[idx], w2_global_scale[idx]
|
||||
)
|
||||
w1_d[idx] = dequantize_nvfp4_to_dtype(
|
||||
w1_fp4_sliced,
|
||||
w1_blockscale_sliced,
|
||||
w1_global_scale[idx],
|
||||
dtype=w1.dtype,
|
||||
device=w1.device,
|
||||
block_size=16,
|
||||
)
|
||||
w2_d[idx] = dequantize_nvfp4_to_dtype(
|
||||
w2_fp4_sliced,
|
||||
w2_blockscale_sliced,
|
||||
w2_global_scale[idx],
|
||||
dtype=w2.dtype,
|
||||
device=w2.device,
|
||||
block_size=16,
|
||||
)
|
||||
|
||||
ref_output = torch_moe_nvfp4(
|
||||
a_in_dtype,
|
||||
w1_d,
|
||||
w2_d,
|
||||
topk,
|
||||
routing_weights.to(a_in_dtype.device),
|
||||
topk_idx.to(a_in_dtype.device),
|
||||
)
|
||||
out_weighted = torch.zeros_like(ref_output, device=out.device, dtype=out.dtype)
|
||||
|
||||
positions = torch.nonzero(masked_m[topk_idx], as_tuple=False)
|
||||
rows, cols = positions[:, 0], positions[:, 1]
|
||||
experts = topk_idx[rows, cols]
|
||||
for i in range(num_experts):
|
||||
mask = experts == i
|
||||
if mask.any():
|
||||
idx = torch.nonzero(mask, as_tuple=False).squeeze(-1)
|
||||
r, c = rows[idx], cols[idx]
|
||||
out_weighted[r] += out[i, : len(r), :] * routing_weights[r, c].to(
|
||||
out.device
|
||||
).unsqueeze(-1)
|
||||
torch.testing.assert_close(
|
||||
out_weighted.cpu(), ref_output.cpu(), atol=2e-1, rtol=2e-1
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"bs, hidden_dim, inter_dim, topk", [(2, 128, 256, 2), (16, 128, 512, 5)]
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_grouped_gemm_nt_masked(
|
||||
bs: int, hidden_dim: int, inter_dim: int, topk: int
|
||||
) -> None:
|
||||
torch.manual_seed(42)
|
||||
B = bs
|
||||
D = hidden_dim
|
||||
N = inter_dim
|
||||
# CuteDSL group gemm has issue when not all experts are active.
|
||||
# i.e. masked = [2, 3, 0, 0, 1] where the 2nd and 3rd experts are inactive
|
||||
# see https://github.com/flashinfer-ai/flashinfer/issues/1856
|
||||
num_experts = bs
|
||||
hidden_states = torch.randn(B, D, dtype=torch.bfloat16, device="cuda")
|
||||
weights = torch.randn(num_experts, N, D, dtype=torch.bfloat16, device="cuda")
|
||||
router_logits = torch.randn(B, num_experts, dtype=torch.float32)
|
||||
|
||||
hidden_states_expanded = (
|
||||
hidden_states.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
)
|
||||
hidden_states_3d, masked_m, topk_idx, _ = prepare_inputs(
|
||||
hidden_states_expanded, router_logits, num_experts, topk
|
||||
)
|
||||
|
||||
a_amax = (
|
||||
hidden_states_3d.abs()
|
||||
.amax(dim=(1, 2))
|
||||
.to(torch.float32)
|
||||
.to(hidden_states.device)
|
||||
)
|
||||
b_amax = weights.abs().amax(dim=(1, 2)).to(torch.float32).to(weights.device)
|
||||
a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
|
||||
b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
|
||||
out_flashinfer = flashinfer_cutedsl_grouped_gemm_nt_masked(
|
||||
hidden_states_3d.to(hidden_states.device), a_gs, weights, b_gs, masked_m
|
||||
)
|
||||
# reference
|
||||
out_ref = grouped_gemm_ref(
|
||||
hidden_states_expanded=hidden_states_expanded,
|
||||
hidden_states_3d=hidden_states_3d,
|
||||
weights=weights,
|
||||
topk_idx=topk_idx,
|
||||
masked_m=masked_m,
|
||||
B=B,
|
||||
topk=topk,
|
||||
num_experts=num_experts,
|
||||
)
|
||||
# Note: just to compare the masked position due to cutedsl may write nan
|
||||
# into unmasked position.
|
||||
for i in range(num_experts):
|
||||
torch.testing.assert_close(
|
||||
out_flashinfer.permute(2, 0, 1)[i, : masked_m[i]],
|
||||
out_ref.to(out_flashinfer.device)[i, : masked_m[i]],
|
||||
atol=1e-1,
|
||||
rtol=1e-1,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_flashinfer_cutedsl_moe_masked(16, 128, 512, 4)
|
||||
test_grouped_gemm_nt_masked(16, 128, 512, 4)
|
||||
92
tests/kernels/moe/test_cutlass_grouped_gemm.py
Normal file
92
tests/kernels/moe/test_cutlass_grouped_gemm.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# DeepGEMM Style Cutlass Grouped GEMM Test
|
||||
# See https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_core.py
|
||||
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import per_token_cast_to_fp8
|
||||
from tests.kernels.utils import baseline_scaled_mm
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import per_block_cast_to_fp8
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_groups, expected_m_per_group, k, n",
|
||||
[
|
||||
(4, 8192, 7168, 4096),
|
||||
(4, 8192, 2048, 7168),
|
||||
(8, 4096, 7168, 4096),
|
||||
(8, 4096, 2048, 7168),
|
||||
(32, 1024, 7168, 4096),
|
||||
(32, 1024, 2048, 7168),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("out_dtype", [torch.float16])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or x.to_int() != 100)(
|
||||
current_platform.get_device_capability()
|
||||
),
|
||||
reason="Block Scaled Grouped GEMM is only supported on SM100.",
|
||||
)
|
||||
def test_cutlass_grouped_gemm(
|
||||
num_groups: int,
|
||||
expected_m_per_group: int,
|
||||
k: int,
|
||||
n: int,
|
||||
out_dtype: torch.dtype,
|
||||
):
|
||||
device = "cuda"
|
||||
alignment = 128
|
||||
group_ms = [
|
||||
int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)
|
||||
]
|
||||
m = sum([cdiv(m, alignment) * alignment for m in group_ms])
|
||||
|
||||
x = torch.randn((m, k), device=device, dtype=out_dtype)
|
||||
y = torch.randn((num_groups, n, k), device=device, dtype=out_dtype)
|
||||
out = torch.empty((m, n), device=device, dtype=out_dtype)
|
||||
ref_out = torch.randn((m, n), device=device, dtype=out_dtype)
|
||||
|
||||
ep_offset = [0] + [sum(group_ms[:i]) for i in range(1, num_groups)] + [m]
|
||||
pb_size = []
|
||||
for i in range(num_groups):
|
||||
pb_size.append([ep_offset[i + 1] - ep_offset[i], n, k])
|
||||
problem_sizes = torch.tensor(pb_size, device=device, dtype=torch.int32)
|
||||
expert_offsets = torch.tensor(ep_offset, device=device, dtype=torch.int32)
|
||||
|
||||
x_fp8 = per_token_cast_to_fp8(x)
|
||||
y_fp8 = (
|
||||
torch.empty_like(y, dtype=torch.float8_e4m3fn),
|
||||
torch.empty(
|
||||
(num_groups, cdiv(n, 128), k // 128), device=device, dtype=torch.float
|
||||
),
|
||||
)
|
||||
for i in range(num_groups):
|
||||
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i], [128, 128])
|
||||
|
||||
for i in range(num_groups):
|
||||
a = x_fp8[0][ep_offset[i] : ep_offset[i + 1]]
|
||||
a_scale = x_fp8[1][ep_offset[i] : ep_offset[i + 1]]
|
||||
b = y_fp8[0][i].t()
|
||||
b_scale = y_fp8[1][i].t()
|
||||
baseline = baseline_scaled_mm(a, b, a_scale, b_scale, out_dtype)
|
||||
ref_out[ep_offset[i] : ep_offset[i + 1]] = baseline
|
||||
|
||||
ops.cutlass_blockwise_scaled_grouped_mm(
|
||||
out,
|
||||
x_fp8[0],
|
||||
y_fp8[0],
|
||||
x_fp8[1],
|
||||
y_fp8[1],
|
||||
problem_sizes,
|
||||
expert_offsets[:-1],
|
||||
)
|
||||
|
||||
torch.testing.assert_close(ref_out, out, atol=5e-1, rtol=1e-3)
|
||||
554
tests/kernels/moe/test_cutlass_moe.py
Normal file
554
tests/kernels/moe/test_cutlass_moe.py
Normal file
@@ -0,0 +1,554 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
import dataclasses
|
||||
from math import prod
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
cutlass_moe_fp8,
|
||||
run_cutlass_moe_fp8,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
NUM_EXPERTS = [40, 64]
|
||||
TOP_KS = [6, 8]
|
||||
|
||||
MNK_FACTORS = [
|
||||
(2, 1024, 1024),
|
||||
(2, 3072, 1024),
|
||||
(2, 3072, 1536),
|
||||
(7, 3072, 1536),
|
||||
(64, 1024, 1024),
|
||||
(64, 1024, 1536),
|
||||
(64, 3072, 1024),
|
||||
(224, 1024, 1024),
|
||||
(224, 3072, 1024),
|
||||
(224, 3072, 1536),
|
||||
(32768, 1024, 1024),
|
||||
# These sizes trigger wrong answers.
|
||||
# (7232, 2048, 5120),
|
||||
# (40000, 2048, 5120),
|
||||
]
|
||||
|
||||
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MOETensors:
|
||||
a: torch.Tensor
|
||||
w1: torch.Tensor
|
||||
w2: torch.Tensor
|
||||
ab_strides1: torch.Tensor
|
||||
c_strides1: torch.Tensor
|
||||
ab_strides2: torch.Tensor
|
||||
c_strides2: torch.Tensor
|
||||
|
||||
@staticmethod
|
||||
def make_moe_tensors(
|
||||
m: int, k: int, n: int, e: int, dtype: torch.dtype
|
||||
) -> "MOETensors":
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
ab_strides1 = torch.full((e,), k, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e,), n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64)
|
||||
return MOETensors(
|
||||
a=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
ab_strides1=ab_strides1,
|
||||
c_strides1=c_strides1,
|
||||
ab_strides2=ab_strides2,
|
||||
c_strides2=c_strides2,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MOETensors8Bit(MOETensors):
|
||||
# quantized
|
||||
a_q: torch.Tensor | None = None # a -> a_q
|
||||
w1_q: torch.Tensor | None = None # w1 -> w1_q
|
||||
w2_q: torch.Tensor | None = None # w2 -> w2_q
|
||||
a_scale: torch.Tensor | None = None
|
||||
w1_scale: torch.Tensor | None = None
|
||||
w2_scale: torch.Tensor | None = None
|
||||
# dequantized
|
||||
a_d: torch.Tensor | None = None # a -> a_q -> a_d
|
||||
w1_d: torch.Tensor | None = None # w1 -> w1_q -> w1_d
|
||||
w2_d: torch.Tensor | None = None # w2 -> w2_q -> w2_d
|
||||
|
||||
@staticmethod
|
||||
def make_moe_tensors_8bit(
|
||||
m: int, k: int, n: int, e: int, per_act_token: bool, per_out_channel: bool
|
||||
) -> "MOETensors8Bit":
|
||||
dtype = torch.half
|
||||
q_dtype = torch.float8_e4m3fn
|
||||
|
||||
moe_tensors_fp16 = MOETensors.make_moe_tensors(m, k, n, e, dtype)
|
||||
|
||||
# a -> a_q, w1 -> w1_q, w2 -> w2_q
|
||||
n_b_scales = 2 * n if per_out_channel else 1
|
||||
k_b_scales = k if per_out_channel else 1
|
||||
# Get the right scale for tests.
|
||||
a_q, a_scale = ops.scaled_fp8_quant(
|
||||
moe_tensors_fp16.a, None, use_per_token_if_dynamic=per_act_token
|
||||
)
|
||||
|
||||
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
|
||||
w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)
|
||||
|
||||
w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32)
|
||||
w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32)
|
||||
for expert in range(e):
|
||||
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
|
||||
moe_tensors_fp16.w1[expert], use_per_token_if_dynamic=per_out_channel
|
||||
)
|
||||
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
|
||||
moe_tensors_fp16.w2[expert], use_per_token_if_dynamic=per_out_channel
|
||||
)
|
||||
|
||||
# a_q -> a_d, w1_q -> w1_d, w2_q -> w2_d
|
||||
a_d = a_q.float().mul(a_scale).to(dtype)
|
||||
w1_d = torch.empty_like(moe_tensors_fp16.w1)
|
||||
w2_d = torch.empty_like(moe_tensors_fp16.w2)
|
||||
for expert in range(e):
|
||||
w1_d[expert] = (w1_q[expert].float() * w1_scale[expert]).half()
|
||||
w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half()
|
||||
|
||||
return MOETensors8Bit(
|
||||
a=moe_tensors_fp16.a,
|
||||
w1=moe_tensors_fp16.w1,
|
||||
w2=moe_tensors_fp16.w2,
|
||||
ab_strides1=moe_tensors_fp16.ab_strides1,
|
||||
c_strides1=moe_tensors_fp16.c_strides1,
|
||||
ab_strides2=moe_tensors_fp16.ab_strides2,
|
||||
c_strides2=moe_tensors_fp16.c_strides2,
|
||||
a_q=a_q,
|
||||
w1_q=w1_q,
|
||||
w2_q=w2_q,
|
||||
a_scale=a_scale,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a_d=a_d,
|
||||
w1_d=w1_d,
|
||||
w2_d=w2_d,
|
||||
)
|
||||
|
||||
|
||||
def run_with_expert_maps(
|
||||
num_experts: int, num_local_experts: int, **cutlass_moe_kwargs
|
||||
):
|
||||
def slice_experts():
|
||||
slice_params = [
|
||||
"w1_q",
|
||||
"w2_q",
|
||||
"ab_strides1",
|
||||
"ab_strides2",
|
||||
"c_strides1",
|
||||
"c_strides2",
|
||||
]
|
||||
full_tensors = {
|
||||
k: v
|
||||
for k, v in cutlass_moe_kwargs.items()
|
||||
if k in slice_params and k in cutlass_moe_kwargs
|
||||
}
|
||||
|
||||
quant_config = cutlass_moe_kwargs["quant_config"]
|
||||
|
||||
for i in range(0, num_experts, num_local_experts):
|
||||
s, e = i, i + num_local_experts
|
||||
|
||||
# make expert map
|
||||
expert_map = [-1] * num_experts
|
||||
expert_map[s:e] = list(range(num_local_experts))
|
||||
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
|
||||
|
||||
# update cutlass moe arg with expert_map
|
||||
cutlass_moe_kwargs["expert_map"] = expert_map
|
||||
# update cutlass moe arg tensors
|
||||
for k, t in full_tensors.items():
|
||||
cutlass_moe_kwargs[k] = t[s:e]
|
||||
|
||||
new_quant_config = copy.deepcopy(quant_config)
|
||||
new_quant_config._w1.scale = quant_config.w1_scale[s:e]
|
||||
new_quant_config._w2.scale = quant_config.w2_scale[s:e]
|
||||
|
||||
cutlass_moe_kwargs["quant_config"] = new_quant_config
|
||||
|
||||
yield cutlass_moe_kwargs
|
||||
|
||||
out_tensor = torch.zeros_like(cutlass_moe_kwargs["a"])
|
||||
for kwargs in slice_experts():
|
||||
out_tensor = out_tensor + cutlass_moe_fp8(**kwargs)
|
||||
|
||||
return out_tensor
|
||||
|
||||
|
||||
def run_8_bit(
|
||||
moe_tensors: MOETensors8Bit,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
num_local_experts: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert not any(
|
||||
[
|
||||
t is None
|
||||
for t in [
|
||||
moe_tensors.w1_q,
|
||||
moe_tensors.w2_q,
|
||||
moe_tensors.w1_scale,
|
||||
moe_tensors.w2_scale,
|
||||
moe_tensors.a_scale,
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=moe_tensors.w1_scale,
|
||||
w2_scale=moe_tensors.w2_scale,
|
||||
per_act_token_quant=per_act_token,
|
||||
per_out_ch_quant=per_out_ch,
|
||||
# Set to moe_tensors.a_scale iff static scales + per tensor.
|
||||
# This is not currently being tested.
|
||||
a1_scale=None,
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
"a": moe_tensors.a,
|
||||
"w1_q": moe_tensors.w1_q, # type: ignore[union-attr]
|
||||
"w2_q": moe_tensors.w2_q, # type: ignore[union-attr]
|
||||
"topk_weights": topk_weights,
|
||||
"topk_ids": topk_ids,
|
||||
"ab_strides1": moe_tensors.ab_strides1,
|
||||
"ab_strides2": moe_tensors.ab_strides2,
|
||||
"c_strides1": moe_tensors.c_strides1,
|
||||
"c_strides2": moe_tensors.c_strides2,
|
||||
"quant_config": quant_config,
|
||||
}
|
||||
|
||||
num_experts = moe_tensors.w1.size(0)
|
||||
with_ep = num_local_experts is not None or num_local_experts == num_experts
|
||||
if not with_ep:
|
||||
return cutlass_moe_fp8(**kwargs)
|
||||
|
||||
assert num_local_experts is not None
|
||||
return run_with_expert_maps(
|
||||
num_experts,
|
||||
num_local_experts, # type: ignore[arg-type]
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()
|
||||
),
|
||||
reason="Grouped gemm is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_moe_8_bit_no_graph(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
ep_size: int | None = None,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.half)
|
||||
topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False)
|
||||
|
||||
# Note that we are using the dequantized versions of the tensors.
|
||||
# Using a, w1 and w2 directly results in minor output differences.
|
||||
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
triton_output = fused_experts(
|
||||
mt.a_d, mt.w1_d, mt.w2_d, topk_weights, topk_ids, quant_config=quant_config
|
||||
)
|
||||
|
||||
if ep_size is not None:
|
||||
assert e % ep_size == 0, "Cannot distribute experts evenly"
|
||||
number_local_experts = e // ep_size
|
||||
else:
|
||||
number_local_experts = None
|
||||
|
||||
cutlass_output = run_8_bit(
|
||||
mt, topk_weights, topk_ids, per_act_token, per_out_ch, number_local_experts
|
||||
)
|
||||
|
||||
# Note 5.5 only needed for larger problem sizes, 5 works ok for
|
||||
# the rest.
|
||||
torch.testing.assert_close(
|
||||
triton_output, cutlass_output, atol=5.5e-2, rtol=1e-2
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()
|
||||
),
|
||||
reason="Grouped gemm is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_moe_8_bit_cuda_graph(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
dtype = torch.half
|
||||
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False)
|
||||
|
||||
# Note that we are using the dequantized versions of the tensors.
|
||||
# Using a, w1 and w2 directly results in minor output differences.
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
triton_output = fused_experts(
|
||||
mt.a_d, mt.w1_d, mt.w2_d, topk_weights, topk_ids, quant_config=quant_config
|
||||
)
|
||||
|
||||
stream = torch.cuda.Stream()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=stream):
|
||||
cutlass_output = run_8_bit(
|
||||
mt, topk_weights, topk_ids, per_act_token, per_out_ch
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
torch.testing.assert_close(triton_output, cutlass_output, atol=9e-2, rtol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [64])
|
||||
@pytest.mark.parametrize("n", [1024])
|
||||
@pytest.mark.parametrize("k", [4096])
|
||||
@pytest.mark.parametrize("e", [16])
|
||||
@pytest.mark.parametrize("topk", [1, 8])
|
||||
@pytest.mark.parametrize("per_act_token", [True])
|
||||
@pytest.mark.parametrize("per_out_channel", [True])
|
||||
@pytest.mark.parametrize("ep_size", [1, 2, 4, 8, 16])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()
|
||||
),
|
||||
reason="Grouped gemm is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_moe_8_bit_EP(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
per_act_token: bool,
|
||||
per_out_channel: bool,
|
||||
ep_size: int,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
):
|
||||
test_cutlass_moe_8_bit_no_graph(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
e,
|
||||
topk,
|
||||
per_act_token,
|
||||
per_out_channel,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
ep_size,
|
||||
)
|
||||
|
||||
|
||||
LARGE_MNK_FACTORS = [
|
||||
(1, 8192, 5120, 31),
|
||||
(32768, 1024, 1024, 16),
|
||||
(65536, 512, 1024, 16),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k,topk", LARGE_MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", [128])
|
||||
@pytest.mark.parametrize("per_act_token", [False])
|
||||
@pytest.mark.parametrize("per_out_channel", [True])
|
||||
@pytest.mark.parametrize("ep_size", [8])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()
|
||||
),
|
||||
reason="Grouped gemm is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_moe_8_bit_EP_large(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
per_act_token: bool,
|
||||
per_out_channel: bool,
|
||||
ep_size: int,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
):
|
||||
test_cutlass_moe_8_bit_no_graph(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
e,
|
||||
topk,
|
||||
per_act_token,
|
||||
per_out_channel,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
ep_size,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k,topk", [(1, 8192, 5120, 31)])
|
||||
@pytest.mark.parametrize("e", [128])
|
||||
@pytest.mark.parametrize("per_act_token", [False])
|
||||
@pytest.mark.parametrize("per_out_channel", [True])
|
||||
@pytest.mark.parametrize("ep_size", [8])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()
|
||||
),
|
||||
reason="Grouped gemm is not supported on this GPU type.",
|
||||
)
|
||||
def test_run_cutlass_moe_fp8(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
per_act_token: bool,
|
||||
per_out_channel: bool,
|
||||
ep_size: int,
|
||||
workspace_init,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(
|
||||
m, k, n, e, per_act_token, per_out_channel
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.half)
|
||||
topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False)
|
||||
# we want to make sure there is at least one token that's generated in
|
||||
# this expert shard and at least one token that's NOT generated in this
|
||||
# expert shard
|
||||
topk_ids[0][0] = -1
|
||||
topk_ids[0][1] = 1
|
||||
|
||||
workspace13_shape = (m * topk, max(2 * n, k))
|
||||
workspace2_shape = (m * topk, max(n, k))
|
||||
output_shape = (m, k)
|
||||
|
||||
workspace13 = torch.empty(
|
||||
prod(workspace13_shape), device="cuda", dtype=mt.a.dtype
|
||||
)
|
||||
workspace2 = torch.empty(
|
||||
prod(workspace2_shape), device="cuda", dtype=mt.a.dtype
|
||||
)
|
||||
|
||||
num_local_experts = e // ep_size
|
||||
start, end = 0, num_local_experts
|
||||
expert_map = [-1] * e
|
||||
expert_map[start:end] = list(range(num_local_experts))
|
||||
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
|
||||
|
||||
ab_strides1 = torch.full((e,), k, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e,), n, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64)
|
||||
|
||||
activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
mt.a, mt.a_scale, torch.float8_e4m3fn, per_act_token
|
||||
)
|
||||
global_num_experts = -1 if mt.w1_q is None else mt.w1_q.size(0)
|
||||
func = lambda output: run_cutlass_moe_fp8(
|
||||
output,
|
||||
a1q,
|
||||
mt.w1_q,
|
||||
mt.w2_q,
|
||||
topk_ids,
|
||||
activation,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
mt.w1_scale,
|
||||
mt.w2_scale,
|
||||
a1q_scale,
|
||||
None,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
workspace13,
|
||||
workspace2,
|
||||
None,
|
||||
mt.a.dtype,
|
||||
per_act_token,
|
||||
per_out_channel,
|
||||
False,
|
||||
topk_weights,
|
||||
)
|
||||
|
||||
workspace13.random_()
|
||||
output_random_workspace = torch.empty(
|
||||
output_shape, device="cuda", dtype=mt.a.dtype
|
||||
)
|
||||
func(output_random_workspace)
|
||||
|
||||
workspace13.fill_(0)
|
||||
output_zero_workspace = torch.zeros(
|
||||
output_shape, device="cuda", dtype=mt.a.dtype
|
||||
)
|
||||
func(output_zero_workspace)
|
||||
|
||||
torch.testing.assert_close(
|
||||
output_random_workspace, output_zero_workspace, atol=5e-3, rtol=1e-3
|
||||
)
|
||||
565
tests/kernels/moe/test_deepep_deepgemm_moe.py
Normal file
565
tests/kernels/moe/test_deepep_deepgemm_moe.py
Normal file
@@ -0,0 +1,565 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Test DeepEP + DeepGEMM integration
|
||||
DeepGEMM are gemm kernels specialized for the
|
||||
fp8 block-quantized case.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from contextlib import contextmanager
|
||||
|
||||
import pytest
|
||||
import torch.distributed
|
||||
from torch.distributed import ProcessGroup
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import (
|
||||
get_mk_alignment_for_contiguous_layout,
|
||||
is_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported,
|
||||
)
|
||||
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
|
||||
from ...utils import multi_gpu_test
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||
from .utils import make_test_weights
|
||||
|
||||
if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import (
|
||||
DeepEPHTPrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import (
|
||||
DeepEPLLPrepareAndFinalize,
|
||||
)
|
||||
|
||||
from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
|
||||
|
||||
if has_deep_gemm():
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
||||
|
||||
requires_deep_ep = pytest.mark.skipif(
|
||||
not has_deep_ep(),
|
||||
reason="Requires deep_ep kernels",
|
||||
)
|
||||
|
||||
requires_deep_gemm = pytest.mark.skipif(
|
||||
not is_deep_gemm_supported(),
|
||||
reason="Requires deep_gemm kernels",
|
||||
)
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def with_dp_metadata(M: int, world_size: int):
|
||||
num_tokens_across_dp = torch.tensor([M] * world_size, device="cpu", dtype=torch.int)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.parallel_config.data_parallel_size = world_size
|
||||
vllm_config.parallel_config.enable_expert_parallel = True
|
||||
|
||||
with set_forward_context(
|
||||
None,
|
||||
vllm_config,
|
||||
num_tokens=M,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
def next_power_of_2(x):
|
||||
import math
|
||||
|
||||
if x == 0:
|
||||
return 1
|
||||
return 2 ** math.ceil(math.log2(x))
|
||||
|
||||
|
||||
def make_block_quant_fp8_weights(
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
block_size: list[int],
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Return weights w1q, w2q, w1_scale, w2_scale
|
||||
"""
|
||||
(_, w1q, w1_scale, _), (_, w2q, w2_scale, _) = make_test_weights(
|
||||
e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_shape=block_size
|
||||
)
|
||||
return w1q, w2q, w1_scale, w2_scale
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TestConfig:
|
||||
topk: int
|
||||
m: int
|
||||
k: int
|
||||
n: int
|
||||
num_experts: int
|
||||
per_act_token_quant: bool
|
||||
block_size: list[int]
|
||||
# configs for testing low-latency kernels
|
||||
low_latency: bool
|
||||
use_fp8_dispatch: bool | None = False
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TestTensors:
|
||||
rank_tokens: torch.Tensor # all ranks make this many tokens
|
||||
rank_token_scales: torch.Tensor | None
|
||||
topk: torch.Tensor
|
||||
topk_weights: torch.Tensor
|
||||
config: TestConfig
|
||||
|
||||
@staticmethod
|
||||
def make(config: TestConfig, rank) -> "TestTensors":
|
||||
dtype = torch.bfloat16
|
||||
topk, m, k = (config.topk, config.m, config.k)
|
||||
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
rank_tokens = (
|
||||
torch.randn((m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0
|
||||
)
|
||||
rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max)
|
||||
rank_token_scales = None
|
||||
|
||||
topk_ids = torch.randint(
|
||||
low=0,
|
||||
high=config.num_experts,
|
||||
size=(m, topk),
|
||||
device=torch.cuda.current_device(),
|
||||
).to(dtype=torch.int64)
|
||||
|
||||
topk_weights = torch.randn(
|
||||
topk_ids.shape, dtype=torch.float32, device=torch.cuda.current_device()
|
||||
)
|
||||
|
||||
return TestTensors(
|
||||
rank_tokens=rank_tokens,
|
||||
rank_token_scales=rank_token_scales,
|
||||
topk=topk_ids,
|
||||
topk_weights=topk_weights,
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
def make_ll_modular_kernel(
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
max_tokens_per_rank: int,
|
||||
dp_size: int,
|
||||
hidden_size: int,
|
||||
q_dtype: torch.dtype | None,
|
||||
test_config: TestConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> FusedMoEModularKernel:
|
||||
assert test_config.low_latency
|
||||
assert test_config.use_fp8_dispatch is not None
|
||||
|
||||
a2a: DeepEPLLPrepareAndFinalize = make_deepep_a2a(
|
||||
pg=pg,
|
||||
pgi=pgi,
|
||||
dp_size=dp_size,
|
||||
deepep_ht_args=None,
|
||||
deepep_ll_args=DeepEPLLArgs(
|
||||
max_tokens_per_rank=max_tokens_per_rank,
|
||||
hidden_size=hidden_size,
|
||||
num_experts=test_config.num_experts,
|
||||
use_fp8_dispatch=test_config.use_fp8_dispatch,
|
||||
),
|
||||
q_dtype=q_dtype,
|
||||
block_shape=test_config.block_size,
|
||||
)
|
||||
|
||||
fused_experts = BatchedDeepGemmExperts(
|
||||
max_num_tokens=max_tokens_per_rank,
|
||||
num_dispatchers=pgi.world_size // dp_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
|
||||
return mk
|
||||
|
||||
|
||||
def make_ht_modular_kernel(
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
num_local_experts: int,
|
||||
q_dtype: torch.dtype | None,
|
||||
test_config: TestConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> FusedMoEModularKernel:
|
||||
assert not test_config.low_latency
|
||||
assert test_config.use_fp8_dispatch is None
|
||||
|
||||
a2a: DeepEPHTPrepareAndFinalize = make_deepep_a2a(
|
||||
pg=pg,
|
||||
pgi=pgi,
|
||||
dp_size=dp_size,
|
||||
deepep_ht_args=DeepEPHTArgs(num_local_experts=num_local_experts),
|
||||
deepep_ll_args=None,
|
||||
q_dtype=q_dtype,
|
||||
block_shape=test_config.block_size,
|
||||
)
|
||||
|
||||
fused_experts = DeepGemmExperts(quant_config)
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
|
||||
return mk
|
||||
|
||||
|
||||
def make_modular_kernel(
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
num_local_experts: int,
|
||||
test_tensors: TestTensors,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> FusedMoEModularKernel:
|
||||
q_dtype = torch.float8_e4m3fn
|
||||
test_config = test_tensors.config
|
||||
|
||||
mk: FusedMoEModularKernel
|
||||
# Make modular kernel
|
||||
if test_config.low_latency:
|
||||
max_tokens_per_rank = max(64, next_power_of_2(test_tensors.rank_tokens.size(0)))
|
||||
hidden_size = test_tensors.rank_tokens.size(-1)
|
||||
|
||||
mk = make_ll_modular_kernel(
|
||||
pg=pg,
|
||||
pgi=pgi,
|
||||
max_tokens_per_rank=max_tokens_per_rank,
|
||||
dp_size=dp_size,
|
||||
hidden_size=hidden_size,
|
||||
q_dtype=q_dtype,
|
||||
test_config=test_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
else:
|
||||
mk = make_ht_modular_kernel(
|
||||
pg,
|
||||
pgi,
|
||||
dp_size,
|
||||
num_local_experts,
|
||||
q_dtype,
|
||||
test_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
return mk
|
||||
|
||||
|
||||
def deepep_deepgemm_moe_impl(
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
test_tensors: TestTensors,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor | None,
|
||||
w2_scale: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
test_config = test_tensors.config
|
||||
num_experts = test_config.num_experts
|
||||
num_local_experts = w1.size(0)
|
||||
|
||||
def build_expert_map():
|
||||
num_local_experts = w1.size(0)
|
||||
expert_map = torch.full((num_experts,), fill_value=-1, dtype=torch.int32)
|
||||
s = pgi.rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
|
||||
return expert_map.to(device=torch.cuda.current_device(), dtype=torch.int32)
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
# Low-Latency kernels can't dispatch scales.
|
||||
a1_scale=(None if test_config.low_latency else test_tensors.rank_token_scales),
|
||||
block_shape=test_config.block_size,
|
||||
)
|
||||
|
||||
# Make modular kernel
|
||||
mk: FusedMoEModularKernel = make_modular_kernel(
|
||||
pg=pg,
|
||||
pgi=pgi,
|
||||
dp_size=dp_size,
|
||||
num_local_experts=num_local_experts,
|
||||
test_tensors=test_tensors,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
with with_dp_metadata(
|
||||
M=test_tensors.rank_tokens.size(0), world_size=pgi.world_size
|
||||
):
|
||||
out = mk.forward(
|
||||
hidden_states=test_tensors.rank_tokens,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=test_tensors.topk_weights,
|
||||
topk_ids=test_tensors.topk,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
global_num_experts=num_experts,
|
||||
expert_map=build_expert_map(),
|
||||
apply_router_weight_on_input=False,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def triton_impl(
|
||||
a: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
a1_scale: torch.Tensor,
|
||||
block_shape: list[int],
|
||||
):
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
hidden_states=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
quant_config=quant_config,
|
||||
# Make sure this is set to False so we
|
||||
# don't end up comparing the same implementation.
|
||||
allow_deep_gemm=False,
|
||||
)
|
||||
|
||||
|
||||
def _test_deepep_deepgemm_moe(
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
config: TestConfig,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
):
|
||||
device = torch.device(f"cuda:{pgi.local_rank}")
|
||||
init_workspace_manager(device)
|
||||
|
||||
current_platform.seed_everything(pgi.rank)
|
||||
|
||||
w1 = w1.to(device=torch.cuda.current_device())
|
||||
w2 = w2.to(device=torch.cuda.current_device())
|
||||
w1_scale = w1_scale.to(device=torch.cuda.current_device())
|
||||
w2_scale = w2_scale.to(device=torch.cuda.current_device())
|
||||
|
||||
pg = torch.distributed.new_group(list(range(pgi.world_size)))
|
||||
test_tensors = TestTensors.make(config, pgi.rank)
|
||||
block_shape = [w1.size(1) // w1_scale.size(1), w1.size(2) // w1_scale.size(2)]
|
||||
|
||||
with set_current_vllm_config(VllmConfig()):
|
||||
# Reference
|
||||
triton_moe = triton_impl(
|
||||
a=test_tensors.rank_tokens,
|
||||
topk_ids=test_tensors.topk,
|
||||
topk_weights=test_tensors.topk_weights,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=test_tensors.rank_token_scales,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
# Slice experts for this rank.
|
||||
num_local_experts = config.num_experts // pgi.world_size
|
||||
e_start = num_local_experts * pgi.rank
|
||||
e_end = e_start + num_local_experts
|
||||
w1_ep = w1[e_start:e_end]
|
||||
w2_ep = w2[e_start:e_end]
|
||||
w1_scale_ep = w1_scale[e_start:e_end]
|
||||
w2_scale_ep = w2_scale[e_start:e_end]
|
||||
|
||||
deepep_moe = deepep_deepgemm_moe_impl(
|
||||
pg,
|
||||
pgi,
|
||||
dp_size,
|
||||
test_tensors,
|
||||
w1_ep,
|
||||
w2_ep,
|
||||
w1_scale_ep,
|
||||
w2_scale_ep,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
triton_moe,
|
||||
deepep_moe,
|
||||
atol=6e-2,
|
||||
rtol=6e-2,
|
||||
)
|
||||
|
||||
|
||||
MNKs = [
|
||||
(8, 128, 128),
|
||||
(8, 128, 512),
|
||||
(3, 1024, 2048),
|
||||
(32, 128, 1024),
|
||||
(45, 512, 2048),
|
||||
(64, 1024, 1024),
|
||||
(129, 128, 256),
|
||||
(129, 1024, 2048),
|
||||
(222, 1024, 2048),
|
||||
]
|
||||
|
||||
TOPKS = [2, 6]
|
||||
NUM_EXPERTS = [32]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mnk", MNKs)
|
||||
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOPKS)
|
||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@requires_deep_ep
|
||||
@requires_deep_gemm
|
||||
def test_ht_deepep_deepgemm_moe(
|
||||
mnk: tuple[int, int, int],
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
world_dp_size: tuple[int, int],
|
||||
disable_deepgemm_ue8m0,
|
||||
workspace_init,
|
||||
):
|
||||
"""
|
||||
Tests for High-Throughput DeepEP + DeepGemm integration.
|
||||
"""
|
||||
|
||||
m, n, k = mnk
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
if topk > num_experts:
|
||||
pytest.skip(f"Skipping test: topk={topk} > E={num_experts}")
|
||||
|
||||
block_m = get_mk_alignment_for_contiguous_layout()[0]
|
||||
block_size = [block_m, block_m]
|
||||
|
||||
world_size, dp_size = world_dp_size
|
||||
config = TestConfig(
|
||||
topk=topk,
|
||||
m=m,
|
||||
k=k,
|
||||
n=n,
|
||||
num_experts=num_experts,
|
||||
per_act_token_quant=False,
|
||||
block_size=block_size,
|
||||
low_latency=False,
|
||||
use_fp8_dispatch=None,
|
||||
)
|
||||
|
||||
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
|
||||
num_experts, n, k, block_size
|
||||
)
|
||||
|
||||
parallel_launch(
|
||||
world_size,
|
||||
_test_deepep_deepgemm_moe,
|
||||
dp_size,
|
||||
config,
|
||||
w1,
|
||||
w2,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
)
|
||||
|
||||
|
||||
MNKs = [
|
||||
(1, 128, 2560),
|
||||
(2, 128, 2560),
|
||||
(3, 1024, 2560),
|
||||
(32, 128, 2560),
|
||||
(45, 512, 2560),
|
||||
(64, 1024, 2560),
|
||||
(222, 1024, 2560),
|
||||
]
|
||||
# Fix tests for USE_FP8_DISPATCH=True
|
||||
USE_FP8_DISPATCH = [False]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mnk", MNKs)
|
||||
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOPKS)
|
||||
@pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH)
|
||||
@pytest.mark.parametrize("block_size", [[128, 128]])
|
||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@requires_deep_ep
|
||||
@requires_deep_gemm
|
||||
def test_ll_deepep_deepgemm_moe(
|
||||
mnk: tuple[int, int, int],
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
use_fp8_dispatch: bool,
|
||||
block_size: list[int],
|
||||
world_dp_size: tuple[int, int],
|
||||
disable_deepgemm_ue8m0,
|
||||
workspace_init,
|
||||
):
|
||||
"""
|
||||
Tests for Low-Latency DeepEP + DeepGemm integration.
|
||||
"""
|
||||
assert not is_deep_gemm_e8m0_used()
|
||||
|
||||
m, n, k = mnk
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
if topk > num_experts:
|
||||
pytest.skip(f"Skipping test: topk={topk} > E={num_experts}")
|
||||
|
||||
world_size, dp_size = world_dp_size
|
||||
config = TestConfig(
|
||||
topk=topk,
|
||||
m=m,
|
||||
k=k,
|
||||
n=n,
|
||||
num_experts=num_experts,
|
||||
per_act_token_quant=False,
|
||||
block_size=block_size,
|
||||
low_latency=True,
|
||||
use_fp8_dispatch=use_fp8_dispatch,
|
||||
)
|
||||
|
||||
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
|
||||
num_experts, n, k, block_size
|
||||
)
|
||||
|
||||
parallel_launch(
|
||||
world_size,
|
||||
_test_deepep_deepgemm_moe,
|
||||
dp_size,
|
||||
config,
|
||||
w1,
|
||||
w2,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
)
|
||||
528
tests/kernels/moe/test_deepep_moe.py
Normal file
528
tests/kernels/moe/test_deepep_moe.py
Normal file
@@ -0,0 +1,528 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Test deepep dispatch-combine logic
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
|
||||
import pytest
|
||||
import torch.distributed
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import TritonExperts
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.import_utils import has_deep_ep
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
|
||||
from ...utils import multi_gpu_test
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||
|
||||
if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import (
|
||||
DeepEPHTPrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import (
|
||||
DeepEPLLPrepareAndFinalize,
|
||||
)
|
||||
|
||||
from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
|
||||
|
||||
requires_deep_ep = pytest.mark.skipif(
|
||||
not has_deep_ep(),
|
||||
reason="Requires deep_ep kernels",
|
||||
)
|
||||
|
||||
MAX_TOKENS_PER_RANK = 64
|
||||
|
||||
|
||||
def make_weights(
|
||||
e, n, k, dtype
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Return weights w1, w2, w1_scale, w2_scale
|
||||
"""
|
||||
if dtype in [torch.float16, torch.bfloat16]:
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
return w1, w2, None, None
|
||||
|
||||
# per-out-channel weight quantization
|
||||
assert dtype == torch.float8_e4m3fn
|
||||
w1 = torch.empty((e, 2 * n, k), device="cuda", dtype=torch.float16)
|
||||
w2 = torch.empty((e, k, n), device="cuda", dtype=torch.float16)
|
||||
|
||||
n_b_scales = 2 * n
|
||||
k_b_scales = k
|
||||
w1_q = torch.empty_like(w1, dtype=dtype)
|
||||
w2_q = torch.empty_like(w2, dtype=dtype)
|
||||
w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32)
|
||||
w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32)
|
||||
for expert in range(e):
|
||||
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
|
||||
w1[expert], use_per_token_if_dynamic=True
|
||||
)
|
||||
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
|
||||
w2[expert], use_per_token_if_dynamic=True
|
||||
)
|
||||
return w1_q, w2_q, w1_scale, w2_scale
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TestConfig:
|
||||
dtype: torch.dtype
|
||||
topk: int
|
||||
m: int
|
||||
k: int
|
||||
n: int
|
||||
num_experts: int
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TestTensors:
|
||||
rank_tokens: torch.Tensor # all ranks make this many tokens
|
||||
rank_token_scales: torch.Tensor | None
|
||||
topk: torch.Tensor
|
||||
topk_weights: torch.Tensor
|
||||
config: TestConfig
|
||||
|
||||
@staticmethod
|
||||
def make(config: TestConfig, low_latency_mode: bool) -> "TestTensors":
|
||||
# TODO (varun) - check that float16 works ?
|
||||
assert config.dtype in [torch.bfloat16, torch.float8_e4m3fn]
|
||||
token_dtype = (
|
||||
torch.bfloat16 if config.dtype == torch.float8_e4m3fn else config.dtype
|
||||
)
|
||||
rank_tokens = (
|
||||
torch.randn((config.m, config.k), device="cuda", dtype=token_dtype) / 10
|
||||
)
|
||||
rank_token_scales = None
|
||||
|
||||
topk = torch.randint(
|
||||
low=0, high=config.num_experts, size=(config.m, config.topk), device="cuda"
|
||||
).to(dtype=torch.int64)
|
||||
topk_weights = torch.randn(topk.shape, dtype=torch.float32, device="cuda")
|
||||
return TestTensors(
|
||||
rank_tokens=rank_tokens,
|
||||
rank_token_scales=rank_token_scales,
|
||||
topk=topk,
|
||||
topk_weights=topk_weights,
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
def make_modular_kernel(
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
low_latency_mode: bool,
|
||||
hidden_size: int,
|
||||
dp_size: int,
|
||||
num_experts: int,
|
||||
num_local_experts: int,
|
||||
q_dtype: torch.dtype | None,
|
||||
use_fp8_dispatch: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> FusedMoEModularKernel:
|
||||
ht_args: DeepEPHTArgs | None = None
|
||||
ll_args: DeepEPLLArgs | None = None
|
||||
|
||||
if low_latency_mode:
|
||||
ll_args = DeepEPLLArgs(
|
||||
max_tokens_per_rank=MAX_TOKENS_PER_RANK,
|
||||
hidden_size=hidden_size,
|
||||
num_experts=num_experts,
|
||||
use_fp8_dispatch=use_fp8_dispatch,
|
||||
)
|
||||
else:
|
||||
assert not use_fp8_dispatch, (
|
||||
"FP8 Dispatch is valid only for low-latency kernels"
|
||||
)
|
||||
ht_args = DeepEPHTArgs(num_local_experts=num_local_experts)
|
||||
|
||||
a2a: DeepEPHTPrepareAndFinalize | DeepEPLLPrepareAndFinalize = make_deepep_a2a(
|
||||
pg=pg,
|
||||
pgi=pgi,
|
||||
dp_size=dp_size,
|
||||
q_dtype=q_dtype,
|
||||
block_shape=None,
|
||||
deepep_ht_args=ht_args,
|
||||
deepep_ll_args=ll_args,
|
||||
)
|
||||
|
||||
num_dispatchers = pgi.world_size // dp_size
|
||||
|
||||
if low_latency_mode:
|
||||
assert not quant_config.per_act_token_quant, "not supported in ll mode"
|
||||
fused_experts = BatchedTritonExperts(
|
||||
max_num_tokens=MAX_TOKENS_PER_RANK,
|
||||
num_dispatchers=num_dispatchers,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
else:
|
||||
fused_experts = TritonExperts(quant_config=quant_config)
|
||||
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
|
||||
return mk
|
||||
|
||||
|
||||
def deep_ep_moe_impl(
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
low_latency_mode: bool,
|
||||
dp_size: int,
|
||||
test_tensors: TestTensors,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor | None,
|
||||
w2_scale: torch.Tensor | None,
|
||||
num_experts: int,
|
||||
use_fp8_dispatch: bool,
|
||||
per_act_token_quant: bool,
|
||||
) -> torch.Tensor:
|
||||
num_local_experts = w1.size(0)
|
||||
|
||||
def build_expert_map():
|
||||
num_local_experts = w1.size(0)
|
||||
expert_map = torch.full((num_experts,), fill_value=-1, dtype=torch.int32)
|
||||
s = pgi.rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
|
||||
return expert_map.to(device=torch.cuda.current_device(), dtype=torch.int32)
|
||||
|
||||
hidden_size = test_tensors.rank_tokens.size(1)
|
||||
is_quantized = w1.dtype == torch.float8_e4m3fn
|
||||
q_dtype = None
|
||||
if is_quantized:
|
||||
q_dtype = torch.float8_e4m3fn
|
||||
|
||||
out_hidden_states = torch.empty_like(test_tensors.rank_tokens)
|
||||
total_num_tokens = test_tensors.rank_tokens.size(0)
|
||||
|
||||
def process_chunk(chunk_start, chunk_end, skip_result_store=False):
|
||||
rank_tokens_chunk = test_tensors.rank_tokens[chunk_start:chunk_end]
|
||||
topk_weights_chunk = test_tensors.topk_weights[chunk_start:chunk_end]
|
||||
topk_chunk = test_tensors.topk[chunk_start:chunk_end]
|
||||
rank_token_scales_chunk = test_tensors.rank_token_scales
|
||||
if (
|
||||
rank_token_scales_chunk is not None
|
||||
and rank_token_scales_chunk.size(0) == total_num_tokens
|
||||
):
|
||||
# per act token
|
||||
rank_token_scales_chunk = rank_token_scales_chunk[chunk_start:chunk_end]
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
q_dtype,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
a1_scale=rank_token_scales_chunk,
|
||||
)
|
||||
|
||||
# Make modular kernel
|
||||
mk: FusedMoEModularKernel = make_modular_kernel(
|
||||
pg,
|
||||
pgi,
|
||||
low_latency_mode,
|
||||
hidden_size,
|
||||
dp_size,
|
||||
num_experts,
|
||||
num_local_experts,
|
||||
q_dtype,
|
||||
use_fp8_dispatch,
|
||||
quant_config,
|
||||
)
|
||||
|
||||
out = mk.forward(
|
||||
hidden_states=rank_tokens_chunk,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights_chunk,
|
||||
topk_ids=topk_chunk,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
global_num_experts=num_experts,
|
||||
expert_map=build_expert_map(),
|
||||
apply_router_weight_on_input=False,
|
||||
)
|
||||
|
||||
if not skip_result_store:
|
||||
out_hidden_states[chunk_start:chunk_end, :].copy_(out, non_blocking=True)
|
||||
|
||||
max_num_tokens_per_dp = (
|
||||
MAX_TOKENS_PER_RANK if low_latency_mode else total_num_tokens
|
||||
)
|
||||
|
||||
for chunk_start_ in range(0, total_num_tokens, max_num_tokens_per_dp):
|
||||
chunk_start = chunk_start_
|
||||
chunk_end = min(chunk_start + max_num_tokens_per_dp, total_num_tokens)
|
||||
# clamp start and end
|
||||
chunk_start = min(chunk_start, total_num_tokens - 1)
|
||||
chunk_end = min(chunk_end, total_num_tokens)
|
||||
|
||||
process_chunk(
|
||||
chunk_start, chunk_end, skip_result_store=chunk_start_ >= total_num_tokens
|
||||
)
|
||||
|
||||
return out_hidden_states
|
||||
|
||||
|
||||
def torch_moe_impl(
|
||||
test_tensors: TestTensors,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor | None,
|
||||
w2_scale: torch.Tensor | None,
|
||||
using_fp8_dispatch: bool,
|
||||
per_act_token_quant: bool,
|
||||
):
|
||||
a, topk_ids, topk_weights = (
|
||||
test_tensors.rank_tokens,
|
||||
test_tensors.topk,
|
||||
test_tensors.topk_weights,
|
||||
)
|
||||
if using_fp8_dispatch:
|
||||
# The DeepEP implementation is requested to dispatch using FP8.
|
||||
# For numerical stability for testing, emulate the fp8 dispatch by
|
||||
# blockwise quant and de-quant.
|
||||
assert not per_act_token_quant
|
||||
a = test_tensors.rank_tokens
|
||||
aq, aq_scale = per_token_group_quant_fp8(a, 128, use_ue8m0=False)
|
||||
a = (
|
||||
(aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1))
|
||||
.view(a.shape)
|
||||
.to(a.dtype)
|
||||
)
|
||||
|
||||
is_quantized = w1.dtype == torch.float8_e4m3fn
|
||||
a_dtype = a.dtype
|
||||
if is_quantized:
|
||||
w1 = w1.to(dtype=torch.float32) * w1_scale
|
||||
w2 = w2.to(dtype=torch.float32) * w2_scale
|
||||
a = a.to(dtype=torch.float32)
|
||||
|
||||
m, _ = a.shape
|
||||
topk = topk_ids.size(1)
|
||||
out = torch.zeros_like(a)
|
||||
|
||||
for i in range(m):
|
||||
a_i = a[i]
|
||||
o_i = out[i]
|
||||
for j in range(topk):
|
||||
e = topk_ids[i][j]
|
||||
e_w = topk_weights[i][j]
|
||||
w1_e = w1[e]
|
||||
w2_e = w2[e]
|
||||
o_i += (
|
||||
SiluAndMul()(a_i @ w1_e.transpose(0, 1)) @ w2_e.transpose(0, 1)
|
||||
) * e_w
|
||||
|
||||
if is_quantized:
|
||||
out = out.to(dtype=a_dtype)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _deep_ep_moe(
|
||||
pgi: ProcessGroupInfo,
|
||||
low_latency_mode: bool,
|
||||
dp_size: int,
|
||||
config: TestConfig,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor | None,
|
||||
w2_scale: torch.Tensor | None,
|
||||
use_fp8_dispatch: bool,
|
||||
per_act_token_quant: bool,
|
||||
):
|
||||
device = torch.device(f"cuda:{pgi.local_rank}")
|
||||
init_workspace_manager(device)
|
||||
|
||||
if not low_latency_mode:
|
||||
assert not use_fp8_dispatch, (
|
||||
"FP8 dispatch interface is available only in low-latency mode"
|
||||
)
|
||||
|
||||
is_quantized = w1.dtype == torch.float8_e4m3fn
|
||||
w1 = w1.to(device=torch.cuda.current_device())
|
||||
w2 = w2.to(device=torch.cuda.current_device())
|
||||
if is_quantized:
|
||||
w1_scale = w1_scale.to( # type: ignore
|
||||
device=torch.cuda.current_device()
|
||||
)
|
||||
w2_scale = w2_scale.to( # type: ignore
|
||||
device=torch.cuda.current_device()
|
||||
)
|
||||
|
||||
pg = torch.distributed.new_group(list(range(pgi.world_size)))
|
||||
test_tensors = TestTensors.make(config, low_latency_mode)
|
||||
|
||||
with set_current_vllm_config(VllmConfig()):
|
||||
# Reference
|
||||
torch_combined = torch_moe_impl(
|
||||
test_tensors,
|
||||
w1,
|
||||
w2,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
use_fp8_dispatch,
|
||||
per_act_token_quant,
|
||||
)
|
||||
|
||||
# Splice experts for this rank.
|
||||
num_local_experts = config.num_experts // pgi.world_size
|
||||
e_start = num_local_experts * pgi.rank
|
||||
e_end = e_start + num_local_experts
|
||||
w1_ep = w1[e_start:e_end]
|
||||
w2_ep = w2[e_start:e_end]
|
||||
|
||||
w1_scale_ep, w2_scale_ep = None, None
|
||||
if is_quantized:
|
||||
w1_scale_ep = w1_scale[e_start:e_end] # type: ignore
|
||||
w2_scale_ep = w2_scale[e_start:e_end] # type: ignore
|
||||
deepep_combined = deep_ep_moe_impl(
|
||||
pg,
|
||||
pgi,
|
||||
low_latency_mode,
|
||||
dp_size,
|
||||
test_tensors,
|
||||
w1_ep,
|
||||
w2_ep,
|
||||
w1_scale_ep,
|
||||
w2_scale_ep,
|
||||
config.num_experts,
|
||||
use_fp8_dispatch,
|
||||
per_act_token_quant,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
torch_combined,
|
||||
deepep_combined,
|
||||
atol=6e-2,
|
||||
rtol=6e-2,
|
||||
)
|
||||
|
||||
|
||||
MNKs = [
|
||||
(1, 128, 128),
|
||||
(2, 128, 512),
|
||||
(3, 1024, 2048),
|
||||
(32, 128, 1024),
|
||||
(45, 512, 2048),
|
||||
(64, 1024, 1024),
|
||||
(222, 1024, 2048),
|
||||
]
|
||||
|
||||
DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("m,n,k", MNKs)
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("topk", [6])
|
||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||
@pytest.mark.parametrize("per_act_token_quant", [False, True])
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@requires_deep_ep
|
||||
def test_deep_ep_moe(
|
||||
dtype: torch.dtype,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
world_dp_size: tuple[int, int],
|
||||
per_act_token_quant: bool,
|
||||
workspace_init,
|
||||
):
|
||||
low_latency_mode = False
|
||||
use_fp8_dispatch = False
|
||||
|
||||
current_platform.seed_everything(7)
|
||||
world_size, dp_size = world_dp_size
|
||||
config = TestConfig(dtype=dtype, topk=topk, m=m, k=k, n=n, num_experts=num_experts)
|
||||
|
||||
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
|
||||
|
||||
parallel_launch(
|
||||
world_size,
|
||||
_deep_ep_moe,
|
||||
low_latency_mode,
|
||||
dp_size,
|
||||
config,
|
||||
w1,
|
||||
w2,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
use_fp8_dispatch,
|
||||
per_act_token_quant,
|
||||
)
|
||||
|
||||
|
||||
MNKs = [
|
||||
(1, 128, 2560),
|
||||
(2, 128, 2560),
|
||||
(3, 1024, 2560),
|
||||
(32, 128, 2560),
|
||||
(45, 512, 2560),
|
||||
(64, 1024, 2560),
|
||||
(222, 1024, 2560),
|
||||
]
|
||||
DTYPES = [torch.float8_e4m3fn, torch.bfloat16]
|
||||
USE_FP8_DISPATCH = [True, False]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("m,n,k", MNKs)
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("topk", [6])
|
||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||
@pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH)
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@requires_deep_ep
|
||||
def test_low_latency_deep_ep_moe(
|
||||
dtype: torch.dtype,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
world_dp_size: tuple[int, int],
|
||||
use_fp8_dispatch: bool,
|
||||
workspace_init,
|
||||
):
|
||||
low_latency_mode = True
|
||||
|
||||
if low_latency_mode and k not in DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES:
|
||||
pytest.skip(
|
||||
f"Skipping test as hidden size {k} is not in list of supported "
|
||||
f"hidden sizes {DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES}"
|
||||
)
|
||||
|
||||
current_platform.seed_everything(7)
|
||||
world_size, dp_size = world_dp_size
|
||||
config = TestConfig(dtype=dtype, topk=topk, m=m, k=k, n=n, num_experts=num_experts)
|
||||
|
||||
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
|
||||
|
||||
parallel_launch(
|
||||
world_size,
|
||||
_deep_ep_moe,
|
||||
low_latency_mode,
|
||||
dp_size,
|
||||
config,
|
||||
w1,
|
||||
w2,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
use_fp8_dispatch,
|
||||
False,
|
||||
)
|
||||
180
tests/kernels/moe/test_deepgemm.py
Normal file
180
tests/kernels/moe/test_deepgemm.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Unit-test DeepGEMM FP8 kernels (no DeepEP).
|
||||
Compare DeepGEMM path against the Triton fallback inside vLLM's fused_experts.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
|
||||
|
||||
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8,
|
||||
)
|
||||
from vllm.utils.deep_gemm import (
|
||||
calc_diff,
|
||||
is_deep_gemm_supported,
|
||||
per_block_cast_to_fp8,
|
||||
)
|
||||
|
||||
BLOCK_SIZE = [128, 128]
|
||||
|
||||
|
||||
def make_block_quant_fp8_weights(
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
block_size: list[int],
|
||||
):
|
||||
"""
|
||||
Generate (w1, w2) expert weights and their per-block scale tensors
|
||||
in FP8 block-quantized format.
|
||||
|
||||
w1 shape: (E, 2N, K)
|
||||
w2 shape: (E, K, N)
|
||||
"""
|
||||
dtype = torch.bfloat16
|
||||
fp8_max, fp8_min = (
|
||||
torch.finfo(torch.float8_e4m3fn).max,
|
||||
torch.finfo(torch.float8_e4m3fn).min,
|
||||
)
|
||||
|
||||
# bf16 reference weights
|
||||
w1_bf16 = torch.randn(e, 2 * n, k, device="cuda", dtype=dtype) / 10
|
||||
w2_bf16 = torch.randn(e, k, n, device="cuda", dtype=dtype) / 10
|
||||
w1_bf16.clamp_(fp8_min, fp8_max)
|
||||
w2_bf16.clamp_(fp8_min, fp8_max)
|
||||
|
||||
block_n, block_k = block_size
|
||||
n_tiles_w1 = math.ceil((2 * n) / block_n)
|
||||
k_tiles_w1 = math.ceil(k / block_k)
|
||||
n_tiles_w2 = math.ceil(k / block_n)
|
||||
k_tiles_w2 = math.ceil(n / block_k)
|
||||
|
||||
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
|
||||
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
|
||||
w1_s = torch.empty(e, n_tiles_w1, k_tiles_w1, device="cuda", dtype=torch.float32)
|
||||
w2_s = torch.empty(e, n_tiles_w2, k_tiles_w2, device="cuda", dtype=torch.float32)
|
||||
|
||||
for i in range(e):
|
||||
w1[i], w1_s[i] = per_block_cast_to_fp8(
|
||||
w1_bf16[i], block_size=block_size, use_ue8m0=True
|
||||
)
|
||||
w2[i], w2_s[i] = per_block_cast_to_fp8(
|
||||
w2_bf16[i], block_size=block_size, use_ue8m0=True
|
||||
)
|
||||
|
||||
return w1, w2, w1_s, w2_s
|
||||
|
||||
|
||||
def run_single_case(m, n, k, topk, num_experts, block_size):
|
||||
"""
|
||||
Run one (M,N,K) configuration on a single GPU and assert DeepGEMM ==
|
||||
Triton baseline within tolerance.
|
||||
"""
|
||||
tokens_bf16 = (
|
||||
torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
|
||||
.clamp_min_(-1)
|
||||
.clamp_max_(1)
|
||||
)
|
||||
_, a1_scale = per_token_group_quant_fp8(tokens_bf16, block_size[1])
|
||||
|
||||
# expert weight tensors
|
||||
w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k, block_size)
|
||||
|
||||
router_logits = torch.randn(m, num_experts, device="cuda", dtype=torch.float32)
|
||||
topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1)
|
||||
topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
a1_scale=a1_scale,
|
||||
block_shape=block_size,
|
||||
)
|
||||
|
||||
# triton reference
|
||||
out_triton = fused_experts(
|
||||
hidden_states=tokens_bf16,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
quant_config=quant_config,
|
||||
allow_deep_gemm=False,
|
||||
)
|
||||
|
||||
# DeepGemm
|
||||
out_deepgemm = fused_experts(
|
||||
hidden_states=tokens_bf16,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
quant_config=quant_config,
|
||||
allow_deep_gemm=True,
|
||||
)
|
||||
diff = calc_diff(out_deepgemm, out_triton)
|
||||
assert diff < 0.001, f"Diff exceeded 1%: {diff}"
|
||||
|
||||
|
||||
# Note: N <= 512 will disable the deepgemm path due to performance issues.
|
||||
MNKs = [
|
||||
(1024, 768, 128),
|
||||
(2048, 768, 512),
|
||||
(512, 1024, 1024),
|
||||
(4096, 4096, 1024),
|
||||
]
|
||||
|
||||
TOPKS = [2, 6]
|
||||
NUM_EXPERTS = [32]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("m", "n", "k"), MNKs)
|
||||
@pytest.mark.parametrize("topk", TOPKS)
|
||||
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
|
||||
@pytest.mark.skipif(not is_deep_gemm_supported(), reason="Requires deep_gemm kernels")
|
||||
def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch, workspace_init):
|
||||
with monkeypatch.context() as mp:
|
||||
mp.setenv("VLLM_USE_DEEP_GEMM", "1")
|
||||
|
||||
_fused_moe_mod = importlib.import_module(
|
||||
"vllm.model_executor.layers.fused_moe.fused_moe"
|
||||
)
|
||||
|
||||
call_counter = {"cnt": 0}
|
||||
|
||||
orig_fn = _fused_moe_mod.deep_gemm_moe_fp8
|
||||
|
||||
def _spy_deep_gemm_moe_fp8(*args, **kwargs):
|
||||
call_counter["cnt"] += 1
|
||||
return orig_fn(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8", _spy_deep_gemm_moe_fp8)
|
||||
|
||||
if topk > num_experts:
|
||||
pytest.skip(f"topk={topk} > num_experts={num_experts}")
|
||||
|
||||
run_single_case(
|
||||
m=m,
|
||||
n=n,
|
||||
k=k,
|
||||
topk=topk,
|
||||
num_experts=num_experts,
|
||||
block_size=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
# ensure that the DeepGEMM path was indeed taken.
|
||||
assert call_counter["cnt"] == 1, (
|
||||
f"DeepGEMM path was not executed during the test. "
|
||||
f"Call counter: {call_counter['cnt']}"
|
||||
)
|
||||
287
tests/kernels/moe/test_flashinfer.py
Normal file
287
tests/kernels/moe/test_flashinfer.py
Normal file
@@ -0,0 +1,287 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_flashinfer_per_tensor_scale_fp8,
|
||||
flashinfer_cutlass_moe_fp8,
|
||||
register_moe_scaling_factors,
|
||||
rotate_flashinfer_fp8_moe_weights,
|
||||
swap_w13_to_w31,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
try:
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
except ImportError:
|
||||
if current_platform.is_rocm():
|
||||
pytest.skip(
|
||||
"flashinfer not supported for vLLM on ROCm", allow_module_level=True
|
||||
)
|
||||
|
||||
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
|
||||
90
|
||||
):
|
||||
pytest.skip(
|
||||
"Supported for sm >= 90",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
NUM_EXPERTS = [16]
|
||||
TOP_KS = [1]
|
||||
|
||||
MNK_FACTORS = [
|
||||
(256, 8192, 5120),
|
||||
(127, 4096, 5120),
|
||||
(10, 8192, 5120),
|
||||
(10, 4096, 5120),
|
||||
(1, 8192, 5120),
|
||||
(1, 4096, 5120),
|
||||
]
|
||||
|
||||
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||
|
||||
|
||||
def quant_fp8_per_tensor_batches(a):
|
||||
num_batches = a.size(0)
|
||||
a_quant = []
|
||||
a_scales = []
|
||||
|
||||
for i in range(num_batches):
|
||||
a_fp8, a_global_sf = input_to_float8(a[i])
|
||||
a_global_sf = 1.0 / a_global_sf
|
||||
a_quant.append(a_fp8)
|
||||
a_scales.append(a_global_sf)
|
||||
|
||||
result_a_quant = torch.stack(a_quant)
|
||||
result_a_scales = torch.stack(a_scales)
|
||||
|
||||
return result_a_quant, result_a_scales
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestData:
|
||||
hidden_states: torch.Tensor
|
||||
w13_quantized: torch.Tensor
|
||||
w2_quantized: torch.Tensor
|
||||
a1_scale: torch.Tensor
|
||||
a2_scale: torch.Tensor
|
||||
w13_weight_scale: torch.Tensor
|
||||
w2_weight_scale: torch.Tensor
|
||||
layer: torch.nn.Module
|
||||
|
||||
@staticmethod
|
||||
def make_moe_tensors_8bit(
|
||||
m: int, k: int, n: int, e: int, reorder: bool, activation: str = "silu"
|
||||
) -> "TestData":
|
||||
is_gated = activation != "relu2_no_mul"
|
||||
|
||||
hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
|
||||
w13 = torch.randn(
|
||||
(e, (2 * n) if is_gated else n, k), device="cuda", dtype=torch.bfloat16
|
||||
)
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Scale to fp8
|
||||
_, a1_scale = input_to_float8(hidden_states)
|
||||
a1_scale = 1.0 / a1_scale
|
||||
a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(dtype=torch.float32)
|
||||
w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13)
|
||||
w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2)
|
||||
|
||||
layer = torch.nn.Module()
|
||||
layer.w13_weight = w13_quantized.clone()
|
||||
layer.w2_weight = w2_quantized.clone()
|
||||
layer.w13_input_scale = a1_scale
|
||||
layer.w2_input_scale = a2_scale
|
||||
layer.w13_weight_scale = w13_weight_scale
|
||||
layer.w2_weight_scale = w2_weight_scale
|
||||
# Setup dummy config.
|
||||
layer.moe_parallel_config = mk.FusedMoEParallelConfig(
|
||||
tp_size=1,
|
||||
pcp_size=1,
|
||||
dp_size=1,
|
||||
ep_size=1,
|
||||
tp_rank=1,
|
||||
pcp_rank=1,
|
||||
dp_rank=1,
|
||||
ep_rank=1,
|
||||
use_ep=False,
|
||||
all2all_backend="naive",
|
||||
)
|
||||
|
||||
register_moe_scaling_factors(layer)
|
||||
|
||||
# flashinfer expects swapped rows for w13
|
||||
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
|
||||
if reorder:
|
||||
rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
|
||||
layer.custom_routing_function = Llama4MoE.custom_routing_function
|
||||
layer.intermediate_size_per_partition = n
|
||||
layer.ep_rank = 0
|
||||
layer.local_num_experts = e
|
||||
|
||||
return TestData(
|
||||
hidden_states=hidden_states,
|
||||
w13_quantized=w13_quantized,
|
||||
w2_quantized=w2_quantized,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
w13_weight_scale=w13_weight_scale,
|
||||
w2_weight_scale=w2_weight_scale,
|
||||
layer=layer,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
def test_flashinfer_per_tensor_moe_fp8_no_graph(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
monkeypatch,
|
||||
):
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip("Test is only supported for sm >= 100")
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||
topk_weights, topk_ids = Llama4MoE.custom_routing_function(
|
||||
hidden_states=td.hidden_states,
|
||||
gating_output=score,
|
||||
topk=topk,
|
||||
renormalize=False,
|
||||
)
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=td.w13_weight_scale,
|
||||
w2_scale=td.w2_weight_scale,
|
||||
a1_scale=td.a1_scale,
|
||||
a2_scale=td.a2_scale,
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
|
||||
output = fused_experts(
|
||||
td.hidden_states,
|
||||
td.w13_quantized,
|
||||
td.w2_quantized,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
apply_router_weight_on_input=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
flashinfer_output = apply_flashinfer_per_tensor_scale_fp8(
|
||||
layer=td.layer,
|
||||
hidden_states=td.hidden_states,
|
||||
router_logits=score,
|
||||
routing_bias=None,
|
||||
global_num_experts=e,
|
||||
top_k=topk,
|
||||
num_expert_group=None,
|
||||
topk_group=None,
|
||||
apply_router_weight_on_input=True,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(output, flashinfer_output, atol=5.5e-2, rtol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("activation", ["silu", "relu2_no_mul"])
|
||||
def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
activation: str,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
td = TestData.make_moe_tensors_8bit(
|
||||
m, k, n, e, reorder=False, activation=activation
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||
topk_weights, topk_ids = Llama4MoE.custom_routing_function(
|
||||
hidden_states=td.hidden_states,
|
||||
gating_output=score,
|
||||
topk=topk,
|
||||
renormalize=False,
|
||||
)
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=td.w13_weight_scale,
|
||||
g1_alphas=(td.w13_weight_scale * td.a1_scale).squeeze(),
|
||||
w2_scale=td.w2_weight_scale,
|
||||
g2_alphas=(td.w2_weight_scale * td.a2_scale).squeeze(),
|
||||
a1_scale=td.a1_scale,
|
||||
a1_gscale=td.a1_scale,
|
||||
a2_scale=td.a2_scale,
|
||||
a2_gscale=1.0 / td.a2_scale,
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
|
||||
output = fused_experts(
|
||||
td.hidden_states,
|
||||
td.w13_quantized,
|
||||
td.w2_quantized,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
activation=activation,
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
apply_router_weight_on_input=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
td.layer.dp_size = 1
|
||||
|
||||
def get_fused_moe_quant_config(n: torch.nn.Module) -> FusedMoEQuantConfig:
|
||||
return quant_config
|
||||
|
||||
td.layer.get_fused_moe_quant_config = get_fused_moe_quant_config
|
||||
td.layer.quant_method = td.layer
|
||||
|
||||
flashinfer_cutlass_output = flashinfer_cutlass_moe_fp8(
|
||||
td.hidden_states,
|
||||
td.layer,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
apply_router_weight_on_input=True,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
output, flashinfer_cutlass_output, atol=5.5e-2, rtol=1e-2
|
||||
)
|
||||
152
tests/kernels/moe/test_flashinfer_moe.py
Normal file
152
tests/kernels/moe/test_flashinfer_moe.py
Normal file
@@ -0,0 +1,152 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_quant_config
|
||||
from tests.kernels.quantization.nvfp4_utils import (
|
||||
FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX,
|
||||
dequantize_nvfp4_to_dtype,
|
||||
)
|
||||
from tests.kernels.utils import torch_moe
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts,
|
||||
is_valid_flashinfer_cutlass_fused_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import (
|
||||
create_flashinfer_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
|
||||
100
|
||||
):
|
||||
pytest.skip(
|
||||
"Requires flashinfer_cutlass_fused_moe and nvfp4 support",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
MNK_FACTORS = [
|
||||
(2, 1024, 1024),
|
||||
(2, 3072, 1024),
|
||||
(2, 3072, 1536),
|
||||
(64, 1024, 1536),
|
||||
(64, 3072, 1024),
|
||||
(64, 2048, 1536),
|
||||
(224, 1024, 1024),
|
||||
(224, 1024, 1536),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", [40, 64, 256])
|
||||
@pytest.mark.parametrize("topk", [1, 6, 8])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("activation", ["silu_and_mul", "relu2"])
|
||||
@torch.inference_mode()
|
||||
def test_flashinfer_fp4_moe_no_graph(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
activation: str,
|
||||
workspace_init,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||
):
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
|
||||
quant_blocksize = 16
|
||||
is_gated_act = activation == "silu_and_mul"
|
||||
|
||||
w1_q, w2_q, quant_config = make_test_quant_config(
|
||||
e,
|
||||
n,
|
||||
k,
|
||||
in_dtype=dtype,
|
||||
quant_dtype="nvfp4",
|
||||
block_shape=None,
|
||||
per_act_token_quant=False,
|
||||
make_gate=is_gated_act,
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)
|
||||
|
||||
assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q)
|
||||
|
||||
flashinfer_experts = FusedMoEModularKernel(
|
||||
create_flashinfer_prepare_finalize(use_dp=False, use_nvfp4=True),
|
||||
FlashInferExperts(out_dtype=dtype, quant_config=quant_config),
|
||||
)
|
||||
|
||||
fi_activation = {"silu_and_mul": "silu", "relu2": "relu2_no_mul"}[activation]
|
||||
|
||||
flashinfer_output = flashinfer_experts(
|
||||
hidden_states=a,
|
||||
w1=w1_q,
|
||||
w2=w2_q,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=fi_activation,
|
||||
)
|
||||
|
||||
# Reference check:
|
||||
a_global_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale)
|
||||
_, m_k = a_fp4.shape
|
||||
a_in_dtype = dequantize_nvfp4_to_dtype(
|
||||
a_fp4,
|
||||
a_scale_interleaved,
|
||||
a_global_scale,
|
||||
dtype=a.dtype,
|
||||
device=a.device,
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
|
||||
w1_d = torch.empty(
|
||||
(e, (2 if is_gated_act else 1) * n, k), device="cuda", dtype=dtype
|
||||
)
|
||||
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
|
||||
|
||||
for idx in range(0, e):
|
||||
w1_d[idx] = dequantize_nvfp4_to_dtype(
|
||||
w1_q[idx],
|
||||
quant_config.w1_scale[idx],
|
||||
(1 / quant_config.g1_alphas[idx]),
|
||||
dtype=dtype,
|
||||
device=w1_q.device,
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
w2_d[idx] = dequantize_nvfp4_to_dtype(
|
||||
w2_q[idx],
|
||||
quant_config.w2_scale[idx],
|
||||
(1 / quant_config.g2_alphas[idx]),
|
||||
dtype=dtype,
|
||||
device=w2_q.device,
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
|
||||
torch_output = torch_moe(
|
||||
a_in_dtype, w1_d, w2_d, score, topk, activation=activation
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
torch_output, flashinfer_output, atol=1e-1, rtol=1e-1
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_flashinfer_fp4_moe_no_graph((2, 1024, 1024), 40, 1, torch.half)
|
||||
350
tests/kernels/moe/test_gpt_oss_triton_kernels.py
Normal file
350
tests/kernels/moe/test_gpt_oss_triton_kernels.py
Normal file
@@ -0,0 +1,350 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass, fields
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.utils.import_utils import has_triton_kernels
|
||||
|
||||
if not has_triton_kernels():
|
||||
pytest.skip(
|
||||
"triton_kernels not found, skipping all related tests",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
import triton_kernels.swiglu
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
from triton_kernels.numerics import InFlexData
|
||||
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp
|
||||
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
||||
from triton_kernels.tensor_details import layout
|
||||
from triton_kernels.testing import assert_close
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||
triton_kernel_moe_forward,
|
||||
)
|
||||
from vllm.model_executor.layers.utils import shuffle_weight
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
|
||||
def deshuffle(w: torch.Tensor):
|
||||
first = w[..., ::2]
|
||||
second = w[..., 1::2]
|
||||
|
||||
deshuffled = torch.concat((first, second), dim=-1)
|
||||
return deshuffled
|
||||
|
||||
|
||||
def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
|
||||
randbits = [torch.randperm(E) for _ in range(M)]
|
||||
x_list = [
|
||||
(-1) ** i
|
||||
* ((16384 + ((i * 512) % 4096) + bits).to(torch.int16).view(torch.bfloat16))
|
||||
for i, bits in enumerate(randbits)
|
||||
]
|
||||
exp_data = torch.stack(x_list).to(device="cuda") # simulating gate_output (M, E)
|
||||
|
||||
# create input tensor
|
||||
x = torch.randn((M, K), dtype=torch.bfloat16, device="cuda")
|
||||
w1 = torch.randn((E, 2 * N, K), dtype=torch.bfloat16, device="cuda")
|
||||
w1_bias = torch.randn((E, 2 * N), dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
w2 = torch.randn((E, K, N), dtype=torch.bfloat16, device="cuda")
|
||||
w2_bias = torch.randn((E, K), dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
exp_data_tri = exp_data.clone()
|
||||
x_tri = x.clone()
|
||||
w1_tri = w1.clone()
|
||||
w2_tri = w2.clone()
|
||||
|
||||
w1_bias_tri = w1_bias.clone()
|
||||
w2_bias_tri = w2_bias.clone()
|
||||
w1_bias_tri = w1_bias_tri.to(torch.float32)
|
||||
w2_bias_tri = w2_bias_tri.to(torch.float32)
|
||||
|
||||
dtype_dict = {
|
||||
"bf16": torch.bfloat16,
|
||||
"fp8_e4m3": torch.float8_e4m3fn,
|
||||
"fp8_e5m2": torch.float8_e5m2,
|
||||
}
|
||||
|
||||
x = x.to(dtype_dict[a_dtype]).to(torch.bfloat16)
|
||||
if w_dtype != "mx4":
|
||||
# simulate quantization support on reference impl
|
||||
w1 = w1.to(dtype_dict[w_dtype]).to(torch.bfloat16)
|
||||
w2 = w2.to(dtype_dict[w_dtype]).to(torch.bfloat16)
|
||||
|
||||
# triton moe kernel use transposed shape for matmul
|
||||
w1_tri = w1_tri.transpose(-2, -1)
|
||||
w2_tri = w2_tri.transpose(-2, -1)
|
||||
|
||||
# shuffle weights
|
||||
w1_tri = shuffle_weight(w1_tri)
|
||||
w1_bias_tri = shuffle_weight(w1_bias_tri)
|
||||
|
||||
# quant triton_weights
|
||||
x_tri = x.to(dtype_dict[a_dtype])
|
||||
if w_dtype != "mx4":
|
||||
pytest.skip("NYI")
|
||||
else: # quantize to mx4
|
||||
# careful on the padding here, the activation padding need to be
|
||||
# multiple of 64, the actual engine is not implemented
|
||||
w1_bottom_pad = round_up(w1_tri.shape[1], 64) - w1_tri.shape[1]
|
||||
w1_right_pad = round_up(w1_tri.shape[2], 128) - w1_tri.shape[2]
|
||||
|
||||
w2_bottom_pad = w1_right_pad // 2
|
||||
w2_right_pad = w1_bottom_pad
|
||||
|
||||
x_pad = w1_bottom_pad
|
||||
|
||||
w1_tri = F.pad(
|
||||
w1_tri,
|
||||
(0, w1_right_pad, 0, w1_bottom_pad, 0, 0),
|
||||
mode="constant",
|
||||
value=0,
|
||||
)
|
||||
w2_tri = F.pad(
|
||||
w2_tri,
|
||||
(0, w2_right_pad, 0, w2_bottom_pad, 0, 0),
|
||||
mode="constant",
|
||||
value=0,
|
||||
)
|
||||
|
||||
w1_bias_tri = F.pad(
|
||||
w1_bias_tri, (0, w1_right_pad, 0, 0), mode="constant", value=0
|
||||
)
|
||||
w2_bias_tri = F.pad(
|
||||
w2_bias_tri, (0, w2_right_pad, 0, 0), mode="constant", value=0
|
||||
)
|
||||
|
||||
x_tri = F.pad(x_tri, (0, x_pad, 0, 0), mode="constant", value=0)
|
||||
|
||||
w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
|
||||
w_scale_layout, w_scale_layout_opts = (
|
||||
layout.make_default_matmul_mxfp4_w_scale_layout(
|
||||
mx_axis=1, num_warps=num_warps
|
||||
)
|
||||
)
|
||||
|
||||
w1_tri, w1_scale_tri = downcast_to_mxfp(w1_tri, torch.uint8, axis=1)
|
||||
w1 = upcast_from_mxfp(w1_tri, w1_scale_tri, torch.bfloat16, axis=1)
|
||||
|
||||
w2_tri, w2_scale_tri = downcast_to_mxfp(w2_tri, torch.uint8, axis=1)
|
||||
w2 = upcast_from_mxfp(w2_tri, w2_scale_tri, torch.bfloat16, axis=1)
|
||||
|
||||
w1_tri = convert_layout(
|
||||
wrap_torch_tensor(w1_tri, FP4), w_layout, **w_layout_opts
|
||||
)
|
||||
w1_scale_tri = convert_layout(
|
||||
wrap_torch_tensor(w1_scale_tri),
|
||||
w_scale_layout,
|
||||
**w_scale_layout_opts,
|
||||
)
|
||||
|
||||
w2_tri = convert_layout(
|
||||
wrap_torch_tensor(w2_tri, FP4), w_layout, **w_layout_opts
|
||||
)
|
||||
w2_scale_tri = convert_layout(
|
||||
wrap_torch_tensor(w2_scale_tri),
|
||||
w_scale_layout,
|
||||
**w_scale_layout_opts,
|
||||
)
|
||||
|
||||
pc1 = PrecisionConfig(
|
||||
weight_scale=w1_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData())
|
||||
)
|
||||
pc2 = PrecisionConfig(
|
||||
weight_scale=w2_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData())
|
||||
)
|
||||
|
||||
# tucuate so the rest can run properly
|
||||
w1 = w1[..., :K, : 2 * N]
|
||||
w2 = w2[..., :N, :K]
|
||||
|
||||
w1 = deshuffle(w1)
|
||||
|
||||
w1 = w1.transpose(-1, -2).contiguous()
|
||||
w2 = w2.transpose(-1, -2).contiguous()
|
||||
|
||||
return (
|
||||
x,
|
||||
w1,
|
||||
w1_bias,
|
||||
w2,
|
||||
w2_bias,
|
||||
exp_data,
|
||||
x_tri,
|
||||
w1_tri,
|
||||
w2_tri,
|
||||
exp_data_tri,
|
||||
w1_bias_tri,
|
||||
w2_bias_tri,
|
||||
pc1,
|
||||
pc2,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
num_hidden_layers: int = 36
|
||||
num_experts: int = 128
|
||||
experts_per_token: int = 4
|
||||
vocab_size: int = 201088
|
||||
hidden_size: int = 2880
|
||||
intermediate_size: int = 2880
|
||||
head_dim: int = 64
|
||||
num_attention_heads: int = 64
|
||||
num_key_value_heads: int = 8
|
||||
sliding_window: int = 128
|
||||
initial_context_length: int = 4096
|
||||
rope_theta: float = 150000.0
|
||||
rope_parameters_factor: float = 32.0
|
||||
rope_ntk_alpha: float = 1.0
|
||||
rope_ntk_beta: float = 32.0
|
||||
|
||||
|
||||
def swiglu(x, alpha: float = 1.702, limit: float = 1.0):
|
||||
# Note we add an extra bias of 1 to the linear layer
|
||||
x_glu, x_linear = torch.chunk(x, 2, dim=-1)
|
||||
if limit is not None:
|
||||
x_glu = x_glu.clamp(max=limit)
|
||||
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
|
||||
if limit is not None:
|
||||
x_linear = x_linear.clamp(min=-limit, max=limit)
|
||||
return out_glu * (x_linear + 1)
|
||||
|
||||
|
||||
def oai_moe_forward(
|
||||
hidden_states: torch.Tensor, # (M, K)
|
||||
w1: torch.Tensor, # (E, 2N)
|
||||
w1_bias: torch.Tensor, # (E, 2N, K)
|
||||
w2: torch.Tensor, # (E, K, N)
|
||||
w2_bias: torch.Tensor, # (E, N)
|
||||
gating_output: torch.Tensor, # (M, E)
|
||||
topk: int,
|
||||
):
|
||||
# model.py 309:330, assuming gating and norm
|
||||
t = hidden_states
|
||||
experts = torch.topk(gating_output, k=topk, dim=-1, sorted=True)
|
||||
expert_weights = torch.nn.functional.softmax(experts.values, dim=1)
|
||||
expert_indices = experts.indices
|
||||
|
||||
# MLP #1
|
||||
mlp1_weight = w1[expert_indices, ...]
|
||||
mlp1_bias = w1_bias[expert_indices, ...]
|
||||
t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias
|
||||
t = swiglu(t, limit=7)
|
||||
|
||||
# MLP #2
|
||||
mlp2_weight = w2[expert_indices, ...]
|
||||
mlp2_bias = w2_bias[expert_indices, ...]
|
||||
t = torch.einsum("beck,bek->bec", mlp2_weight, t)
|
||||
t += mlp2_bias
|
||||
|
||||
# Weighted sum of experts
|
||||
t = torch.einsum("bec,be->bc", t, expert_weights)
|
||||
|
||||
return t
|
||||
|
||||
|
||||
@dataclass
|
||||
class Case:
|
||||
a_dtype: str
|
||||
w_dtype: str
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
", ".join(f.name for f in fields(Case)),
|
||||
[
|
||||
tuple(getattr(case, f.name) for f in fields(Case))
|
||||
for case in [
|
||||
# Case(a_dtype="bf16", w_dtype="bf16"),
|
||||
# Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"),
|
||||
Case(a_dtype="bf16", w_dtype="mx4")
|
||||
]
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("num_token", [2])
|
||||
@pytest.mark.parametrize("tp", [1, 2, 4, 8])
|
||||
def test_equiv(num_token, a_dtype, w_dtype, tp, workspace_init):
|
||||
from triton_kernels.tensor_details import layout
|
||||
|
||||
if not hasattr(layout, "make_default_matmul_mxfp4_w_layout"):
|
||||
pytest.skip("make_default_matmul_mxfp4_w_layout not available")
|
||||
|
||||
M = num_token
|
||||
E = ModelConfig.num_experts
|
||||
K = ModelConfig.hidden_size
|
||||
N = ModelConfig.intermediate_size // tp
|
||||
topk = ModelConfig.experts_per_token
|
||||
|
||||
(
|
||||
x,
|
||||
w1,
|
||||
w1_bias,
|
||||
w2,
|
||||
w2_bias,
|
||||
exp_data,
|
||||
x_tri,
|
||||
w1_tri,
|
||||
w2_tri,
|
||||
exp_data_tri,
|
||||
w1_bias_tri,
|
||||
w2_bias_tri,
|
||||
pc1,
|
||||
pc2,
|
||||
) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8)
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
w1_bias=w1_bias_tri,
|
||||
w2_bias=w2_bias_tri,
|
||||
w1_scale=pc1,
|
||||
w2_scale=pc2,
|
||||
)
|
||||
|
||||
out_triton_monolithic = triton_kernel_moe_forward(
|
||||
hidden_states=x_tri,
|
||||
w1=w1_tri,
|
||||
w2=w2_tri,
|
||||
gating_output=exp_data_tri,
|
||||
topk=topk,
|
||||
renormalize=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
out_triton_monolithic = out_triton_monolithic[..., :K]
|
||||
|
||||
out_ref = oai_moe_forward(
|
||||
hidden_states=x,
|
||||
w1=w1,
|
||||
w1_bias=w1_bias,
|
||||
w2=w2,
|
||||
w2_bias=w2_bias,
|
||||
gating_output=exp_data,
|
||||
topk=topk,
|
||||
)
|
||||
assert_close(ref=out_ref, tri=out_triton_monolithic, maxtol=0.025, rmstol=0.005)
|
||||
|
||||
|
||||
def test_unit_shuffle():
|
||||
N = ModelConfig.intermediate_size
|
||||
K = ModelConfig.hidden_size
|
||||
m = torch.randn((K, 2 * N), dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
x = torch.randn(K, dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
m_shuffled = shuffle_weight(m)
|
||||
|
||||
out_ref = x @ m
|
||||
out_ref = swiglu(out_ref, limit=1.0)
|
||||
|
||||
out = x @ m_shuffled
|
||||
out = triton_kernels.swiglu.swiglu_torch(
|
||||
out,
|
||||
alpha=1.702,
|
||||
precision_config=triton_kernels.swiglu.PrecisionConfig(limit=1.0),
|
||||
)
|
||||
|
||||
assert_close(ref=out_ref, tri=out)
|
||||
81
tests/kernels/moe/test_grouped_topk.py
Normal file
81
tests/kernels/moe/test_grouped_topk.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for the MoE grouped topk kernel
|
||||
|
||||
Run `pytest tests/kernels/moe/test_grouped_topk.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_grouped_topk,
|
||||
grouped_topk,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
|
||||
)
|
||||
@pytest.mark.parametrize("n_token", [1, 33, 64])
|
||||
@pytest.mark.parametrize("n_hidden", [1024, 2048])
|
||||
@pytest.mark.parametrize("n_expert", [16])
|
||||
@pytest.mark.parametrize("topk", [2])
|
||||
@pytest.mark.parametrize("renormalize", [True, False])
|
||||
@pytest.mark.parametrize("num_expert_group", [8])
|
||||
@pytest.mark.parametrize("topk_group", [2])
|
||||
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
|
||||
@pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
|
||||
def test_grouped_topk(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
n_token: int,
|
||||
n_hidden: int,
|
||||
n_expert: int,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
scoring_func: str,
|
||||
routed_scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
hidden_states = torch.randn((n_token, n_hidden), dtype=dtype, device="cuda")
|
||||
gating_output = torch.randn((n_token, n_expert), dtype=dtype, device="cuda")
|
||||
e_score_correction_bias = torch.randn(
|
||||
(n_expert,), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0")
|
||||
baseline_topk_weights, baseline_topk_ids = grouped_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=gating_output,
|
||||
topk=topk,
|
||||
renormalize=renormalize,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
test_topk_weights, test_topk_ids = fused_grouped_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=gating_output,
|
||||
topk=topk,
|
||||
renormalize=renormalize,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
if renormalize:
|
||||
torch.testing.assert_close(
|
||||
baseline_topk_weights, test_topk_weights, atol=2e-2, rtol=0
|
||||
)
|
||||
torch.testing.assert_close(baseline_topk_ids, test_topk_ids, atol=0, rtol=0)
|
||||
350
tests/kernels/moe/test_modular_kernel_combinations.py
Normal file
350
tests/kernels/moe/test_modular_kernel_combinations.py
Normal file
@@ -0,0 +1,350 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
import textwrap
|
||||
import traceback
|
||||
from itertools import product
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
|
||||
from .modular_kernel_tools.common import (
|
||||
Config,
|
||||
RankTensors,
|
||||
WeightTensors,
|
||||
reference_moe_impl,
|
||||
run_modular_kernel,
|
||||
)
|
||||
from .modular_kernel_tools.mk_objects import (
|
||||
MK_FUSED_EXPERT_TYPES,
|
||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
|
||||
MK_QUANT_CONFIGS,
|
||||
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES,
|
||||
TestMoEQuantConfig,
|
||||
expert_info,
|
||||
)
|
||||
from .modular_kernel_tools.parallel_utils import (
|
||||
ProcessGroupInfo,
|
||||
parallel_launch_with_config,
|
||||
)
|
||||
|
||||
has_any_multi_gpu_package = (
|
||||
has_deep_ep() or has_deep_gemm() or has_pplx() or has_flashinfer_cutlass_fused_moe()
|
||||
)
|
||||
|
||||
meets_multi_gpu_requirements = pytest.mark.skipif(
|
||||
not has_any_multi_gpu_package,
|
||||
reason="Requires deep_ep or deep_gemm or pplx or flashinfer packages",
|
||||
)
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
pytest.skip(
|
||||
"Tests in this file require float8_e4m3fn and platform does not support",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
|
||||
def format_result(verbose, msg, ex=None):
|
||||
if ex is not None:
|
||||
x = str(ex)
|
||||
newx = x.strip(" \n\t")[:16]
|
||||
if len(newx) < len(x):
|
||||
newx = newx + " ..."
|
||||
|
||||
prefix = "E\t"
|
||||
print(f"{textwrap.indent(traceback.format_exc(), prefix)}")
|
||||
print(f"FAILED {msg} - {newx}\n")
|
||||
elif verbose:
|
||||
print(f"PASSED {msg}")
|
||||
else:
|
||||
print(".", end="")
|
||||
|
||||
|
||||
def rank_worker(
|
||||
pgi: ProcessGroupInfo,
|
||||
vllm_config: VllmConfig,
|
||||
cpu_group,
|
||||
base_config: Config,
|
||||
weights: WeightTensors,
|
||||
verbose: bool,
|
||||
):
|
||||
# Initialize workspace manager in child process
|
||||
device = torch.device(f"cuda:{pgi.local_rank}")
|
||||
init_workspace_manager(device)
|
||||
|
||||
current_platform.seed_everything(pgi.rank)
|
||||
|
||||
# sanity check
|
||||
from vllm import envs
|
||||
|
||||
if base_config.fused_moe_chunk_size is not None:
|
||||
assert base_config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
|
||||
# get weights to this device
|
||||
weights.to_current_device()
|
||||
|
||||
Ms = base_config.Ms
|
||||
assert isinstance(Ms, list)
|
||||
TOPKs = base_config.topks
|
||||
assert isinstance(TOPKs, list)
|
||||
|
||||
exceptions = []
|
||||
count = 0
|
||||
|
||||
for m, topk in product(Ms, TOPKs):
|
||||
# override m and topk
|
||||
config = copy.deepcopy(base_config)
|
||||
config.Ms = m
|
||||
config.topks = topk
|
||||
|
||||
try:
|
||||
print(f"Running[{pgi.rank}]: m={m}, topk={topk} ...")
|
||||
count = count + 1
|
||||
|
||||
# inputs for rank
|
||||
rank_tensors = RankTensors.make(config, pgi)
|
||||
|
||||
# modular kernel out
|
||||
mk_out = run_modular_kernel(pgi, vllm_config, config, weights, rank_tensors)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ref_out = reference_moe_impl(config, weights, rank_tensors)
|
||||
|
||||
if config.quant_dtype == "nvfp4":
|
||||
atol = 1e-1 if config.K < 4096 else 2e-1
|
||||
rtol = 1e-1 if config.K < 4096 else 2e-1
|
||||
else:
|
||||
atol = 3e-2
|
||||
rtol = 3e-2
|
||||
|
||||
torch.testing.assert_close(ref_out, mk_out, atol=atol, rtol=rtol)
|
||||
format_result(verbose, config.describe())
|
||||
except Exception as ex:
|
||||
format_result(verbose, config.describe(), ex)
|
||||
exceptions.append(ex)
|
||||
|
||||
if len(exceptions) > 0:
|
||||
raise RuntimeError(
|
||||
f"{len(exceptions)} of {count} tests failed in child process, "
|
||||
f"rank={pgi.rank}."
|
||||
)
|
||||
else:
|
||||
print(f"{count} of {count} tests passed in child process, rank={pgi.rank}.")
|
||||
|
||||
|
||||
def run(config: Config, verbose: bool):
|
||||
assert config.is_valid()[0]
|
||||
assert not is_nyi_config(config)
|
||||
|
||||
weights: WeightTensors = WeightTensors.make(config)
|
||||
|
||||
vllm_config, env_dict = config.make_env_data()
|
||||
parallel_launch_with_config(
|
||||
config.world_size, rank_worker, vllm_config, env_dict, config, weights, verbose
|
||||
)
|
||||
|
||||
|
||||
Ms = [32, 64]
|
||||
# hidden sizes, making this too large will cause fp4 tests to fail.
|
||||
# Also needs to be a multiple of 1024 for deep_gemm.
|
||||
Ks = [2048]
|
||||
Ns = [1024]
|
||||
TOPKs = [4, 1]
|
||||
Es = [32]
|
||||
DTYPEs = [torch.bfloat16]
|
||||
FUSED_MOE_CHUNK_SIZEs = [None, 16]
|
||||
|
||||
|
||||
def is_nyi_config(config: Config) -> bool:
|
||||
# We know these configs to be legitimate. but still fail.
|
||||
info = expert_info(config.fused_experts_type)
|
||||
|
||||
if info.needs_matching_quant:
|
||||
# The triton kernels expect both per-act-token-quant and
|
||||
# per-out-ch-quant or neither.
|
||||
unsupported_quant_config = (
|
||||
config.is_per_act_token_quant + config.is_per_out_ch_quant
|
||||
) == 1
|
||||
return unsupported_quant_config
|
||||
|
||||
return not info.supports_expert_map
|
||||
|
||||
|
||||
def generate_valid_test_cases(
|
||||
world_size: int, prepare_finalize_types
|
||||
) -> list[tuple[Any, ...]]:
|
||||
cases = []
|
||||
total = 0
|
||||
|
||||
for k, n, e, dtype, quant_config, combination, chunk_size in product(
|
||||
Ks,
|
||||
Ns,
|
||||
Es,
|
||||
DTYPEs,
|
||||
MK_QUANT_CONFIGS,
|
||||
product(prepare_finalize_types, MK_FUSED_EXPERT_TYPES),
|
||||
FUSED_MOE_CHUNK_SIZEs,
|
||||
):
|
||||
total = total + 1
|
||||
|
||||
config = Config(
|
||||
Ms=Ms,
|
||||
K=k,
|
||||
N=n,
|
||||
E=e,
|
||||
topks=TOPKs,
|
||||
dtype=dtype,
|
||||
quant_config=quant_config,
|
||||
prepare_finalize_type=combination[0],
|
||||
fused_experts_type=combination[1],
|
||||
fused_moe_chunk_size=chunk_size,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
# TODO(bnell): figure out how to get verbose flag here.
|
||||
verbose = False # pytestconfig.getoption('verbose') > 0
|
||||
|
||||
valid, reason = config.is_valid()
|
||||
|
||||
if not valid:
|
||||
if verbose:
|
||||
print(f"Test config {config} is not valid: {reason}")
|
||||
continue
|
||||
|
||||
if is_nyi_config(config):
|
||||
if verbose:
|
||||
print(f"Test config {config} is nyi.")
|
||||
continue
|
||||
|
||||
cases.append(
|
||||
(
|
||||
k,
|
||||
n,
|
||||
e,
|
||||
dtype,
|
||||
quant_config,
|
||||
combination[0],
|
||||
combination[1],
|
||||
chunk_size,
|
||||
world_size,
|
||||
)
|
||||
)
|
||||
|
||||
print(f"{len(cases)} of {total} valid configs generated.")
|
||||
|
||||
return cases
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size",
|
||||
generate_valid_test_cases(
|
||||
world_size=2, prepare_finalize_types=MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
|
||||
),
|
||||
)
|
||||
@meets_multi_gpu_requirements
|
||||
def test_modular_kernel_combinations_multigpu(
|
||||
k: int,
|
||||
n: int,
|
||||
e: int,
|
||||
dtype: torch.dtype,
|
||||
quant_config: TestMoEQuantConfig | None,
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
|
||||
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
|
||||
chunk_size: int | None,
|
||||
world_size: int,
|
||||
pytestconfig,
|
||||
):
|
||||
if cuda_device_count_stateless() < world_size:
|
||||
pytest.skip(
|
||||
f"Not enough GPUs available to run, got "
|
||||
f"{cuda_device_count_stateless()} exepected "
|
||||
f"{world_size}."
|
||||
)
|
||||
|
||||
config = Config(
|
||||
Ms=Ms,
|
||||
K=k,
|
||||
N=n,
|
||||
E=e,
|
||||
topks=TOPKs,
|
||||
dtype=dtype,
|
||||
quant_config=quant_config,
|
||||
prepare_finalize_type=prepare_finalize_type,
|
||||
fused_experts_type=fused_experts_type,
|
||||
fused_moe_chunk_size=chunk_size,
|
||||
world_size=world_size,
|
||||
)
|
||||
verbosity = pytestconfig.getoption("verbose")
|
||||
run(config, verbosity > 0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size",
|
||||
generate_valid_test_cases(
|
||||
world_size=1, prepare_finalize_types=MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
|
||||
),
|
||||
)
|
||||
def test_modular_kernel_combinations_singlegpu(
|
||||
k: int,
|
||||
n: int,
|
||||
e: int,
|
||||
dtype: torch.dtype,
|
||||
quant_config: TestMoEQuantConfig | None,
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
|
||||
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
|
||||
chunk_size: int | None,
|
||||
world_size: int,
|
||||
pytestconfig,
|
||||
workspace_init,
|
||||
):
|
||||
"""Note: float8_e4m3fn is not supported on CUDA architecture < 89,
|
||||
and those tests will be skipped on unsupported hardware."""
|
||||
config = Config(
|
||||
Ms=Ms,
|
||||
K=k,
|
||||
N=n,
|
||||
E=e,
|
||||
topks=TOPKs,
|
||||
dtype=dtype,
|
||||
quant_config=quant_config,
|
||||
prepare_finalize_type=prepare_finalize_type,
|
||||
fused_experts_type=fused_experts_type,
|
||||
fused_moe_chunk_size=chunk_size,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
if (
|
||||
quant_config is not None and quant_config.quant_dtype == torch.float8_e4m3fn
|
||||
) and not current_platform.has_device_capability(89):
|
||||
pytest.skip(
|
||||
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
|
||||
)
|
||||
verbosity = pytestconfig.getoption("verbose")
|
||||
run(config, verbosity > 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Ability to test individual PrepareAndFinalize and FusedExperts combination
|
||||
from .modular_kernel_tools.cli_args import make_config, make_config_arg_parser
|
||||
|
||||
parser = make_config_arg_parser(
|
||||
description=(
|
||||
"Run single prepare-finalize & fused-experts combination test"
|
||||
"Example : python3 -m tests.kernels.moe.test_modular_kernel_combinations "
|
||||
"--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts"
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
config = make_config(args)
|
||||
|
||||
run(config, True)
|
||||
250
tests/kernels/moe/test_modular_oai_triton_moe.py
Normal file
250
tests/kernels/moe/test_modular_oai_triton_moe.py
Normal file
@@ -0,0 +1,250 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Test modular OAI Triton MoE
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.utils.import_utils import has_triton_kernels
|
||||
|
||||
if not has_triton_kernels():
|
||||
pytest.skip(
|
||||
"triton_kernels not found, skipping all related tests",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
from triton_kernels.numerics import InFlexData
|
||||
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp
|
||||
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
||||
from triton_kernels.tensor_details import layout
|
||||
from triton_kernels.testing import assert_close
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import mxfp4_w4a16_moe_quant_config
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||
OAITritonExperts,
|
||||
UnfusedOAITritonExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.utils import shuffle_weight
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MNK = [
|
||||
(1, 512, 384),
|
||||
(1, 2880, 2880),
|
||||
(2, 512, 384),
|
||||
(2, 2880, 2880),
|
||||
(16, 2880, 2880),
|
||||
]
|
||||
|
||||
|
||||
def unshuffle_weight(w: torch.Tensor):
|
||||
first = w[..., ::2]
|
||||
second = w[..., 1::2]
|
||||
return torch.concat((first, second), dim=-1)
|
||||
|
||||
|
||||
def make_weights(dtype, k, n, e):
|
||||
w1 = torch.randn((e, k, 2 * n), dtype=dtype, device="cuda")
|
||||
w1_bias = torch.randn((e, 2 * n), dtype=dtype, device="cuda")
|
||||
|
||||
w2 = torch.randn((e, n, k), dtype=dtype, device="cuda")
|
||||
w2_bias = torch.randn((e, k), dtype=dtype, device="cuda")
|
||||
|
||||
w1_tri = w1.clone()
|
||||
w2_tri = w2.clone()
|
||||
|
||||
w1_bias_tri = w1_bias.clone()
|
||||
w2_bias_tri = w2_bias.clone()
|
||||
w1_bias_tri = w1_bias_tri.to(torch.float32)
|
||||
w2_bias_tri = w2_bias_tri.to(torch.float32)
|
||||
|
||||
# shuffle weights
|
||||
w1_tri = shuffle_weight(w1_tri)
|
||||
w1_bias_tri = shuffle_weight(w1_bias_tri)
|
||||
|
||||
# quant triton_weights
|
||||
w1_tri, w1_scale_tri = downcast_to_mxfp(w1_tri, torch.uint8, axis=1)
|
||||
w1 = upcast_from_mxfp(w1_tri, w1_scale_tri, dtype, axis=1)
|
||||
w1 = unshuffle_weight(w1)
|
||||
|
||||
w2_tri, w2_scale_tri = downcast_to_mxfp(w2_tri, torch.uint8, axis=1)
|
||||
w2 = upcast_from_mxfp(w2_tri, w2_scale_tri, dtype, axis=1)
|
||||
|
||||
num_warps = 8
|
||||
w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
|
||||
w_scale_layout, w_scale_layout_opts = (
|
||||
layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=1, num_warps=num_warps)
|
||||
)
|
||||
|
||||
w1_tri = convert_layout(wrap_torch_tensor(w1_tri, FP4), w_layout, **w_layout_opts)
|
||||
w1_scale_tri = convert_layout(
|
||||
wrap_torch_tensor(w1_scale_tri),
|
||||
w_scale_layout,
|
||||
**w_scale_layout_opts,
|
||||
)
|
||||
|
||||
w2_tri = convert_layout(wrap_torch_tensor(w2_tri, FP4), w_layout, **w_layout_opts)
|
||||
w2_scale_tri = convert_layout(
|
||||
wrap_torch_tensor(w2_scale_tri),
|
||||
w_scale_layout,
|
||||
**w_scale_layout_opts,
|
||||
)
|
||||
|
||||
w1_precision_config = PrecisionConfig(
|
||||
weight_scale=w1_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData())
|
||||
)
|
||||
w2_precision_config = PrecisionConfig(
|
||||
weight_scale=w2_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData())
|
||||
)
|
||||
|
||||
return (
|
||||
w1,
|
||||
w2,
|
||||
w1_bias,
|
||||
w2_bias,
|
||||
w1_tri,
|
||||
w2_tri,
|
||||
w1_bias_tri,
|
||||
w2_bias_tri,
|
||||
w1_precision_config,
|
||||
w2_precision_config,
|
||||
)
|
||||
|
||||
|
||||
def swiglu(x, alpha: float = 1.702, limit: float = 1.0):
|
||||
# Note we add an extra bias of 1 to the linear layer
|
||||
x_glu, x_linear = torch.chunk(x, 2, dim=-1)
|
||||
if limit is not None:
|
||||
x_glu = x_glu.clamp(max=limit)
|
||||
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
|
||||
if limit is not None:
|
||||
x_linear = x_linear.clamp(min=-limit, max=limit)
|
||||
return out_glu * (x_linear + 1)
|
||||
|
||||
|
||||
def torch_moe_impl(
|
||||
hidden_states: torch.Tensor, # (M, K)
|
||||
w1: torch.Tensor, # (E, K, 2N)
|
||||
w2: torch.Tensor, # (E, N, K)
|
||||
w1_bias: torch.Tensor, # (E, 2N)
|
||||
w2_bias: torch.Tensor, # (E, K)
|
||||
topk_weights: torch.Tensor, # (M, topk)
|
||||
topk_ids: torch.Tensor, # (M, topk)
|
||||
):
|
||||
w1 = w1[topk_ids, ...]
|
||||
w1_bias = w1_bias[topk_ids, ...]
|
||||
hidden_states = torch.einsum("bekc,bk->bec", w1, hidden_states) + w1_bias
|
||||
hidden_states = swiglu(hidden_states, limit=7)
|
||||
|
||||
w2 = w2[topk_ids, ...]
|
||||
w2_bias = w2_bias[topk_ids, ...]
|
||||
hidden_states = torch.einsum("bekc,bek->bec", w2, hidden_states) + w2_bias
|
||||
|
||||
# Weighted sum of experts
|
||||
hidden_states = torch.einsum("bec,be->bc", hidden_states, topk_weights)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def oai_triton_moe_impl(
|
||||
x: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: "PrecisionConfig",
|
||||
w2_scale: "PrecisionConfig",
|
||||
w1_bias: torch.Tensor | None,
|
||||
w2_bias: torch.Tensor | None,
|
||||
num_experts: int,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
unfused: bool = False,
|
||||
) -> torch.Tensor:
|
||||
quant_config = mxfp4_w4a16_moe_quant_config(
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
)
|
||||
|
||||
if unfused:
|
||||
fused_experts = UnfusedOAITritonExperts(quant_config)
|
||||
else:
|
||||
fused_experts = OAITritonExperts(quant_config)
|
||||
|
||||
mk = FusedMoEModularKernel(MoEPrepareAndFinalizeNoEP(), fused_experts)
|
||||
|
||||
return mk.forward(
|
||||
hidden_states=x,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation="swigluoai",
|
||||
global_num_experts=num_experts,
|
||||
expert_map=None,
|
||||
apply_router_weight_on_input=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("m,n,k", MNK)
|
||||
@pytest.mark.parametrize("num_experts", [32, 128])
|
||||
@pytest.mark.parametrize("topk", [4])
|
||||
@pytest.mark.parametrize("unfused", [True, False])
|
||||
def test_oai_triton_moe(
|
||||
dtype: torch.dtype,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
unfused: bool,
|
||||
workspace_init,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
(
|
||||
w1,
|
||||
w2,
|
||||
w1_bias,
|
||||
w2_bias,
|
||||
w1_tri,
|
||||
w2_tri,
|
||||
w1_bias_tri,
|
||||
w2_bias_tri,
|
||||
w1_precision_config,
|
||||
w2_precision_config,
|
||||
) = make_weights(dtype, k, n, num_experts)
|
||||
|
||||
x = torch.randn((m, k), dtype=dtype, device="cuda")
|
||||
router_logits = torch.randn(m, num_experts, device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1, sorted=True)
|
||||
topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)
|
||||
|
||||
with set_current_vllm_config(VllmConfig()):
|
||||
out_ref = torch_moe_impl(x, w1, w2, w1_bias, w2_bias, topk_weights, topk_ids)
|
||||
|
||||
out = oai_triton_moe_impl(
|
||||
x,
|
||||
w1_tri,
|
||||
w2_tri,
|
||||
w1_precision_config,
|
||||
w2_precision_config,
|
||||
w1_bias_tri,
|
||||
w2_bias_tri,
|
||||
num_experts,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
unfused,
|
||||
)
|
||||
|
||||
assert_close(ref=out_ref, tri=out, maxtol=0.025, rmstol=0.005)
|
||||
1288
tests/kernels/moe/test_moe.py
Normal file
1288
tests/kernels/moe/test_moe.py
Normal file
File diff suppressed because it is too large
Load Diff
408
tests/kernels/moe/test_moe_align_block_size.py
Normal file
408
tests/kernels/moe/test_moe_align_block_size.py
Normal file
@@ -0,0 +1,408 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for the MOE align block size function.
|
||||
|
||||
Run `pytest tests/kernels/moe/test_moe_align_block_size.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
batched_moe_align_block_size,
|
||||
moe_align_block_size,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
NUM_TOKENS = [1, 3, 256, 2256, 4096]
|
||||
NUM_EXPERTS = [32, 160, 256, 257]
|
||||
TOP_KS = [1, 2, 16, 32]
|
||||
BLOCK_SIZES = [32, 128]
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
|
||||
def _group_tokens_by_expert(
|
||||
sorted_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
block_size: int,
|
||||
valid_length: int,
|
||||
total_tokens: int,
|
||||
) -> dict:
|
||||
num_blocks = valid_length // block_size
|
||||
expert_tokens: dict[int, list[int]] = {}
|
||||
|
||||
for block_idx in range(num_blocks):
|
||||
expert_id = expert_ids[block_idx].item()
|
||||
block_start = block_idx * block_size
|
||||
block_end = min(block_start + block_size, valid_length)
|
||||
|
||||
block_tokens = sorted_ids[block_start:block_end]
|
||||
valid_tokens = block_tokens[block_tokens < total_tokens]
|
||||
|
||||
if expert_id not in expert_tokens:
|
||||
expert_tokens[expert_id] = []
|
||||
expert_tokens[expert_id].extend(valid_tokens.tolist())
|
||||
return expert_tokens
|
||||
|
||||
|
||||
def _verify_expert_level_sorting(
|
||||
actual_sorted_ids: torch.Tensor,
|
||||
golden_sorted_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
block_size: int,
|
||||
valid_length: int,
|
||||
total_tokens: int,
|
||||
):
|
||||
"""
|
||||
Verify that actual_sorted_ids follows the correct expert-level sorting.
|
||||
The kerne limplementation may or may not preserve original token order
|
||||
in topk_ids in the final sorted_ids however this does not impact quality.
|
||||
"""
|
||||
# Group tokens by expert from the golden implementation
|
||||
golden_expert_tokens = _group_tokens_by_expert(
|
||||
golden_sorted_ids, expert_ids, block_size, valid_length, total_tokens
|
||||
)
|
||||
|
||||
actual_expert_tokens = _group_tokens_by_expert(
|
||||
actual_sorted_ids, expert_ids, block_size, valid_length, total_tokens
|
||||
)
|
||||
|
||||
assert set(golden_expert_tokens.keys()) == set(actual_expert_tokens.keys()), (
|
||||
f"Expert IDs mismatch: golden={set(golden_expert_tokens.keys())}, "
|
||||
f"actual={set(actual_expert_tokens.keys())}"
|
||||
)
|
||||
|
||||
for expert_id in golden_expert_tokens:
|
||||
golden_tokens = torch.tensor(
|
||||
golden_expert_tokens[expert_id], device=actual_sorted_ids.device
|
||||
)
|
||||
actual_tokens = torch.tensor(
|
||||
actual_expert_tokens[expert_id], device=actual_sorted_ids.device
|
||||
)
|
||||
assert torch.equal(
|
||||
torch.sort(golden_tokens)[0], torch.sort(actual_tokens)[0]
|
||||
), (
|
||||
f"Expert {expert_id} token mismatch: "
|
||||
f"golden={golden_expert_tokens[expert_id]}, "
|
||||
f"actual={actual_expert_tokens[expert_id]}"
|
||||
)
|
||||
|
||||
|
||||
def torch_moe_align_block_size(
|
||||
topk_ids: torch.Tensor,
|
||||
block_size: int,
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
pad_sorted_ids: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Golden torch implementation of moe_align_block_size.
|
||||
|
||||
This function aligns the token distribution across experts to be compatible
|
||||
with block size for matrix multiplication by sorting tokens by expert and
|
||||
padding to block boundaries.
|
||||
"""
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
if pad_sorted_ids:
|
||||
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
|
||||
if topk_ids.numel() < num_experts:
|
||||
max_num_tokens_padded = topk_ids.numel() * block_size
|
||||
|
||||
flattened_token_indices = torch.arange(
|
||||
topk_ids.numel(), device=topk_ids.device, dtype=torch.int32
|
||||
)
|
||||
flattened_expert_ids = topk_ids.flatten()
|
||||
sorted_expert_ids, sort_indices = torch.sort(flattened_expert_ids, stable=True)
|
||||
sorted_token_indices = flattened_token_indices[sort_indices]
|
||||
|
||||
expert_token_counts = torch.zeros(
|
||||
num_experts, dtype=torch.int64, device=topk_ids.device
|
||||
)
|
||||
for expert_id in range(num_experts):
|
||||
mask = sorted_expert_ids == expert_id
|
||||
expert_token_counts[expert_id] = mask.sum()
|
||||
|
||||
expert_padded_counts = torch.zeros(
|
||||
num_experts, dtype=torch.int64, device=topk_ids.device
|
||||
)
|
||||
for expert_id in range(num_experts):
|
||||
original_count = expert_token_counts[expert_id]
|
||||
if expert_map is not None and expert_map[expert_id] == -1:
|
||||
continue
|
||||
if original_count > 0:
|
||||
expert_padded_counts[expert_id] = (
|
||||
(original_count + block_size - 1) // block_size
|
||||
) * block_size
|
||||
|
||||
sorted_token_ids = torch.full(
|
||||
(max_num_tokens_padded,),
|
||||
topk_ids.numel(),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device,
|
||||
)
|
||||
max_num_blocks = (max_num_tokens_padded + block_size - 1) // block_size
|
||||
expert_ids = torch.zeros(max_num_blocks, dtype=torch.int32, device=topk_ids.device)
|
||||
|
||||
current_pos = 0
|
||||
current_block = 0
|
||||
for expert_id in range(num_experts):
|
||||
if expert_map is not None and expert_map[expert_id] == -1:
|
||||
continue
|
||||
|
||||
expert_mask = sorted_expert_ids == expert_id
|
||||
expert_tokens = sorted_token_indices[expert_mask]
|
||||
num_expert_tokens = expert_tokens.shape[0]
|
||||
|
||||
if num_expert_tokens > 0:
|
||||
sorted_token_ids[current_pos : current_pos + num_expert_tokens] = (
|
||||
expert_tokens
|
||||
)
|
||||
|
||||
expert_blocks_needed = expert_padded_counts[expert_id] // block_size
|
||||
|
||||
expert_id_new = expert_id
|
||||
if expert_map is not None:
|
||||
expert_id_new = expert_map[expert_id]
|
||||
expert_ids[current_block : current_block + expert_blocks_needed] = (
|
||||
expert_id_new
|
||||
)
|
||||
|
||||
current_pos += expert_padded_counts[expert_id]
|
||||
current_block += expert_blocks_needed
|
||||
|
||||
total_padded_tokens = expert_padded_counts.sum()
|
||||
num_tokens_post_pad = torch.tensor(
|
||||
[total_padded_tokens], dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
return sorted_token_ids, expert_ids, num_tokens_post_pad
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("pad_sorted_ids", [False, True])
|
||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||
def test_moe_align_block_size(
|
||||
m: int, topk: int, num_experts: int, block_size: int, pad_sorted_ids: bool
|
||||
):
|
||||
"""Test moe_align_block_size without expert mapping"""
|
||||
topk_ids = torch.zeros((m, topk), device="cuda", dtype=torch.int32)
|
||||
for i in range(m):
|
||||
experts = torch.randperm(num_experts, device="cuda")[:topk]
|
||||
topk_ids[i] = experts
|
||||
|
||||
actual_sorted_ids, actual_expert_ids, actual_num_tokens = moe_align_block_size(
|
||||
topk_ids=topk_ids,
|
||||
block_size=block_size,
|
||||
num_experts=num_experts,
|
||||
pad_sorted_ids=pad_sorted_ids,
|
||||
)
|
||||
golden_sorted_ids, golden_expert_ids, golden_num_tokens = (
|
||||
torch_moe_align_block_size(
|
||||
topk_ids=topk_ids,
|
||||
block_size=block_size,
|
||||
num_experts=num_experts,
|
||||
pad_sorted_ids=pad_sorted_ids,
|
||||
)
|
||||
)
|
||||
|
||||
torch.testing.assert_close(actual_num_tokens, golden_num_tokens, atol=0, rtol=0)
|
||||
torch.testing.assert_close(actual_expert_ids, golden_expert_ids, atol=0, rtol=0)
|
||||
|
||||
# For sorted_token_ids, verify block-level correctness rather than exact
|
||||
# order Tokens within each expert's blocks can be in any order, but expert
|
||||
# regions must be correct
|
||||
_verify_expert_level_sorting(
|
||||
actual_sorted_ids,
|
||||
golden_sorted_ids,
|
||||
actual_expert_ids,
|
||||
block_size,
|
||||
actual_num_tokens.item(),
|
||||
m * topk,
|
||||
)
|
||||
|
||||
total_tokens = m * topk
|
||||
assert actual_num_tokens.item() % block_size == 0, (
|
||||
"num_tokens_post_pad should be divisible by block_size"
|
||||
)
|
||||
assert actual_num_tokens.item() >= total_tokens, (
|
||||
"num_tokens_post_pad should be at least total_tokens"
|
||||
)
|
||||
valid_tokens = actual_sorted_ids[actual_sorted_ids < total_tokens]
|
||||
assert len(valid_tokens) == total_tokens, (
|
||||
f"Should have exactly {total_tokens} valid tokens, got {len(valid_tokens)}"
|
||||
)
|
||||
assert (actual_expert_ids >= 0).all() and (actual_expert_ids < num_experts).all(), (
|
||||
"expert_ids should contain valid expert indices"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [16, 32, 2048])
|
||||
@pytest.mark.parametrize("topk", [2, 4])
|
||||
@pytest.mark.parametrize("num_experts", [8, 64])
|
||||
@pytest.mark.parametrize("block_size", [64])
|
||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||
def test_moe_align_block_size_with_expert_map(
|
||||
m: int, topk: int, num_experts: int, block_size: int
|
||||
):
|
||||
"""Test moe_align_block_size with expert mapping (EP scenario)"""
|
||||
topk_ids = torch.zeros((m, topk), device="cuda", dtype=torch.int32)
|
||||
for i in range(m):
|
||||
experts = torch.randperm(num_experts, device="cuda")[:topk]
|
||||
topk_ids[i] = experts
|
||||
|
||||
expert_map = torch.full((num_experts,), -1, device="cuda", dtype=torch.int32)
|
||||
local_experts = list(range(0, num_experts, 2))
|
||||
for i, expert_id in enumerate(local_experts):
|
||||
expert_map[expert_id] = i
|
||||
|
||||
actual_sorted_ids, actual_expert_ids, actual_num_tokens = moe_align_block_size(
|
||||
topk_ids=topk_ids,
|
||||
block_size=block_size,
|
||||
num_experts=num_experts,
|
||||
expert_map=expert_map,
|
||||
ignore_invalid_experts=True,
|
||||
)
|
||||
golden_sorted_ids, golden_expert_ids, golden_num_tokens = (
|
||||
torch_moe_align_block_size(
|
||||
topk_ids=topk_ids,
|
||||
block_size=block_size,
|
||||
num_experts=num_experts,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
)
|
||||
|
||||
torch.testing.assert_close(actual_num_tokens, golden_num_tokens, atol=0, rtol=0)
|
||||
torch.testing.assert_close(actual_expert_ids, golden_expert_ids, atol=0, rtol=0)
|
||||
_verify_expert_level_sorting(
|
||||
actual_sorted_ids,
|
||||
golden_sorted_ids,
|
||||
actual_expert_ids,
|
||||
block_size,
|
||||
actual_num_tokens.item(),
|
||||
m * topk,
|
||||
)
|
||||
|
||||
|
||||
def test_moe_align_block_size_deterministic():
|
||||
m, topk, num_experts, block_size = 128, 2, 32, 64
|
||||
|
||||
torch.manual_seed(42)
|
||||
topk_ids = torch.randint(
|
||||
0, num_experts, (m, topk), device="cuda", dtype=torch.int32
|
||||
)
|
||||
|
||||
# expect the results to be reproducible
|
||||
results = []
|
||||
for _ in range(5):
|
||||
sorted_ids, expert_ids, num_tokens = moe_align_block_size(
|
||||
topk_ids=topk_ids, block_size=block_size, num_experts=num_experts
|
||||
)
|
||||
results.append((sorted_ids.clone(), expert_ids.clone(), num_tokens.clone()))
|
||||
|
||||
for i in range(1, len(results)):
|
||||
assert torch.equal(results[0][0], results[i][0]), (
|
||||
"sorted_ids should be deterministic"
|
||||
)
|
||||
assert torch.equal(results[0][1], results[i][1]), (
|
||||
"expert_ids should be deterministic"
|
||||
)
|
||||
assert torch.equal(results[0][2], results[i][2]), (
|
||||
"num_tokens should be deterministic"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_tokens_per_batch", [13, 16, 512])
|
||||
@pytest.mark.parametrize("num_experts", [8, 16, 32, 64])
|
||||
@pytest.mark.parametrize("block_size", [8, 16, 32, 64])
|
||||
@pytest.mark.parametrize("simulate_empty_batches", [False, True])
|
||||
def test_batched_moe_align_block_size(
|
||||
max_tokens_per_batch: int,
|
||||
num_experts: int,
|
||||
block_size: int,
|
||||
simulate_empty_batches: bool,
|
||||
):
|
||||
def ref_outputs(
|
||||
expert_num_tokens: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
E = expert_num_tokens.size(0)
|
||||
|
||||
# Round up so each batch can be split to blocks evenly.
|
||||
Msum = round_up(max_tokens_per_batch, block_size) * E
|
||||
ref_sorted_ids = torch.empty((Msum,), dtype=torch.int32)
|
||||
ref_expert_ids = torch.empty((Msum // block_size,), dtype=torch.int32)
|
||||
ref_num_tokens_post_pad = torch.empty((1,), dtype=torch.int32)
|
||||
|
||||
# Intialize
|
||||
sentinel = E * max_tokens_per_batch
|
||||
ref_sorted_ids.fill_(sentinel)
|
||||
ref_expert_ids.fill_(-1)
|
||||
|
||||
# Fill ref_sorted_ids
|
||||
i = 0
|
||||
for expert_id, expert_nt in enumerate(expert_num_tokens):
|
||||
token_offset = expert_id * max_tokens_per_batch
|
||||
for j in range(expert_nt):
|
||||
ref_sorted_ids[i] = token_offset + j
|
||||
i += 1
|
||||
# round up i to the next block_size
|
||||
i = round_up(i, block_size)
|
||||
|
||||
ref_num_tokens_post_pad[0] = i
|
||||
|
||||
# Fill expert_ids
|
||||
nt_ceil_sum = 0
|
||||
for expert_id, expert_nt in enumerate(expert_num_tokens):
|
||||
expert_ids_offset = nt_ceil_sum // block_size
|
||||
ceil_expert_nt = round_up(int(expert_nt.item()), block_size)
|
||||
num_blocks = ceil_expert_nt // block_size
|
||||
for x in range(num_blocks):
|
||||
ref_expert_ids[expert_ids_offset + x] = expert_id
|
||||
nt_ceil_sum += ceil_expert_nt
|
||||
|
||||
return (
|
||||
ref_sorted_ids.to("cuda"),
|
||||
ref_expert_ids.to("cuda"),
|
||||
ref_num_tokens_post_pad.to("cuda"),
|
||||
)
|
||||
|
||||
# Compute expert_num_tokens
|
||||
expert_num_tokens = torch.randint(
|
||||
low=0,
|
||||
high=max_tokens_per_batch,
|
||||
size=(num_experts,),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
)
|
||||
if simulate_empty_batches:
|
||||
# mark half the batches to have 0 tokens
|
||||
zero_batches = torch.randperm(num_experts)[: num_experts // 2]
|
||||
expert_num_tokens[zero_batches] = 0
|
||||
|
||||
# ref outputs
|
||||
ref_sorted_ids, ref_expert_ids, ref_num_tokens_post_pad = ref_outputs(
|
||||
expert_num_tokens
|
||||
)
|
||||
|
||||
# outputs
|
||||
sorted_ids, expert_ids, num_tokens_post_pad = batched_moe_align_block_size(
|
||||
max_tokens_per_batch, block_size, expert_num_tokens.to("cuda")
|
||||
)
|
||||
|
||||
assert ref_sorted_ids.size() == sorted_ids.size(), (
|
||||
f"{ref_sorted_ids.size()} vs {sorted_ids.size()}"
|
||||
)
|
||||
assert ref_expert_ids.size() == expert_ids.size(), (
|
||||
f"{ref_expert_ids.size()} vs {expert_ids.size()}"
|
||||
)
|
||||
assert ref_num_tokens_post_pad.size() == num_tokens_post_pad.size(), (
|
||||
f"{ref_num_tokens_post_pad.size()} vs {num_tokens_post_pad.size()}"
|
||||
)
|
||||
torch.testing.assert_close(ref_sorted_ids, sorted_ids, atol=0, rtol=0)
|
||||
torch.testing.assert_close(ref_expert_ids, expert_ids, atol=0, rtol=0)
|
||||
torch.testing.assert_close(
|
||||
ref_num_tokens_post_pad, num_tokens_post_pad, atol=0, rtol=0
|
||||
)
|
||||
311
tests/kernels/moe/test_moe_permute_unpermute.py
Normal file
311
tests/kernels/moe/test_moe_permute_unpermute.py
Normal file
@@ -0,0 +1,311 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for the MOE permute/unpermute kernel
|
||||
|
||||
Run `pytest tests/kernels/test_moe_permute_unpermute.py`.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.layer import determine_expert_map
|
||||
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||
moe_permute,
|
||||
moe_permute_unpermute_supported,
|
||||
moe_unpermute,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
NUM_EXPERTS = [16, 64, 256]
|
||||
TOP_KS = [2, 6, 8]
|
||||
EP_SIZE = [1, 4, 16]
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
if current_platform.is_rocm():
|
||||
pytest.skip(
|
||||
"moe_permute_unpermute_supported is not defined for ROCm",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
|
||||
def torch_permute(
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
# token_expert_indices: torch.Tensor,
|
||||
topk: int,
|
||||
n_expert: int,
|
||||
n_local_expert: int,
|
||||
start_expert: int,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
align_block_size: int | None = None,
|
||||
fill_invalid_expert: int = -1,
|
||||
) -> list[torch.Tensor]:
|
||||
n_token, n_hidden = hidden_states.shape[0], hidden_states.shape[1]
|
||||
if expert_map is not None:
|
||||
is_local_expert = expert_map[topk_ids] != -1
|
||||
not_local_expert = expert_map[topk_ids] == -1
|
||||
topk_ids = is_local_expert * (topk_ids - start_expert) + not_local_expert * (
|
||||
topk_ids + n_expert
|
||||
)
|
||||
token_expert_indices = torch.arange(
|
||||
0, n_token * topk, dtype=torch.int32, device=hidden_states.device
|
||||
).reshape((n_token, topk))
|
||||
|
||||
sorted_topk_ids, sorted_indices = torch.sort(topk_ids.flatten(), stable=True)
|
||||
dst_row_id2src_row_id_map = token_expert_indices.flatten()[sorted_indices]
|
||||
|
||||
expert_first_token_offset = torch.zeros(
|
||||
n_local_expert + 1, dtype=torch.int64, device="cuda"
|
||||
)
|
||||
idx = 0
|
||||
for i in range(0, n_local_expert):
|
||||
cnt = 0
|
||||
while idx < sorted_topk_ids.numel() and sorted_topk_ids[idx] == i:
|
||||
cnt += 1
|
||||
idx += 1
|
||||
expert_first_token_offset[i + 1] = expert_first_token_offset[i] + cnt
|
||||
|
||||
_, src2dst_idx = torch.sort(dst_row_id2src_row_id_map)
|
||||
valid_row_idx = []
|
||||
if align_block_size is None:
|
||||
permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map // topk, ...]
|
||||
permuted_row_size = permuted_hidden_states.shape[0]
|
||||
m_indices = torch.empty(
|
||||
permuted_row_size, device="cuda", dtype=torch.int32
|
||||
).fill_(fill_invalid_expert)
|
||||
for i in range(1, n_local_expert + 1):
|
||||
first_token_offset = expert_first_token_offset[i - 1]
|
||||
last_token_offset = expert_first_token_offset[i]
|
||||
m_indices[first_token_offset:last_token_offset] = i - 1
|
||||
src_row_id2dst_row_id_map = torch.arange(
|
||||
0, n_token * topk, device="cuda", dtype=torch.int32
|
||||
)[src2dst_idx].reshape((n_token, topk))
|
||||
valid_row_idx += [i for i in range(expert_first_token_offset[-1])]
|
||||
dst_row_id2src_row_id_map[expert_first_token_offset[-1] :] = n_token * topk
|
||||
return [
|
||||
permuted_hidden_states,
|
||||
expert_first_token_offset,
|
||||
src_row_id2dst_row_id_map,
|
||||
dst_row_id2src_row_id_map,
|
||||
m_indices,
|
||||
valid_row_idx,
|
||||
]
|
||||
else:
|
||||
permuted_row_size = (
|
||||
(topk * n_token + n_expert * (align_block_size - 1) + align_block_size - 1)
|
||||
// align_block_size
|
||||
* align_block_size
|
||||
)
|
||||
permuted_idx = torch.full(
|
||||
(permuted_row_size,),
|
||||
n_token * topk,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
permuted_hidden_states = torch.empty(
|
||||
(permuted_row_size, n_hidden), device="cuda", dtype=hidden_states.dtype
|
||||
)
|
||||
align_src_row_id2dst_row_id = torch.empty(
|
||||
n_token * topk, device="cuda", dtype=torch.int32
|
||||
)
|
||||
align_expert_first_token_offset = torch.zeros_like(expert_first_token_offset)
|
||||
m_indices = torch.empty(
|
||||
permuted_row_size, device="cuda", dtype=torch.int32
|
||||
).fill_(fill_invalid_expert)
|
||||
# get align_permuted_hidden_states,
|
||||
# valid row_idx and align_expert_first_token_offset
|
||||
for i in range(1, n_local_expert + 1):
|
||||
first_token_offset = expert_first_token_offset[i - 1]
|
||||
last_token_offset = expert_first_token_offset[i]
|
||||
n_token_in_expert = last_token_offset - first_token_offset
|
||||
align_expert_first_token_offset[i] = (
|
||||
align_expert_first_token_offset[i - 1]
|
||||
+ (n_token_in_expert + align_block_size - 1)
|
||||
// align_block_size
|
||||
* align_block_size
|
||||
)
|
||||
align_first_token_offset = align_expert_first_token_offset[i - 1]
|
||||
align_last_token_offset = align_expert_first_token_offset[i]
|
||||
dst_row_id2src_row_id_in_expert = dst_row_id2src_row_id_map[
|
||||
first_token_offset : first_token_offset + n_token_in_expert
|
||||
]
|
||||
# store token in current expert with align_first_token_offset
|
||||
permuted_hidden_states[
|
||||
align_first_token_offset : align_first_token_offset + n_token_in_expert,
|
||||
...,
|
||||
] = hidden_states[dst_row_id2src_row_id_in_expert // topk, ...]
|
||||
permuted_idx[
|
||||
align_first_token_offset : align_first_token_offset + n_token_in_expert
|
||||
] = dst_row_id2src_row_id_in_expert
|
||||
# set current expert m_indices
|
||||
m_indices[align_first_token_offset:align_last_token_offset] = i - 1
|
||||
valid_row_idx += [
|
||||
i
|
||||
for i in range(
|
||||
align_first_token_offset,
|
||||
align_first_token_offset + n_token_in_expert,
|
||||
)
|
||||
]
|
||||
# get align_src_row_id2dst_row_id
|
||||
for i in range(n_token * topk):
|
||||
eid = sorted_topk_ids[i]
|
||||
if eid >= n_local_expert:
|
||||
# check token not in local expert
|
||||
align_src_row_id2dst_row_id[i] = align_expert_first_token_offset[-1]
|
||||
continue
|
||||
first_token_offset = expert_first_token_offset[eid]
|
||||
align_first_token_offset = align_expert_first_token_offset[eid]
|
||||
token_offset = i - first_token_offset
|
||||
align_src_row_id2dst_row_id[i] = align_first_token_offset + token_offset
|
||||
align_src_row_id2dst_row_id = align_src_row_id2dst_row_id[src2dst_idx].reshape(
|
||||
(n_token, topk)
|
||||
)
|
||||
return [
|
||||
permuted_hidden_states,
|
||||
align_expert_first_token_offset,
|
||||
align_src_row_id2dst_row_id,
|
||||
permuted_idx,
|
||||
m_indices,
|
||||
valid_row_idx,
|
||||
]
|
||||
|
||||
|
||||
def torch_unpermute(
|
||||
permuted_hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
src_row_id2dst_row_id_map: torch.Tensor,
|
||||
valid_row_idx: torch.Tensor,
|
||||
topk: int,
|
||||
n_expert: int,
|
||||
) -> torch.Tensor:
|
||||
# ignore invalid row
|
||||
n_hidden = permuted_hidden_states.shape[1]
|
||||
mask = torch.zeros(permuted_hidden_states.shape[0], dtype=bool, device="cuda")
|
||||
mask[valid_row_idx] = True
|
||||
permuted_hidden_states[~mask] = 0
|
||||
|
||||
permuted_hidden_states = permuted_hidden_states[
|
||||
src_row_id2dst_row_id_map.flatten(), ...
|
||||
]
|
||||
permuted_hidden_states = permuted_hidden_states.view(-1, topk, n_hidden)
|
||||
output = (
|
||||
(permuted_hidden_states * topk_weights.unsqueeze(2))
|
||||
.sum(1)
|
||||
.to(permuted_hidden_states.dtype)
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_token", [1, 33, 1024, 5000])
|
||||
@pytest.mark.parametrize("n_hidden", [2048, 7168])
|
||||
@pytest.mark.parametrize("n_expert", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||
@pytest.mark.parametrize("align_block_size", [None, 128])
|
||||
def test_moe_permute_unpermute(
|
||||
n_token: int,
|
||||
n_hidden: int,
|
||||
topk: int,
|
||||
n_expert: int,
|
||||
ep_size: int,
|
||||
dtype: torch.dtype,
|
||||
align_block_size: int | None,
|
||||
):
|
||||
if not moe_permute_unpermute_supported():
|
||||
pytest.skip("moe_permute_unpermute is not supported on this platform.")
|
||||
fill_invalid_expert = 0
|
||||
ep_rank = np.random.randint(0, ep_size)
|
||||
expert_map = None
|
||||
n_local_expert = n_expert
|
||||
if ep_size != 1:
|
||||
n_local_expert, expert_map, _ = determine_expert_map(ep_size, ep_rank, n_expert)
|
||||
expert_map = expert_map.cuda()
|
||||
start_expert = n_local_expert * ep_rank
|
||||
current_platform.seed_everything(0)
|
||||
hidden_states = torch.randn((n_token, n_hidden), device="cuda").to(dtype)
|
||||
gating_output = torch.randn((n_token, n_expert), device="cuda").to(dtype)
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
hidden_states, gating_output, topk, False
|
||||
)
|
||||
(
|
||||
gold_permuted_hidden_states,
|
||||
gold_expert_first_token_offset,
|
||||
gold_inv_permuted_idx,
|
||||
gold_permuted_idx,
|
||||
gold_m_indices,
|
||||
valid_row_idx,
|
||||
) = torch_permute(
|
||||
hidden_states,
|
||||
topk_ids,
|
||||
# token_expert_indices,
|
||||
topk,
|
||||
n_expert,
|
||||
n_local_expert,
|
||||
start_expert,
|
||||
expert_map=expert_map,
|
||||
align_block_size=align_block_size,
|
||||
fill_invalid_expert=fill_invalid_expert,
|
||||
)
|
||||
|
||||
(
|
||||
permuted_hidden_states,
|
||||
_,
|
||||
expert_first_token_offset,
|
||||
inv_permuted_idx,
|
||||
m_indices,
|
||||
) = moe_permute(
|
||||
hidden_states=hidden_states,
|
||||
a1q_scale=None,
|
||||
topk_ids=topk_ids,
|
||||
n_expert=n_expert,
|
||||
n_local_expert=n_local_expert,
|
||||
expert_map=expert_map,
|
||||
align_block_size=align_block_size,
|
||||
fill_invalid_expert=fill_invalid_expert,
|
||||
)
|
||||
|
||||
# check expert_first_token_offset
|
||||
torch.testing.assert_close(
|
||||
gold_expert_first_token_offset, expert_first_token_offset, atol=0, rtol=0
|
||||
)
|
||||
# check src_row_id2dst_row_id_map
|
||||
torch.testing.assert_close(
|
||||
gold_inv_permuted_idx.flatten(), inv_permuted_idx, atol=0, rtol=0
|
||||
)
|
||||
# check mindice
|
||||
# current kernel usage assumes deepgemm requires align_block_size
|
||||
# when it's not provided then we don't compute m_indices (for cutlass)
|
||||
if align_block_size is not None:
|
||||
torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0)
|
||||
|
||||
# check permuted_hidden_states, only valid token
|
||||
torch.testing.assert_close(
|
||||
gold_permuted_hidden_states[valid_row_idx],
|
||||
permuted_hidden_states[valid_row_idx],
|
||||
atol=0,
|
||||
rtol=0,
|
||||
)
|
||||
# add a random tensor to simulate group gemm
|
||||
result0 = 0.5 * permuted_hidden_states + torch.randn_like(permuted_hidden_states)
|
||||
result4 = torch.empty_like(hidden_states)
|
||||
moe_unpermute(
|
||||
result4, result0, topk_weights, inv_permuted_idx, expert_first_token_offset
|
||||
)
|
||||
|
||||
gold4 = torch_unpermute(
|
||||
result0,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
token_expert_indices,
|
||||
inv_permuted_idx,
|
||||
valid_row_idx,
|
||||
topk,
|
||||
n_local_expert,
|
||||
)
|
||||
# check unpermuted hidden
|
||||
torch.testing.assert_close(result4, gold4, atol=2e-2, rtol=0)
|
||||
140
tests/kernels/moe/test_nvfp4_moe.py
Normal file
140
tests/kernels/moe/test_nvfp4_moe.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_weights
|
||||
from tests.kernels.quantization.nvfp4_utils import (
|
||||
FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX,
|
||||
dequantize_nvfp4_to_dtype,
|
||||
)
|
||||
from tests.kernels.utils import torch_moe
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import nvfp4_moe_quant_config
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(
|
||||
"Nvfp4 Requires compute capability of 10 or above.", allow_module_level=True
|
||||
)
|
||||
|
||||
MNK_FACTORS = [
|
||||
(2, 1024, 1024),
|
||||
(2, 1024, 1536),
|
||||
(2, 3072, 1024),
|
||||
(64, 1024, 1024),
|
||||
(64, 3072, 1024),
|
||||
(64, 2048, 1536),
|
||||
(224, 1024, 1024),
|
||||
(224, 1024, 1536),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", [40, 64, 256])
|
||||
@pytest.mark.parametrize("topk", [1, 6, 8])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@torch.inference_mode()
|
||||
def test_cutlass_fp4_moe_no_graph(
|
||||
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype, workspace_init
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||
):
|
||||
quant_blocksize = 16
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
|
||||
(_, w1_q, w1_blockscale, w1_gs), (_, w2_q, w2_blockscale, w2_gs) = (
|
||||
make_test_weights(
|
||||
e,
|
||||
n,
|
||||
k,
|
||||
in_dtype=dtype,
|
||||
quant_dtype="nvfp4",
|
||||
block_shape=None, # use quant_blocksize?
|
||||
per_out_ch_quant=False,
|
||||
)
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)
|
||||
|
||||
a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
|
||||
a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
|
||||
|
||||
assert w1_gs is not None
|
||||
assert w2_gs is not None
|
||||
assert w1_blockscale is not None
|
||||
assert w2_blockscale is not None
|
||||
|
||||
quant_config = nvfp4_moe_quant_config(
|
||||
g1_alphas=(1 / w1_gs),
|
||||
g2_alphas=(1 / w2_gs),
|
||||
a1_gscale=a1_gs,
|
||||
a2_gscale=a2_gs,
|
||||
w1_scale=w1_blockscale,
|
||||
w2_scale=w2_blockscale,
|
||||
)
|
||||
|
||||
cutlass_output = cutlass_moe_fp4(
|
||||
a=a,
|
||||
w1_fp4=w1_q,
|
||||
w2_fp4=w2_q,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
quant_config=quant_config,
|
||||
m=m,
|
||||
n=n,
|
||||
k=k,
|
||||
e=e,
|
||||
)
|
||||
|
||||
# Reference check:
|
||||
a_global_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale)
|
||||
|
||||
a_in_dtype = dequantize_nvfp4_to_dtype(
|
||||
a_fp4,
|
||||
a_scale_interleaved,
|
||||
a_global_scale,
|
||||
dtype=a.dtype,
|
||||
device=a.device,
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
|
||||
w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype)
|
||||
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
|
||||
|
||||
for idx in range(0, e):
|
||||
w1_d[idx] = dequantize_nvfp4_to_dtype(
|
||||
w1_q[idx],
|
||||
w1_blockscale[idx],
|
||||
w1_gs[idx],
|
||||
dtype=dtype,
|
||||
device=w1_q.device,
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
w2_d[idx] = dequantize_nvfp4_to_dtype(
|
||||
w2_q[idx],
|
||||
w2_blockscale[idx],
|
||||
w2_gs[idx],
|
||||
dtype=dtype,
|
||||
device=w2_q.device,
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
|
||||
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
|
||||
|
||||
torch.testing.assert_close(torch_output, cutlass_output, atol=1e-1, rtol=1e-1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_cutlass_fp4_moe_no_graph((2, 1024, 1024), 40, 1, torch.half)
|
||||
993
tests/kernels/moe/test_ocp_mx_moe.py
Normal file
993
tests/kernels/moe/test_ocp_mx_moe.py
Normal file
@@ -0,0 +1,993 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import importlib.metadata
|
||||
from dataclasses import dataclass
|
||||
from importlib.util import find_spec
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
|
||||
QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
|
||||
importlib.metadata.version("amd-quark")
|
||||
) >= version.parse("0.8.99")
|
||||
|
||||
TRTLLM_GEN_MXFP4_AVAILABLE = (
|
||||
current_platform.is_cuda() and current_platform.is_device_capability_family(100)
|
||||
)
|
||||
|
||||
HOPPER_MXFP4_BF16_AVAILABLE = (
|
||||
current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(90)
|
||||
and has_flashinfer()
|
||||
)
|
||||
|
||||
if TRTLLM_GEN_MXFP4_AVAILABLE:
|
||||
from flashinfer import (
|
||||
fp4_quantize,
|
||||
mxfp8_quantize,
|
||||
next_positive_power_of_2,
|
||||
reorder_rows_for_gated_act_gemm,
|
||||
shuffle_matrix_a,
|
||||
shuffle_matrix_sf_a,
|
||||
trtllm_fp4_block_scale_moe,
|
||||
)
|
||||
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
|
||||
from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelCase:
|
||||
model_id: str
|
||||
tp: int
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def enable_pickle(monkeypatch):
|
||||
"""`LLM.apply_model` requires pickling a function."""
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_case",
|
||||
[
|
||||
ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=2),
|
||||
ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8),
|
||||
ModelCase("fxmarty/Llama-4-Scout-17B-16E-Instruct-2-layers-mxfp4", tp=1),
|
||||
ModelCase("fxmarty/Llama-3.1-70B-Instruct-2-layers-mxfp6", tp=1),
|
||||
ModelCase("fxmarty/Llama-3.1-70B-Instruct-2-layers-mxfp6", tp=4),
|
||||
],
|
||||
)
|
||||
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
|
||||
def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
|
||||
if torch.cuda.device_count() < model_case.tp:
|
||||
pytest.skip(
|
||||
f"This test requires >={model_case.tp} gpus, got only "
|
||||
f"{torch.cuda.device_count()}"
|
||||
)
|
||||
|
||||
# `cudagraph_capture_sizes=[16]` to reduce load time.
|
||||
with vllm_runner(
|
||||
model_case.model_id,
|
||||
tensor_parallel_size=model_case.tp,
|
||||
load_format="dummy",
|
||||
cudagraph_capture_sizes=[16],
|
||||
) as llm:
|
||||
# Disabled as check_model is broken: https://github.com/vllm-project/vllm/pull/18465#issuecomment-3329880562
|
||||
# def check_model(model):
|
||||
# from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
|
||||
# QuarkLinearMethod)
|
||||
# from vllm.model_executor.layers.quantization.quark.schemes.quark_ocp_mx import QuarkOCP_MX # noqa: E501
|
||||
# from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501
|
||||
# QuarkOCP_MX_MoEMethod)
|
||||
|
||||
# layer = model.model.layers[0]
|
||||
|
||||
# qkv_proj = layer.self_attn.qkv_proj
|
||||
|
||||
# assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
|
||||
# assert isinstance(qkv_proj.scheme, QuarkOCP_MX)
|
||||
|
||||
# assert isinstance(layer.mlp.experts.quant_method,
|
||||
# QuarkOCP_MX_MoEMethod)
|
||||
|
||||
# if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4":
|
||||
# llm.apply_model(check_model)
|
||||
|
||||
output = llm.generate_greedy("Today I am in the French Alps and", max_tokens=20)
|
||||
assert output
|
||||
|
||||
|
||||
def swiglu(x, alpha: float = 1.702, beta: float = 1.0, limit: float | None = None):
|
||||
# Note we add an extra bias of 1 to the linear layer
|
||||
x_glu, x_linear = torch.chunk(x, 2, dim=-1)
|
||||
if limit is not None:
|
||||
x_glu = x_glu.clamp(max=limit)
|
||||
x_linear = x_linear.clamp(min=-limit, max=limit)
|
||||
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
|
||||
return out_glu * (x_linear + beta)
|
||||
|
||||
|
||||
fp4_lookup_table = [0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6]
|
||||
|
||||
|
||||
def mxfp4_dequantize(x, scale):
|
||||
assert x.dtype == torch.uint8
|
||||
x = x.view(torch.uint8).to(torch.int32)
|
||||
x_unpacked = torch.zeros(
|
||||
*x.shape[:-1], x.shape[-1] * 2, dtype=torch.int32, device=x.device
|
||||
)
|
||||
x_unpacked[..., 0::2].copy_(x & 0xF)
|
||||
x_unpacked[..., 1::2].copy_((x >> 4) & 0xF)
|
||||
|
||||
x_float = torch.zeros(x_unpacked.shape, dtype=torch.float32, device=x.device)
|
||||
for i, val in enumerate(fp4_lookup_table):
|
||||
x_float[x_unpacked == i] = val
|
||||
|
||||
scale = scale.view(torch.uint8).to(torch.int32)
|
||||
scale = (scale << 23).view(torch.float32)
|
||||
scale = scale.reshape(*x.shape[:-1], -1)
|
||||
scale = torch.stack([scale] * 32, dim=-1).reshape(*x_float.shape)
|
||||
|
||||
return x_float * scale
|
||||
|
||||
|
||||
def mxfp8_dequantize(x, scale):
|
||||
assert x.dtype == torch.float8_e4m3fn
|
||||
x_float = x.to(torch.float32)
|
||||
|
||||
scale = scale.view(torch.uint8).to(torch.int32)
|
||||
scale = (scale << 23).view(torch.float32)
|
||||
scale = scale.reshape(*x.shape[:-1], -1)
|
||||
scale = torch.stack([scale] * 32, dim=-1).reshape(*x_float.shape)
|
||||
|
||||
return x_float * scale
|
||||
|
||||
|
||||
def reference_moe(
|
||||
roouting_logits,
|
||||
topk,
|
||||
num_experts,
|
||||
hidden_states,
|
||||
w13,
|
||||
bias13,
|
||||
w2,
|
||||
bias2,
|
||||
alpha,
|
||||
beta,
|
||||
limit,
|
||||
act_type,
|
||||
):
|
||||
# renormalize routing
|
||||
experts = torch.topk(roouting_logits, k=topk, dim=-1, sorted=True)
|
||||
expert_weights = torch.nn.functional.softmax(experts.values, dim=1)
|
||||
expert_indices = experts.indices
|
||||
t = hidden_states.clone()
|
||||
# MLP #1
|
||||
mlp1_weight = w13[expert_indices, ...]
|
||||
mlp1_bias = bias13[expert_indices, ...]
|
||||
t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias
|
||||
t = swiglu(t, alpha=alpha, beta=beta, limit=limit)
|
||||
|
||||
if act_type == "mxfp8":
|
||||
t_quantized, t_scale = mxfp8_quantize(
|
||||
t.to(torch.bfloat16), is_sf_swizzled_layout=False
|
||||
)
|
||||
t = mxfp8_dequantize(t_quantized, t_scale)
|
||||
# MLP #2
|
||||
mlp2_weight = w2[expert_indices, ...]
|
||||
mlp2_bias = bias2[expert_indices, ...]
|
||||
t = torch.einsum("beck,bek->bec", mlp2_weight, t) + mlp2_bias
|
||||
# Weighted sum of experts
|
||||
t = torch.einsum("bec,be->bc", t, expert_weights)
|
||||
assert t.shape == hidden_states.shape
|
||||
return t.to(torch.bfloat16)
|
||||
|
||||
|
||||
def get_tile_tokens_dim(x: torch.Tensor, top_k: int, num_experts: int):
|
||||
# Number of tokens in the input tensor.
|
||||
num_tokens = x.shape[0]
|
||||
# Factor to account for the imbalance of the experts.
|
||||
# factor equals to the
|
||||
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
|
||||
# - 1.0 means perfect expert distribution.
|
||||
# - > 1.0 means some experts have more
|
||||
# tokens than the perfect distribution.
|
||||
# - < 1.0 does not make sense.
|
||||
imbalance_factor = 1.3
|
||||
# Calculate the number of tokens per expert
|
||||
# assuming perfect distribution.
|
||||
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
||||
# Apply the imbalance factor.
|
||||
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
|
||||
# And pad the number to the next power of 2.
|
||||
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
|
||||
# Cap to 8-64 tokens per CTA tile
|
||||
# as it's the range supported by the kernel.
|
||||
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
||||
return tile_tokens_dim
|
||||
|
||||
|
||||
def tg_mxfp4_moe(
|
||||
router_logits,
|
||||
topk,
|
||||
num_experts,
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
hidden_states,
|
||||
hidden_states_scale,
|
||||
w13_weight,
|
||||
w13_weight_scale,
|
||||
w13_bias,
|
||||
w2_weight,
|
||||
w2_weight_scale,
|
||||
w2_bias,
|
||||
act_type,
|
||||
alpha,
|
||||
beta,
|
||||
limit,
|
||||
transpose_optimized: bool = False,
|
||||
) -> torch.Tensor:
|
||||
sf_block_size = 32
|
||||
assert (
|
||||
w13_weight.dim() == 3
|
||||
and w13_weight.shape[0] == num_experts
|
||||
and w13_weight.shape[1] == intermediate_size * 2
|
||||
and w13_weight.shape[2] == hidden_size // 2
|
||||
)
|
||||
assert (
|
||||
w13_weight_scale.dim() == 3
|
||||
and w13_weight_scale.shape[0] == num_experts
|
||||
and w13_weight_scale.shape[1] == intermediate_size * 2
|
||||
and w13_weight_scale.shape[2] == hidden_size // sf_block_size
|
||||
)
|
||||
assert (
|
||||
w2_weight.dim() == 3
|
||||
and w2_weight.shape[0] == num_experts
|
||||
and w2_weight.shape[1] == hidden_size
|
||||
and w2_weight.shape[2] == intermediate_size // 2
|
||||
)
|
||||
assert (
|
||||
w2_weight_scale.dim() == 3
|
||||
and w2_weight_scale.shape[1] == hidden_size
|
||||
and w2_weight_scale.shape[2] == intermediate_size // sf_block_size
|
||||
)
|
||||
assert (
|
||||
w13_bias.dim() == 2
|
||||
and w13_bias.shape[0] == num_experts
|
||||
and w13_bias.shape[1] == intermediate_size * 2
|
||||
)
|
||||
assert (
|
||||
w2_bias.dim() == 2
|
||||
and w2_bias.shape[0] == num_experts
|
||||
and w2_bias.shape[1] == hidden_size
|
||||
)
|
||||
|
||||
# Swap w1 and w3 as the definition of
|
||||
# swiglu is different in the trtllm-gen
|
||||
w13_weight_scale_ = w13_weight_scale.clone()
|
||||
w13_weight_ = w13_weight.clone()
|
||||
w13_bias_ = w13_bias.clone()
|
||||
w13_weight[:, :intermediate_size, :].copy_(w13_weight_[:, intermediate_size:, :])
|
||||
w13_weight[:, intermediate_size:, :].copy_(w13_weight_[:, :intermediate_size, :])
|
||||
w13_weight_scale[:, :intermediate_size, :].copy_(
|
||||
w13_weight_scale_[:, intermediate_size:, :]
|
||||
)
|
||||
w13_weight_scale[:, intermediate_size:, :].copy_(
|
||||
w13_weight_scale_[:, :intermediate_size, :]
|
||||
)
|
||||
w13_bias[:, :intermediate_size].copy_(w13_bias_[:, intermediate_size:])
|
||||
w13_bias[:, intermediate_size:].copy_(w13_bias_[:, :intermediate_size])
|
||||
|
||||
# Interleave the weights and scaling factors for activation
|
||||
w13_weight_interleaved = []
|
||||
w13_weight_scale_interleaved = []
|
||||
w13_bias_interleaved = []
|
||||
for i in range(num_experts):
|
||||
w13_weight_interleaved.append(
|
||||
reorder_rows_for_gated_act_gemm(w13_weight[i].clone())
|
||||
)
|
||||
w13_weight_scale_interleaved.append(
|
||||
reorder_rows_for_gated_act_gemm(w13_weight_scale[i].clone())
|
||||
)
|
||||
w13_bias_interleaved.append(
|
||||
reorder_rows_for_gated_act_gemm(w13_bias[i].clone().reshape(-1, 1))
|
||||
)
|
||||
w13_weight = torch.stack(w13_weight_interleaved).reshape(
|
||||
num_experts, 2 * intermediate_size, hidden_size // 2
|
||||
)
|
||||
w13_weight_scale = torch.stack(w13_weight_scale_interleaved).reshape(
|
||||
num_experts, 2 * intermediate_size, hidden_size // 32
|
||||
)
|
||||
w13_bias = torch.stack(w13_bias_interleaved).reshape(
|
||||
num_experts, 2 * intermediate_size
|
||||
)
|
||||
|
||||
# Shuffle weights and scaling factors for transposed mma output
|
||||
gemm1_weights_shuffled = []
|
||||
gemm1_scales_shuffled = []
|
||||
gemm2_weights_shuffled = []
|
||||
gemm2_scales_shuffled = []
|
||||
gemm1_bias_shuffled = []
|
||||
gemm2_bias_shuffled = []
|
||||
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
||||
_cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
|
||||
if transpose_optimized:
|
||||
for i in range(num_experts):
|
||||
# w13 weight shuffling
|
||||
permute_indices = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
w13_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm1_weights_shuffled.append(
|
||||
w13_weight[i]
|
||||
.view(torch.uint8)[permute_indices.to(w13_weight.device)]
|
||||
.contiguous()
|
||||
)
|
||||
# w13 scale shuffling
|
||||
permute_sf_indices = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
w13_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
num_elts_per_sf=16,
|
||||
)
|
||||
gemm1_scales_shuffled.append(
|
||||
nvfp4_block_scale_interleave(
|
||||
w13_weight_scale[i]
|
||||
.view(torch.uint8)[permute_sf_indices.to(w13_weight_scale.device)]
|
||||
.contiguous()
|
||||
)
|
||||
)
|
||||
# w13 bias shuffling
|
||||
permute_bias_indices = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
w13_bias[i].clone().reshape(-1, 1),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm1_bias_shuffled.append(
|
||||
w13_bias[i]
|
||||
.clone()
|
||||
.reshape(-1, 1)[permute_bias_indices.to(w13_bias.device)]
|
||||
.contiguous()
|
||||
)
|
||||
# w2 weight shuffling
|
||||
permute_indices = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
w2_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm2_weights_shuffled.append(
|
||||
w2_weight[i]
|
||||
.view(torch.uint8)[permute_indices.to(w2_weight.device)]
|
||||
.contiguous()
|
||||
)
|
||||
# w2 scale shuffling
|
||||
permute_sf_indices = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
w2_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
num_elts_per_sf=16,
|
||||
)
|
||||
gemm2_scales_shuffled.append(
|
||||
nvfp4_block_scale_interleave(
|
||||
w2_weight_scale[i]
|
||||
.view(torch.uint8)[permute_sf_indices.to(w2_weight_scale.device)]
|
||||
.contiguous()
|
||||
)
|
||||
)
|
||||
# w2 bias shuffling
|
||||
permute_indices = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
w2_bias[i].clone().reshape(-1, 1),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm2_bias_shuffled.append(
|
||||
w2_bias[i]
|
||||
.clone()
|
||||
.reshape(-1, 1)[permute_indices.to(w2_bias.device)]
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
else:
|
||||
for i in range(num_experts):
|
||||
gemm1_weights_shuffled.append(
|
||||
shuffle_matrix_a(w13_weight[i].view(torch.uint8), epilogue_tile_m)
|
||||
)
|
||||
gemm1_scales_shuffled.append(
|
||||
shuffle_matrix_sf_a(
|
||||
w13_weight_scale[i].view(torch.uint8), epilogue_tile_m
|
||||
)
|
||||
)
|
||||
|
||||
gemm2_weights_shuffled.append(
|
||||
shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m)
|
||||
)
|
||||
gemm2_scales_shuffled.append(
|
||||
shuffle_matrix_sf_a(
|
||||
w2_weight_scale[i].view(torch.uint8), epilogue_tile_m
|
||||
)
|
||||
)
|
||||
gemm1_bias_shuffled.append(
|
||||
shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m)
|
||||
)
|
||||
gemm2_bias_shuffled.append(
|
||||
shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m)
|
||||
)
|
||||
|
||||
w13_weight = torch.stack(gemm1_weights_shuffled)
|
||||
w13_weight_scale = (
|
||||
torch.stack(gemm1_scales_shuffled)
|
||||
.reshape(num_experts, 2 * intermediate_size, hidden_size // sf_block_size)
|
||||
.view(torch.float8_e4m3fn)
|
||||
)
|
||||
w13_bias = torch.stack(gemm1_bias_shuffled).reshape(num_experts, -1)
|
||||
|
||||
w2_weight = torch.stack(gemm2_weights_shuffled)
|
||||
w2_weight_scale = (
|
||||
torch.stack(gemm2_scales_shuffled)
|
||||
.reshape(num_experts, hidden_size, intermediate_size // sf_block_size)
|
||||
.view(torch.float8_e4m3fn)
|
||||
)
|
||||
w2_bias = torch.stack(gemm2_bias_shuffled).reshape(num_experts, -1)
|
||||
|
||||
tg_result = trtllm_fp4_block_scale_moe(
|
||||
routing_logits=router_logits.to(torch.bfloat16),
|
||||
routing_bias=None,
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_scale=hidden_states_scale,
|
||||
gemm1_weights=w13_weight,
|
||||
gemm1_weights_scale=w13_weight_scale,
|
||||
gemm1_bias=w13_bias,
|
||||
gemm1_alpha=alpha,
|
||||
gemm1_beta=beta,
|
||||
gemm1_clamp_limit=limit,
|
||||
gemm2_weights=w2_weight,
|
||||
gemm2_weights_scale=w2_weight_scale,
|
||||
gemm2_bias=w2_bias,
|
||||
output1_scale_scalar=None,
|
||||
output1_scale_gate_scalar=None,
|
||||
output2_scale_scalar=None,
|
||||
num_experts=num_experts,
|
||||
top_k=topk,
|
||||
n_group=None,
|
||||
topk_group=None,
|
||||
intermediate_size=intermediate_size,
|
||||
local_expert_offset=0,
|
||||
local_num_experts=num_experts,
|
||||
routed_scaling_factor=None,
|
||||
tile_tokens_dim=get_tile_tokens_dim(hidden_states, topk, num_experts),
|
||||
routing_method_type=1, # renormalize
|
||||
do_finalize=True,
|
||||
)[0]
|
||||
return tg_result
|
||||
|
||||
|
||||
def check_accuracy(a, b, atol, rtol, percent):
|
||||
"""Allow a mismatch percentage of 1 - percent."""
|
||||
if torch.any(torch.isnan(a)):
|
||||
raise Exception("NaN in reference output")
|
||||
if torch.any(torch.isnan(b)):
|
||||
raise Exception("NaN in actual output")
|
||||
if torch.any(torch.isinf(a)):
|
||||
raise Exception("Inf in reference output")
|
||||
if torch.any(torch.isinf(b)):
|
||||
raise Exception("Inf in actual output")
|
||||
assert a.shape == b.shape, f"Shape mismatch: {a.shape} vs {b.shape}"
|
||||
|
||||
left = torch.abs(a - b)
|
||||
right = atol + rtol * torch.abs(b)
|
||||
count = torch.sum(left > right)
|
||||
mismatch_percent = count / a.numel()
|
||||
if mismatch_percent > 1 - percent:
|
||||
raise Exception(
|
||||
f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} "
|
||||
f"(threshold: {1 - percent:.4f})"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("topk", [1, 4])
|
||||
@pytest.mark.parametrize("num_experts", [32, 128])
|
||||
@pytest.mark.parametrize("num_tokens", [1, 128, 1024])
|
||||
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
|
||||
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)])
|
||||
@pytest.mark.parametrize("act_type", ["mxfp8", "bf16"])
|
||||
@pytest.mark.parametrize("transpose_optimized", [False, True])
|
||||
@pytest.mark.skipif(
|
||||
not TRTLLM_GEN_MXFP4_AVAILABLE,
|
||||
reason="nvidia gpu and compute capability sm100 is required for this test",
|
||||
)
|
||||
def test_trtllm_gen_mxfp4_fused_moe(
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
num_tokens: int,
|
||||
intermediate_size: int,
|
||||
hidden_size: int,
|
||||
alpha: float,
|
||||
beta: float,
|
||||
limit: float | None,
|
||||
act_type: str,
|
||||
transpose_optimized: bool,
|
||||
):
|
||||
seed = 42
|
||||
torch.manual_seed(seed)
|
||||
hidden_states = torch.randn(
|
||||
num_tokens, hidden_size, device="cuda:0", dtype=torch.bfloat16
|
||||
)
|
||||
w13 = torch.randn(
|
||||
num_experts,
|
||||
intermediate_size * 2,
|
||||
hidden_size,
|
||||
device="cuda:0",
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
w2 = torch.randn(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
device="cuda:0",
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
bias13 = torch.randn(num_experts, intermediate_size * 2, device="cuda:0") * 10
|
||||
bias2 = torch.randn(num_experts, hidden_size, device="cuda:0") * 10
|
||||
router_logits = torch.rand(num_tokens, num_experts, dtype=torch.float32).cuda()
|
||||
|
||||
w13, w13_scale = fp4_quantize(
|
||||
w13,
|
||||
torch.tensor(1.0, device="cuda:0"),
|
||||
32,
|
||||
sf_use_ue8m0=True,
|
||||
is_sf_swizzled_layout=False,
|
||||
)
|
||||
w13_scale = w13_scale.view(torch.float8_e4m3fn).reshape(
|
||||
num_experts, intermediate_size * 2, hidden_size // 32
|
||||
)
|
||||
w2, w2_scale = fp4_quantize(
|
||||
w2,
|
||||
torch.tensor(1.0, device="cuda:0"),
|
||||
32,
|
||||
sf_use_ue8m0=True,
|
||||
is_sf_swizzled_layout=False,
|
||||
)
|
||||
w2_scale = w2_scale.view(torch.float8_e4m3fn).reshape(
|
||||
num_experts, hidden_size, intermediate_size // 32
|
||||
)
|
||||
if act_type == "mxfp8":
|
||||
hidden_states, hidden_states_scale = mxfp8_quantize(
|
||||
hidden_states, is_sf_swizzled_layout=False
|
||||
)
|
||||
hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape(-1)
|
||||
else:
|
||||
hidden_states_scale = None
|
||||
|
||||
# reference result
|
||||
ref_result = torch.empty_like(hidden_states, dtype=torch.bfloat16)
|
||||
w13_ref = mxfp4_dequantize(w13.clone(), w13_scale.clone())
|
||||
w2_ref = mxfp4_dequantize(w2.clone(), w2_scale.clone())
|
||||
bias13_ref = bias13
|
||||
bias2_ref = bias2
|
||||
if act_type == "mxfp8":
|
||||
hidden_states_ref = mxfp8_dequantize(hidden_states, hidden_states_scale).to(
|
||||
torch.float32
|
||||
)
|
||||
else:
|
||||
hidden_states_ref = hidden_states.to(torch.float32)
|
||||
# Process tokens in chunks of 32 to reduce memory usage
|
||||
chunk_size = 32
|
||||
num_chunks = (num_tokens + chunk_size - 1) // chunk_size
|
||||
for i in range(num_chunks):
|
||||
start_idx = i * chunk_size
|
||||
end_idx = min(start_idx + chunk_size, num_tokens)
|
||||
chunk_result = reference_moe(
|
||||
router_logits[start_idx:end_idx].to(torch.float32),
|
||||
topk,
|
||||
num_experts,
|
||||
hidden_states_ref[start_idx:end_idx],
|
||||
w13_ref,
|
||||
bias13_ref,
|
||||
w2_ref,
|
||||
bias2_ref,
|
||||
alpha,
|
||||
beta,
|
||||
limit,
|
||||
act_type,
|
||||
)
|
||||
ref_result[start_idx:end_idx].copy_(chunk_result)
|
||||
|
||||
# trtllm-gen result
|
||||
if alpha is not None:
|
||||
alpha = torch.full((num_experts,), alpha, device=hidden_states.device)
|
||||
if limit is not None:
|
||||
limit = torch.full((num_experts,), limit, device=hidden_states.device)
|
||||
if beta is not None:
|
||||
beta = torch.full((num_experts,), beta, device=hidden_states.device)
|
||||
tg_result = tg_mxfp4_moe(
|
||||
router_logits,
|
||||
topk,
|
||||
num_experts,
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
hidden_states,
|
||||
hidden_states_scale,
|
||||
w13,
|
||||
w13_scale,
|
||||
bias13,
|
||||
w2,
|
||||
w2_scale,
|
||||
bias2,
|
||||
act_type,
|
||||
alpha=alpha,
|
||||
beta=beta,
|
||||
limit=limit,
|
||||
transpose_optimized=transpose_optimized,
|
||||
)
|
||||
# relatively loose check since the mxfp4 quantization is less accurate
|
||||
check_accuracy(ref_result, tg_result, atol=0, rtol=0.3, percent=0.8)
|
||||
|
||||
|
||||
def _interleave_scales_lastdim_by4(scales: torch.Tensor) -> torch.Tensor:
|
||||
"""Interleave scales on the last dimension by groups of 4, matching
|
||||
the transformation in mxfp4.py's BF16 (Hopper) path."""
|
||||
s = scales.to(torch.uint8)
|
||||
s_shape = s.shape
|
||||
assert s_shape[-1] % 4 == 0
|
||||
s = s.reshape(*s_shape[:-1], s_shape[-1] // 4, 4)
|
||||
# Move the 4-group dimension before the row dimension
|
||||
permuted = s.permute(0, 2, 1, 3)
|
||||
# Merge the row dim with the 4-group dim
|
||||
return permuted.reshape(s_shape[0], s_shape[-1] // 4, s_shape[1] * 4)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("topk", [1, 4])
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("num_tokens", [1, 128])
|
||||
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
|
||||
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)])
|
||||
@pytest.mark.skipif(
|
||||
not HOPPER_MXFP4_BF16_AVAILABLE,
|
||||
reason="nvidia gpu sm90 and flashinfer are required for this test",
|
||||
)
|
||||
def test_flashinfer_cutlass_mxfp4_fused_moe(
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
num_tokens: int,
|
||||
intermediate_size: int,
|
||||
hidden_size: int,
|
||||
alpha: float,
|
||||
beta: float,
|
||||
limit: float | None,
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
device = "cuda:0"
|
||||
|
||||
# Inputs
|
||||
hidden_states = torch.randn(
|
||||
num_tokens, hidden_size, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
# Random MXFP4 weights and scales (uint8), contiguous [w1; w3]
|
||||
w13_q = torch.randint(
|
||||
0,
|
||||
256,
|
||||
(num_experts, 2 * intermediate_size, hidden_size // 2),
|
||||
device=device,
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
w13_scale = torch.randint(
|
||||
118,
|
||||
123,
|
||||
(num_experts, 2 * intermediate_size, hidden_size // 32),
|
||||
device=device,
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
|
||||
w2_q = torch.randint(
|
||||
0,
|
||||
256,
|
||||
(num_experts, hidden_size, intermediate_size // 2),
|
||||
device=device,
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
w2_scale = torch.randint(
|
||||
118,
|
||||
123,
|
||||
(num_experts, hidden_size, intermediate_size // 32),
|
||||
device=device,
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
# Bias contiguous [b1; b3]
|
||||
bias13 = (
|
||||
torch.randn(
|
||||
num_experts, 2 * intermediate_size, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
* 10
|
||||
)
|
||||
bias2 = (
|
||||
torch.randn(num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10
|
||||
)
|
||||
router_logits = torch.rand(
|
||||
num_tokens, num_experts, dtype=torch.float32, device=device
|
||||
)
|
||||
|
||||
w13_ref = mxfp4_dequantize(w13_q.clone(), w13_scale.clone()).reshape(
|
||||
num_experts, 2 * intermediate_size, hidden_size
|
||||
)
|
||||
w2_ref = mxfp4_dequantize(w2_q.clone(), w2_scale.clone()).reshape(
|
||||
num_experts, hidden_size, intermediate_size
|
||||
)
|
||||
ref = reference_moe(
|
||||
router_logits.to(torch.float32),
|
||||
topk,
|
||||
num_experts,
|
||||
hidden_states.to(torch.float32),
|
||||
w13_ref,
|
||||
bias13.to(torch.float32),
|
||||
w2_ref,
|
||||
bias2.to(torch.float32),
|
||||
alpha,
|
||||
beta,
|
||||
limit,
|
||||
"bf16",
|
||||
)
|
||||
|
||||
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
|
||||
|
||||
# Swap halves to arrange as [w3; w1] (kernel expectation)
|
||||
w1_w, w3_w = torch.chunk(w13_q, 2, dim=1)
|
||||
w13_q_swapped = torch.cat([w3_w, w1_w], dim=1)
|
||||
|
||||
b1, b3 = torch.chunk(bias13.to(torch.float32), 2, dim=-1)
|
||||
w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)
|
||||
|
||||
w1_s, w3_s = torch.chunk(w13_scale, 2, dim=1)
|
||||
w13_s = torch.cat([w3_s, w1_s], dim=1)
|
||||
w13_s_inter = _interleave_scales_lastdim_by4(w13_s)
|
||||
w2_s_inter = _interleave_scales_lastdim_by4(w2_scale)
|
||||
|
||||
routing_weights = torch.nn.functional.softmax(
|
||||
router_logits, dim=1, dtype=torch.float32
|
||||
)
|
||||
token_final_scales, token_selected_experts = torch.topk(
|
||||
routing_weights, topk, dim=-1
|
||||
)
|
||||
token_final_scales = token_final_scales / token_final_scales.sum(
|
||||
dim=-1, keepdim=True
|
||||
)
|
||||
token_selected_experts = token_selected_experts.to(torch.int).contiguous()
|
||||
|
||||
out = torch.empty_like(hidden_states, dtype=torch.bfloat16)
|
||||
if alpha is not None:
|
||||
alpha = torch.full((num_experts,), alpha, device=hidden_states.device)
|
||||
if beta is not None:
|
||||
beta = torch.full((num_experts,), beta, device=hidden_states.device)
|
||||
if limit is not None:
|
||||
limit = torch.full((num_experts,), limit, device=hidden_states.device)
|
||||
|
||||
_ = flashinfer_cutlass_fused_moe(
|
||||
input=hidden_states,
|
||||
token_selected_experts=token_selected_experts,
|
||||
token_final_scales=token_final_scales,
|
||||
fc1_expert_weights=w13_q_swapped,
|
||||
fc2_expert_weights=w2_q,
|
||||
output_dtype=torch.bfloat16,
|
||||
output=out,
|
||||
quant_scales=[w13_s_inter.to(torch.uint8), w2_s_inter.to(torch.uint8)],
|
||||
fc1_expert_biases=w13_b,
|
||||
fc2_expert_biases=bias2.to(torch.bfloat16),
|
||||
swiglu_alpha=alpha,
|
||||
swiglu_beta=beta,
|
||||
swiglu_limit=limit,
|
||||
tp_size=1,
|
||||
tp_rank=0,
|
||||
ep_size=1,
|
||||
ep_rank=0,
|
||||
use_w4_group_scaling=True,
|
||||
)
|
||||
|
||||
# Allow some mismatch due to MXFP4 quantization
|
||||
check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("topk", [1, 4])
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("num_tokens", [1, 128])
|
||||
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
|
||||
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)])
|
||||
@pytest.mark.skipif(
|
||||
not (
|
||||
current_platform.is_cuda()
|
||||
and current_platform.is_device_capability_family(100)
|
||||
and has_flashinfer()
|
||||
),
|
||||
reason="NVIDIA GPU sm100 and flashinfer are required for this test",
|
||||
)
|
||||
def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
num_tokens: int,
|
||||
intermediate_size: int,
|
||||
hidden_size: int,
|
||||
alpha: float | None,
|
||||
beta: float | None,
|
||||
limit: float | None,
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
device = "cuda:0"
|
||||
|
||||
# Inputs
|
||||
hidden_states = torch.randn(
|
||||
num_tokens, hidden_size, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
# Float weights in w13 format [w1; w3]
|
||||
w13 = (
|
||||
torch.randn(
|
||||
num_experts,
|
||||
2 * intermediate_size,
|
||||
hidden_size,
|
||||
device=device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
/ 10
|
||||
)
|
||||
w2 = (
|
||||
torch.randn(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
device=device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
/ 10
|
||||
)
|
||||
# Bias contiguous [b1; b3]
|
||||
bias13 = (
|
||||
torch.randn(
|
||||
num_experts, 2 * intermediate_size, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
* 10
|
||||
)
|
||||
bias2 = (
|
||||
torch.randn(num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10
|
||||
)
|
||||
router_logits = torch.rand(
|
||||
num_tokens, num_experts, dtype=torch.float32, device=device
|
||||
)
|
||||
|
||||
# Quantize weights to MXFP4 per expert (SM100 path)
|
||||
from flashinfer import mxfp4_quantize
|
||||
|
||||
def quant_mxfp4_batches(a: torch.Tensor, e: int):
|
||||
qs, sfs = [], []
|
||||
for i in range(e):
|
||||
q, sf = mxfp4_quantize(a[i].cuda())
|
||||
qs.append(q)
|
||||
sfs.append(sf)
|
||||
return torch.stack(qs), torch.stack(sfs)
|
||||
|
||||
def dequant_mxfp4_batches(mat_fp4: torch.Tensor, scale_tensor: torch.Tensor):
|
||||
num_batches = mat_fp4.size(0)
|
||||
scale_tensor = scale_tensor.view(num_batches, -1)
|
||||
from flashinfer import mxfp4_dequantize
|
||||
|
||||
return torch.stack(
|
||||
[
|
||||
mxfp4_dequantize(mat_fp4[b, :, :], scale_tensor[b, :])
|
||||
for b in range(num_batches)
|
||||
]
|
||||
)
|
||||
|
||||
w13_q, w13_scale = quant_mxfp4_batches(w13, num_experts)
|
||||
w2_q, w2_scale = quant_mxfp4_batches(w2, num_experts)
|
||||
|
||||
# Reference result using dequantized tensors and reference_moe
|
||||
w13_ref = (
|
||||
dequant_mxfp4_batches(
|
||||
w13_q.view(torch.uint8), w13_scale.view(torch.uint8).reshape(-1)
|
||||
)
|
||||
.to(torch.float32)
|
||||
.reshape(num_experts, 2 * intermediate_size, hidden_size)
|
||||
.to(device)
|
||||
)
|
||||
w2_ref = (
|
||||
dequant_mxfp4_batches(
|
||||
w2_q.view(torch.uint8), w2_scale.view(torch.uint8).reshape(-1)
|
||||
)
|
||||
.to(torch.float32)
|
||||
.reshape(num_experts, hidden_size, intermediate_size)
|
||||
.to(device)
|
||||
)
|
||||
|
||||
# Quantize activations for SM100 path and dequantize for reference
|
||||
hidden_states_q, hidden_states_sf = mxfp8_quantize(hidden_states, True, 32)
|
||||
# Reference uses BF16 input but quantizes intermediate activation to MXFP8
|
||||
ref = reference_moe(
|
||||
router_logits.to(torch.float32),
|
||||
topk,
|
||||
num_experts,
|
||||
hidden_states.to(torch.float32),
|
||||
w13_ref,
|
||||
bias13.to(torch.float32),
|
||||
w2_ref,
|
||||
bias2.to(torch.float32),
|
||||
alpha,
|
||||
beta,
|
||||
limit,
|
||||
"mxfp8",
|
||||
)
|
||||
|
||||
# Prepare inputs for FlashInfer CUTLASS fused MoE
|
||||
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
|
||||
|
||||
# Swap halves to arrange as [w3; w1] (kernel expectation)
|
||||
w1_w, w3_w = torch.chunk(w13_q, 2, dim=1)
|
||||
w13_q_swapped = torch.cat([w3_w, w1_w], dim=1)
|
||||
|
||||
# Swap scales halves to match swapped weights
|
||||
s1, s3 = torch.chunk(w13_scale, 2, dim=1)
|
||||
w13_scale_swapped = torch.cat([s3, s1], dim=1)
|
||||
|
||||
b1, b3 = torch.chunk(bias13.to(torch.float32), 2, dim=-1)
|
||||
w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)
|
||||
|
||||
# Build routing for kernel
|
||||
routing_weights = torch.nn.functional.softmax(
|
||||
router_logits, dim=1, dtype=torch.float32
|
||||
)
|
||||
token_final_scales, token_selected_experts = torch.topk(
|
||||
routing_weights, topk, dim=-1
|
||||
)
|
||||
token_final_scales = token_final_scales / token_final_scales.sum(
|
||||
dim=-1, keepdim=True
|
||||
)
|
||||
token_selected_experts = token_selected_experts.to(torch.int).contiguous()
|
||||
|
||||
out = torch.empty_like(hidden_states, dtype=torch.bfloat16)
|
||||
if alpha is not None:
|
||||
alpha_t = torch.full((num_experts,), alpha, device=hidden_states.device)
|
||||
else:
|
||||
alpha_t = None
|
||||
if beta is not None:
|
||||
beta_t = torch.full((num_experts,), beta, device=hidden_states.device)
|
||||
else:
|
||||
beta_t = None
|
||||
if limit is not None:
|
||||
limit_t = torch.full((num_experts,), limit, device=hidden_states.device)
|
||||
else:
|
||||
limit_t = None
|
||||
|
||||
# Quant scales for SM100 MXFP8+MXFP4 path
|
||||
fake_input_scale = torch.ones(num_experts, device=device)
|
||||
quant_scales = [
|
||||
w13_scale_swapped.view(torch.int32),
|
||||
fake_input_scale,
|
||||
w2_scale.view(torch.int32),
|
||||
fake_input_scale,
|
||||
]
|
||||
|
||||
_ = flashinfer_cutlass_fused_moe(
|
||||
input=hidden_states_q,
|
||||
token_selected_experts=token_selected_experts,
|
||||
token_final_scales=token_final_scales,
|
||||
fc1_expert_weights=w13_q_swapped.contiguous().view(torch.long),
|
||||
fc2_expert_weights=w2_q.contiguous().view(torch.long),
|
||||
output_dtype=torch.bfloat16,
|
||||
output=out,
|
||||
quant_scales=quant_scales,
|
||||
fc1_expert_biases=w13_b,
|
||||
fc2_expert_biases=bias2.to(torch.bfloat16),
|
||||
swiglu_alpha=alpha_t,
|
||||
swiglu_beta=beta_t,
|
||||
swiglu_limit=limit_t,
|
||||
tp_size=1,
|
||||
tp_rank=0,
|
||||
ep_size=1,
|
||||
ep_rank=0,
|
||||
use_mxfp8_act_scaling=True,
|
||||
input_sf=hidden_states_sf,
|
||||
)
|
||||
|
||||
# Allow some mismatch due to MXFP4 quantization
|
||||
check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8)
|
||||
356
tests/kernels/moe/test_pplx_cutlass_moe.py
Normal file
356
tests/kernels/moe/test_pplx_cutlass_moe.py
Normal file
@@ -0,0 +1,356 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import torch_experts
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassBatchedExpertsFp8
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
from ...utils import multi_gpu_test
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||
|
||||
try:
|
||||
from pplx_kernels import AllToAll
|
||||
from pplx_kernels.nvshmem import (
|
||||
nvshmem_alloc_empty_unique_id,
|
||||
nvshmem_finalize,
|
||||
nvshmem_get_unique_id,
|
||||
nvshmem_init,
|
||||
)
|
||||
|
||||
has_pplx = True
|
||||
except ImportError:
|
||||
has_pplx = False
|
||||
|
||||
requires_pplx = pytest.mark.skipif(
|
||||
not has_pplx,
|
||||
reason="Requires PPLX kernels",
|
||||
)
|
||||
|
||||
NUM_EXPERTS = [40, 64]
|
||||
TOP_KS = [6, 8]
|
||||
|
||||
|
||||
def rank_chunk(num, r, w):
|
||||
rem = num % w
|
||||
return (num // w) + (1 if r < rem else 0)
|
||||
|
||||
|
||||
def chunk_by_rank(t, r, w):
|
||||
num = t.shape[0]
|
||||
chunk = rank_chunk(num, r, w)
|
||||
rem = num % w
|
||||
if rem == 0 or r < rem:
|
||||
return t[(r * chunk) : (r + 1) * chunk].contiguous()
|
||||
else:
|
||||
long_chunks = (num // w + 1) * rem
|
||||
short_chunks = (r - rem) * chunk
|
||||
start = long_chunks + short_chunks
|
||||
return t[start : start + chunk].contiguous()
|
||||
|
||||
|
||||
def pplx_cutlass_moe(
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
a1_scale: torch.Tensor,
|
||||
out_dtype,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
group_name: str | None,
|
||||
):
|
||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize,
|
||||
)
|
||||
|
||||
assert torch.cuda.current_device() == pgi.local_rank
|
||||
|
||||
num_tokens, hidden_dim = a.shape
|
||||
intermediate_dim = w2.shape[2]
|
||||
num_experts = w1.shape[0]
|
||||
block_size = hidden_dim # TODO support more cases
|
||||
device = pgi.device
|
||||
rank = pgi.rank
|
||||
world_size = pgi.world_size
|
||||
rank_num_tokens = rank_chunk(num_tokens, rank, world_size)
|
||||
max_num_tokens = rank_chunk(num_tokens, 0, world_size)
|
||||
topk = topk_ids.shape[1]
|
||||
|
||||
if block_size == hidden_dim:
|
||||
scale_elems = 4 # hack to circumvent pplx data format requirements
|
||||
else:
|
||||
scale_elems = (hidden_dim + block_size - 1) // block_size
|
||||
|
||||
args = dict(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_experts=num_experts,
|
||||
experts_per_token=topk,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
dp_size=dp_size,
|
||||
hidden_dim=hidden_dim,
|
||||
hidden_dim_bytes=hidden_dim, # because a.dtype.itemsize == 1
|
||||
hidden_dim_scale_bytes=scale_elems * torch.float32.itemsize,
|
||||
)
|
||||
|
||||
if group_name is None:
|
||||
ata = AllToAll.internode(**args)
|
||||
else:
|
||||
args["group_name"] = group_name
|
||||
ata = AllToAll.intranode(**args)
|
||||
|
||||
w1 = w1.to(device)
|
||||
w2 = w2.to(device)
|
||||
w1_scale = w1_scale.to(device)
|
||||
w2_scale = w2_scale.to(device)
|
||||
a1_scale = a1_scale.to(device)
|
||||
|
||||
assert num_experts % world_size == 0
|
||||
num_local_experts = cdiv(num_experts, world_size)
|
||||
num_dispatchers = pgi.world_size // dp_size
|
||||
|
||||
prepare_finalize = PplxPrepareAndFinalize(
|
||||
ata,
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_local_experts=num_local_experts,
|
||||
num_dispatchers=num_dispatchers,
|
||||
)
|
||||
|
||||
ab_strides1 = torch.full(
|
||||
(num_local_experts,), hidden_dim, device="cuda", dtype=torch.int64
|
||||
)
|
||||
ab_strides2 = torch.full(
|
||||
(num_local_experts,), intermediate_dim, device="cuda", dtype=torch.int64
|
||||
)
|
||||
c_strides1 = torch.full(
|
||||
(num_local_experts,), 2 * intermediate_dim, device="cuda", dtype=torch.int64
|
||||
)
|
||||
c_strides2 = torch.full(
|
||||
(num_local_experts,), hidden_dim, device="cuda", dtype=torch.int64
|
||||
)
|
||||
|
||||
experts = CutlassBatchedExpertsFp8(
|
||||
num_local_experts,
|
||||
num_dispatchers,
|
||||
out_dtype,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
fp8_w8a8_moe_quant_config(
|
||||
per_act_token_quant=per_act_token,
|
||||
per_out_ch_quant=per_out_ch,
|
||||
w1_scale=chunk_by_rank(w1_scale, rank, world_size),
|
||||
w2_scale=chunk_by_rank(w2_scale, rank, world_size),
|
||||
a1_scale=chunk_by_rank(a1_scale, rank, world_size)
|
||||
if per_act_token
|
||||
else a1_scale[rank],
|
||||
),
|
||||
)
|
||||
|
||||
fused_cutlass_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
)
|
||||
|
||||
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
|
||||
chunk_topk_weight = chunk_by_rank(topk_weights, rank, world_size).to(device)
|
||||
chunk_topk_ids = (
|
||||
chunk_by_rank(topk_ids, rank, world_size).to(torch.uint32).to(device)
|
||||
)
|
||||
|
||||
out = fused_cutlass_experts(
|
||||
a_chunk,
|
||||
chunk_by_rank(w1, rank, world_size),
|
||||
chunk_by_rank(w2, rank, world_size),
|
||||
chunk_topk_weight,
|
||||
chunk_topk_ids,
|
||||
global_num_experts=num_experts,
|
||||
expert_map=None, # TODO
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
ata.destroy()
|
||||
|
||||
return out[:rank_num_tokens]
|
||||
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
|
||||
|
||||
def _pplx_moe(
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
a1_scale: torch.Tensor,
|
||||
out_dtype,
|
||||
a_full: torch.Tensor,
|
||||
w1_full: torch.Tensor,
|
||||
w2_full: torch.Tensor,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
use_internode: bool,
|
||||
):
|
||||
try:
|
||||
if use_internode:
|
||||
uid = (
|
||||
nvshmem_get_unique_id()
|
||||
if pgi.rank == 0
|
||||
else nvshmem_alloc_empty_unique_id()
|
||||
)
|
||||
torch.distributed.broadcast(uid, src=0)
|
||||
nvshmem_init(uid, pgi.rank, pgi.world_size)
|
||||
else:
|
||||
group_ranks = list(range(pgi.world_size))
|
||||
cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
|
||||
group_name = cpu_group.group_name
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
torch_output = torch_experts(
|
||||
a_full, w1_full, w2_full, topk_weights, topk_ids
|
||||
)
|
||||
pplx_output = pplx_cutlass_moe(
|
||||
pgi,
|
||||
dp_size,
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
a1_scale,
|
||||
out_dtype,
|
||||
per_act_token,
|
||||
per_out_ch,
|
||||
group_name,
|
||||
)
|
||||
|
||||
torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(
|
||||
pplx_output.device
|
||||
)
|
||||
|
||||
# Uncomment if more debugging is needed
|
||||
# print("PPLX OUT:", pplx_output)
|
||||
# print("TORCH OUT:", torch_output)
|
||||
|
||||
torch.testing.assert_close(pplx_output, torch_output, atol=0.05, rtol=0)
|
||||
finally:
|
||||
if use_internode:
|
||||
nvshmem_finalize()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [2, 224])
|
||||
@pytest.mark.parametrize("n", [3072])
|
||||
@pytest.mark.parametrize("k", [1536])
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) # , [4, 2]])
|
||||
@pytest.mark.parametrize("use_internode", [False])
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()
|
||||
),
|
||||
reason="Grouped gemm is not supported on this GPU type.",
|
||||
)
|
||||
@requires_pplx
|
||||
def test_cutlass_moe_pplx(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
world_dp_size: tuple[int, int],
|
||||
use_internode: bool,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
dtype = torch.half
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10.0
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10.0
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10.0
|
||||
|
||||
n_b_scales = 2 * n if per_out_ch else 1
|
||||
k_b_scales = k if per_out_ch else 1
|
||||
|
||||
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=torch.float8_e4m3fn)
|
||||
w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn)
|
||||
w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32)
|
||||
w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32)
|
||||
|
||||
for expert in range(e):
|
||||
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
|
||||
w1[expert], use_per_token_if_dynamic=per_out_ch
|
||||
)
|
||||
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
|
||||
w2[expert], use_per_token_if_dynamic=per_out_ch
|
||||
)
|
||||
|
||||
w1_d = torch.empty_like(w1)
|
||||
w2_d = torch.empty_like(w2)
|
||||
for expert in range(e):
|
||||
w1_d[expert] = (w1_q[expert].float() * w1_scale[expert]).half()
|
||||
w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half()
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)
|
||||
|
||||
world_size, dp_size = world_dp_size
|
||||
a_scale1 = (
|
||||
torch.randn(
|
||||
(m if per_act_token else 1, 1), device="cuda", dtype=torch.float32
|
||||
)
|
||||
/ 10.0
|
||||
)
|
||||
if not per_act_token:
|
||||
a_scale1 = a_scale1.repeat(world_size, 1)
|
||||
|
||||
parallel_launch(
|
||||
world_size,
|
||||
_pplx_moe,
|
||||
dp_size,
|
||||
a,
|
||||
w1_q,
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
a_scale1,
|
||||
dtype,
|
||||
a,
|
||||
w1_d,
|
||||
w2_d,
|
||||
per_act_token,
|
||||
per_out_ch,
|
||||
use_internode,
|
||||
)
|
||||
1018
tests/kernels/moe/test_pplx_moe.py
Normal file
1018
tests/kernels/moe/test_pplx_moe.py
Normal file
File diff suppressed because it is too large
Load Diff
219
tests/kernels/moe/test_rocm_aiter_topk.py
Normal file
219
tests/kernels/moe/test_rocm_aiter_topk.py
Normal file
@@ -0,0 +1,219 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# This is a test for the AITER ops.
|
||||
# It tests if the AITER ops are
|
||||
# 1. correctly registered as custom ops
|
||||
# 2. correctly defined the relationship between
|
||||
# implementation and fake function
|
||||
# 3. can be used with torch.compile
|
||||
# This file will be skipped if AITER is not installed
|
||||
# and the platform is not ROCm.
|
||||
|
||||
import importlib.util
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# this import statement is needed to ensure the ops are registered
|
||||
import vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe # noqa: F401
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# need to import once to ensure the ops are registered
|
||||
# Check if aiter package is installed
|
||||
aiter_available = importlib.util.find_spec("aiter") is not None
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not (current_platform.is_rocm() and aiter_available),
|
||||
reason="AITER ops are only available on ROCm with aiter package installed",
|
||||
)
|
||||
|
||||
|
||||
def test_rocm_aiter_biased_grouped_topk_custom_op_registration():
|
||||
"""Test that the custom op is correctly registered."""
|
||||
# Check if the op exists in torch.ops.vllm
|
||||
assert hasattr(torch.ops.vllm, "rocm_aiter_biased_grouped_topk")
|
||||
|
||||
# Check if the op is callable
|
||||
assert callable(torch.ops.vllm.rocm_aiter_biased_grouped_topk)
|
||||
|
||||
|
||||
def test_rocm_aiter_grouped_topk_custom_op_registration():
|
||||
"""Test that the custom op is correctly registered."""
|
||||
# Check if the op exists in torch.ops.vllm
|
||||
assert hasattr(torch.ops.vllm, "rocm_aiter_grouped_topk")
|
||||
|
||||
# Check if the op is callable
|
||||
assert callable(torch.ops.vllm.rocm_aiter_grouped_topk)
|
||||
|
||||
|
||||
def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility():
|
||||
"""Test that the op can be used with torch.compile."""
|
||||
# Create test tensors
|
||||
token = 64
|
||||
expert = 256
|
||||
num_expert_group = 8
|
||||
topk = 8
|
||||
topk_group = 4
|
||||
renormalize = True
|
||||
scale_factor = 1.0
|
||||
|
||||
gating_output = torch.randn((token, expert), dtype=torch.bfloat16, device="cuda")
|
||||
e_score_correction_bias = torch.randn(
|
||||
(expert,), dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
|
||||
device = gating_output.device
|
||||
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
|
||||
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
|
||||
|
||||
# Define a function that uses the op
|
||||
def biased_grouped_topk_fn(
|
||||
gating_output, e_score_correction_bias, topk_weights, topk_ids
|
||||
):
|
||||
return torch.ops.vllm.rocm_aiter_biased_grouped_topk(
|
||||
gating_output,
|
||||
e_score_correction_bias,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
renormalize,
|
||||
scale_factor,
|
||||
)
|
||||
|
||||
# Verify the op's fake implementation
|
||||
torch.library.opcheck(
|
||||
torch.ops.vllm.rocm_aiter_biased_grouped_topk,
|
||||
(gating_output, e_score_correction_bias, topk_weights, topk_ids),
|
||||
kwargs={
|
||||
"num_expert_group": num_expert_group,
|
||||
"topk_group": topk_group,
|
||||
"need_renorm": renormalize,
|
||||
"routed_scaling_factor": scale_factor,
|
||||
},
|
||||
test_utils=("test_faketensor"),
|
||||
)
|
||||
|
||||
# Compile the function with appropriate settings
|
||||
compiled_fn = torch.compile(
|
||||
biased_grouped_topk_fn,
|
||||
fullgraph=True,
|
||||
backend="inductor",
|
||||
mode="reduce-overhead",
|
||||
dynamic=False,
|
||||
)
|
||||
|
||||
topk_weights_original = torch.empty(
|
||||
(token, topk), dtype=torch.float32, device=device
|
||||
)
|
||||
topk_ids_original = torch.empty((token, topk), dtype=torch.int32, device=device)
|
||||
|
||||
topk_weights_compiled = torch.empty(
|
||||
(token, topk), dtype=torch.float32, device=device
|
||||
)
|
||||
topk_ids_compiled = torch.empty((token, topk), dtype=torch.int32, device=device)
|
||||
|
||||
# Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode)
|
||||
biased_grouped_topk_fn(
|
||||
gating_output, e_score_correction_bias, topk_weights_original, topk_ids_original
|
||||
)
|
||||
compiled_fn(
|
||||
gating_output, e_score_correction_bias, topk_weights_compiled, topk_ids_compiled
|
||||
)
|
||||
|
||||
# Sort the results for comparison since the order might not be deterministic
|
||||
topk_ids_original, indices_original = torch.sort(topk_ids_original)
|
||||
topk_weights_original = torch.gather(topk_weights_original, 1, indices_original)
|
||||
|
||||
topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled)
|
||||
topk_weights_compiled = torch.gather(topk_weights_compiled, 1, indices_compiled)
|
||||
|
||||
# Verify results match
|
||||
assert torch.allclose(
|
||||
topk_weights_original, topk_weights_compiled, rtol=1e-2, atol=1e-2
|
||||
)
|
||||
assert torch.allclose(topk_ids_original, topk_ids_compiled)
|
||||
|
||||
|
||||
def test_rocm_aiter_grouped_topk_torch_compile_compatibility():
|
||||
"""Test that the op can be used with torch.compile."""
|
||||
# Create test tensors
|
||||
token = 64
|
||||
expert = 256
|
||||
num_expert_group = 8
|
||||
topk = 8
|
||||
topk_group = 4
|
||||
renormalize = True
|
||||
scoring_func = "softmax"
|
||||
scale_factor = 1.0
|
||||
|
||||
gating_output = torch.randn((token, expert), dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
device = gating_output.device
|
||||
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
|
||||
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
|
||||
|
||||
# Define a function that uses the op
|
||||
def grouped_topk_fn(gating_output, topk_weights, topk_ids, scoring_func):
|
||||
return torch.ops.vllm.rocm_aiter_grouped_topk(
|
||||
gating_output,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
renormalize,
|
||||
scoring_func,
|
||||
scale_factor,
|
||||
)
|
||||
|
||||
# Verify the op's fake implementation
|
||||
torch.library.opcheck(
|
||||
torch.ops.vllm.rocm_aiter_grouped_topk,
|
||||
(gating_output, topk_weights, topk_ids),
|
||||
kwargs={
|
||||
"num_expert_group": num_expert_group,
|
||||
"topk_group": topk_group,
|
||||
"need_renorm": renormalize,
|
||||
"scoring_func": scoring_func,
|
||||
"routed_scaling_factor": scale_factor,
|
||||
},
|
||||
test_utils=("test_faketensor"),
|
||||
)
|
||||
|
||||
# Compile the function with appropriate settings
|
||||
compiled_fn = torch.compile(
|
||||
grouped_topk_fn,
|
||||
fullgraph=True,
|
||||
backend="inductor",
|
||||
mode="reduce-overhead",
|
||||
dynamic=False,
|
||||
)
|
||||
|
||||
topk_weights_original = torch.empty(
|
||||
(token, topk), dtype=torch.float32, device=device
|
||||
)
|
||||
topk_ids_original = torch.empty((token, topk), dtype=torch.int32, device=device)
|
||||
|
||||
topk_weights_compiled = torch.empty(
|
||||
(token, topk), dtype=torch.float32, device=device
|
||||
)
|
||||
topk_ids_compiled = torch.empty((token, topk), dtype=torch.int32, device=device)
|
||||
|
||||
# Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode)
|
||||
grouped_topk_fn(
|
||||
gating_output, topk_weights_original, topk_ids_original, scoring_func
|
||||
)
|
||||
compiled_fn(gating_output, topk_weights_compiled, topk_ids_compiled, scoring_func)
|
||||
|
||||
# Sort the results for comparison since the order might not be deterministic
|
||||
topk_ids_original, indices_original = torch.sort(topk_ids_original)
|
||||
topk_weights_original = torch.gather(topk_weights_original, 1, indices_original)
|
||||
|
||||
topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled)
|
||||
topk_weights_compiled = torch.gather(topk_weights_compiled, 1, indices_compiled)
|
||||
|
||||
# Verify results match
|
||||
assert torch.allclose(
|
||||
topk_weights_original, topk_weights_compiled, rtol=1e-2, atol=1e-2
|
||||
)
|
||||
assert torch.allclose(topk_ids_original, topk_ids_compiled)
|
||||
293
tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py
Normal file
293
tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py
Normal file
@@ -0,0 +1,293 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
persistent_masked_m_silu_mul_quant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import DeepGemmQuantScaleFMT, has_deep_gemm
|
||||
from vllm.utils.math_utils import cdiv, round_up
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
pytest.skip(
|
||||
"Tests in this file require float8_e4m3fn and platform does not support",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
fp8_dtype = torch.float8_e4m3fn
|
||||
|
||||
CASES = [
|
||||
(1, 1, 128, fp8_dtype),
|
||||
(1, 4, 128 * 1, fp8_dtype),
|
||||
(2, 4, 128 * 2, fp8_dtype),
|
||||
(1, 4, 128 * 3, fp8_dtype),
|
||||
(8, 16, 128 * 4, fp8_dtype),
|
||||
(8, 16, 128 * 5, fp8_dtype),
|
||||
(8, 16, 128 * 6, fp8_dtype),
|
||||
(8, 16, 128 * 7, fp8_dtype),
|
||||
(8, 16, 128 * 8, fp8_dtype),
|
||||
(8, 16, 128 * 9, fp8_dtype),
|
||||
(8, 64, 7168, fp8_dtype),
|
||||
(8, 128, 128 * 33, fp8_dtype),
|
||||
(1, 4, 128 * 10, fp8_dtype),
|
||||
(8, 128, 7168, fp8_dtype),
|
||||
(8, 512, 7168, fp8_dtype),
|
||||
(8, 1024, 7168, fp8_dtype),
|
||||
(17, 31, 768, fp8_dtype),
|
||||
(32, 64, 256, fp8_dtype),
|
||||
(256, 8, 7168, fp8_dtype),
|
||||
(256, 32, 7168, fp8_dtype),
|
||||
(256, 64, 7168, fp8_dtype),
|
||||
# Only add a few fnuz tests to help with long CI times.
|
||||
(8, 512, 7168, torch.float8_e4m3fnuz),
|
||||
(8, 1024, 7168, torch.float8_e4m3fnuz),
|
||||
]
|
||||
|
||||
|
||||
def as_uint8(x) -> torch.Tensor:
|
||||
return (
|
||||
torch.empty(x.shape, dtype=x.dtype, device=x.device).copy_(x).view(torch.uint8)
|
||||
)
|
||||
|
||||
|
||||
def silu(x: torch.Tensor) -> torch.Tensor:
|
||||
one_f32 = torch.tensor([1.0], device=x.device, dtype=torch.float32)
|
||||
x_f32 = x.to(torch.float32)
|
||||
act_f32 = x_f32 / (one_f32 + torch.exp(-x_f32))
|
||||
assert act_f32.dtype == torch.float32
|
||||
return act_f32.to(torch.bfloat16)
|
||||
|
||||
|
||||
def do_quant(x: torch.Tensor, group_size: int, ceil_ue8m0: bool):
|
||||
eps_bf16 = torch.tensor([1e-10], device=x.device, dtype=torch.bfloat16)
|
||||
one_bf16 = torch.tensor([1.0], device=x.device, dtype=torch.bfloat16)
|
||||
fp8_max_bf16 = torch.tensor(
|
||||
[torch.finfo(fp8_dtype).max], device=x.device, dtype=torch.bfloat16
|
||||
)
|
||||
fp8_min_bf16 = torch.tensor(
|
||||
[torch.finfo(fp8_dtype).min], device=x.device, dtype=torch.bfloat16
|
||||
)
|
||||
fp8_max_inv = one_bf16 / fp8_max_bf16
|
||||
assert fp8_max_inv.dtype == torch.bfloat16
|
||||
|
||||
assert x.size(-1) % group_size == 0
|
||||
num_groups = x.numel() // group_size
|
||||
x_og_shape = x.shape
|
||||
|
||||
x = x.to(torch.bfloat16)
|
||||
x = x.view((-1, group_size))
|
||||
amax = x.abs().amax(dim=1).clamp(min=eps_bf16)
|
||||
assert amax.dtype == torch.bfloat16
|
||||
s = amax * fp8_max_inv
|
||||
|
||||
if ceil_ue8m0:
|
||||
s = torch.exp2(
|
||||
torch.ceil(torch.log2(s).to(torch.bfloat16)).to(torch.bfloat16)
|
||||
).to(torch.bfloat16)
|
||||
|
||||
inv_s = one_bf16 / s
|
||||
inv_s = inv_s.view((num_groups, 1))
|
||||
xq = torch.clamp(x * inv_s, min=fp8_min_bf16.item(), max=fp8_max_bf16.item()).to(
|
||||
fp8_dtype
|
||||
)
|
||||
|
||||
xq = xq.view(x_og_shape)
|
||||
xs = s.view((-1, xq.size(-1) // group_size))
|
||||
return xq, xs
|
||||
|
||||
|
||||
def silu_mul_quant(
|
||||
gate: torch.Tensor, up: torch.Tensor, group_size: int, ceil_ue8m0: bool
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert gate.size(-1) % group_size == 0
|
||||
assert up.size(-1) % group_size == 0
|
||||
|
||||
assert gate.dtype == torch.bfloat16
|
||||
assert up.dtype == torch.bfloat16
|
||||
|
||||
act_bf16 = silu(gate)
|
||||
assert act_bf16.dtype == torch.bfloat16
|
||||
|
||||
# act & mul
|
||||
a_m = act_bf16 * up
|
||||
assert a_m.dtype == torch.bfloat16
|
||||
|
||||
q, s = do_quant(a_m, group_size, ceil_ue8m0)
|
||||
return q, s
|
||||
|
||||
|
||||
def pack_scales(x: torch.Tensor, tokens_per_expert: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
pack float32 scales into a int32 tensor
|
||||
"""
|
||||
assert x.dtype == torch.float32
|
||||
E, T, G = x.size()
|
||||
|
||||
# Add i32_padding here so we can view it as a i32 tensor later on.
|
||||
i32_padding = round_up(G, 4) - G
|
||||
ref_s_i8 = torch.empty((E, T, G + i32_padding), dtype=torch.uint8, device="cuda")
|
||||
for e in range(E):
|
||||
nt = tokens_per_expert[e].item()
|
||||
ref_s_i8[e, :nt, :G] = x[e, :nt].view(torch.int32) >> 23
|
||||
|
||||
ref_s_i32 = ref_s_i8.view(torch.int32)
|
||||
|
||||
return ref_s_i32
|
||||
|
||||
|
||||
def ref_with_scale_fmt(
|
||||
E: int,
|
||||
T: int,
|
||||
H: int,
|
||||
group_size: int,
|
||||
tokens_per_expert: torch.Tensor,
|
||||
gate: torch.Tensor,
|
||||
up: torch.Tensor,
|
||||
scale_fmt: DeepGemmQuantScaleFMT,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
The precision types of the operations triggered by this function
|
||||
match closely with the kernel implementation so we compare more
|
||||
accurately.
|
||||
"""
|
||||
scale_dtype = (
|
||||
torch.int32 if scale_fmt == DeepGemmQuantScaleFMT.UE8M0 else torch.float32
|
||||
)
|
||||
ceil_ue8m0 = scale_fmt in [
|
||||
DeepGemmQuantScaleFMT.UE8M0,
|
||||
DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0,
|
||||
]
|
||||
|
||||
ref_q = torch.empty((E, T, H), dtype=fp8_dtype, device="cuda")
|
||||
ref_s_f32 = torch.empty(
|
||||
(E, T, cdiv(H, group_size)), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
|
||||
for e in range(E):
|
||||
nt = tokens_per_expert[e].item()
|
||||
if nt == 0:
|
||||
continue
|
||||
ref_q[e, :nt], ref_s_f32[e, :nt] = silu_mul_quant(
|
||||
gate[e, :nt], up[e, :nt], group_size, ceil_ue8m0=ceil_ue8m0
|
||||
)
|
||||
|
||||
if scale_dtype == torch.float32:
|
||||
return ref_q, ref_s_f32
|
||||
|
||||
assert scale_dtype == torch.int32
|
||||
return ref_q, pack_scales(ref_s_f32, tokens_per_expert)
|
||||
|
||||
|
||||
def token_random(E, T, H2, tokens_per_expert):
|
||||
"""
|
||||
Initialize each token in a random range so we test a range of
|
||||
scale values.
|
||||
"""
|
||||
y = torch.empty((E, T, H2), dtype=torch.bfloat16, device="cuda")
|
||||
for e in range(E):
|
||||
for t in range(tokens_per_expert[e].item()):
|
||||
exp = random.choice(range(1, 20))
|
||||
y[e, t].uniform_(-(2**exp), 2**exp)
|
||||
return y
|
||||
|
||||
|
||||
@pytest.mark.parametrize("E,T,H,fp8_type", CASES)
|
||||
@torch.inference_mode()
|
||||
def test_silu_mul_fp8_quant_deep_gemm(E: int, T: int, H: int, fp8_type: torch.dtype):
|
||||
group_size = 128
|
||||
current_platform.seed_everything(42)
|
||||
|
||||
tokens_per_expert = torch.randint(
|
||||
low=0,
|
||||
high=T,
|
||||
size=(E,),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
# Input tensor of shape (E, T, 2*H)
|
||||
y = token_random(E, T, 2 * H, tokens_per_expert)
|
||||
|
||||
gate = y[..., :H].to(torch.bfloat16)
|
||||
up = y[..., H:].to(torch.bfloat16)
|
||||
|
||||
scale_fmts = [
|
||||
DeepGemmQuantScaleFMT.FLOAT32,
|
||||
DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0,
|
||||
DeepGemmQuantScaleFMT.UE8M0,
|
||||
]
|
||||
|
||||
# Run the SiLU V2 kernel
|
||||
for scale_fmt in scale_fmts:
|
||||
y_q, y_s = persistent_masked_m_silu_mul_quant(
|
||||
y,
|
||||
tokens_per_expert,
|
||||
group_size=group_size,
|
||||
quant_scale_fmt=scale_fmt,
|
||||
)
|
||||
|
||||
ref_y_q, ref_y_s = ref_with_scale_fmt(
|
||||
E, T, H, group_size, tokens_per_expert, gate, up, scale_fmt=scale_fmt
|
||||
)
|
||||
|
||||
# deepgemm scales transform
|
||||
dg_scales = None
|
||||
if (
|
||||
has_deep_gemm()
|
||||
and current_platform.has_device_capability(100)
|
||||
and scale_fmt == DeepGemmQuantScaleFMT.UE8M0
|
||||
):
|
||||
from deep_gemm import transform_sf_into_required_layout
|
||||
|
||||
_q, _s = ref_with_scale_fmt(
|
||||
E,
|
||||
T,
|
||||
H,
|
||||
group_size,
|
||||
tokens_per_expert,
|
||||
gate,
|
||||
up,
|
||||
scale_fmt=DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0,
|
||||
)
|
||||
dg_scales = transform_sf_into_required_layout(
|
||||
sf=_s,
|
||||
mn=_q.size(1),
|
||||
k=_q.size(2),
|
||||
recipe=(1, 128, 128),
|
||||
num_groups=_q.size(0),
|
||||
is_sfa=True,
|
||||
)
|
||||
|
||||
expected_scale_dtype = (
|
||||
torch.int32 if scale_fmt == DeepGemmQuantScaleFMT.UE8M0 else torch.float32
|
||||
)
|
||||
assert y_s.dtype == expected_scale_dtype
|
||||
assert ref_y_s.dtype == expected_scale_dtype
|
||||
|
||||
for e in range(E):
|
||||
nt = tokens_per_expert[e].item()
|
||||
|
||||
torch.testing.assert_close(
|
||||
y_q[e, :nt].to(torch.float32),
|
||||
ref_y_q[e, :nt].to(torch.float32),
|
||||
)
|
||||
|
||||
if scale_fmt == DeepGemmQuantScaleFMT.UE8M0:
|
||||
G = H // group_size
|
||||
y_s_sliced = as_uint8(y_s[e])
|
||||
ref_s_sliced = as_uint8(ref_y_s[e])
|
||||
torch.testing.assert_close(y_s_sliced[:nt, :G], ref_s_sliced[:nt, :G])
|
||||
if dg_scales is not None:
|
||||
dg_sliced = as_uint8(dg_scales[e])
|
||||
torch.testing.assert_close(y_s_sliced[:nt, :G], dg_sliced[:nt, :G])
|
||||
else:
|
||||
torch.testing.assert_close(
|
||||
y_s[e, :nt],
|
||||
ref_y_s[e, :nt],
|
||||
)
|
||||
@@ -0,0 +1,86 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
_per_token_group_quant_fp8_colmajor,
|
||||
silu_mul_per_token_group_quant_fp8_colmajor,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||
|
||||
FLOAT8_DTYPE = torch.float8_e4m3fn
|
||||
GROUP_SIZE = 128
|
||||
|
||||
|
||||
def reference_quant(x: torch.Tensor, use_ue8m0: bool):
|
||||
"""
|
||||
Reference triton quant kernel from,
|
||||
vllm.model_executor.layers.quantization.utils.fp8_utils
|
||||
"""
|
||||
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=FLOAT8_DTYPE)
|
||||
|
||||
# Allocate the scale tensor in column-major format.
|
||||
shape = (x.shape[-1] // GROUP_SIZE,) + x.shape[:-1]
|
||||
x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
|
||||
|
||||
M = x.numel() // GROUP_SIZE
|
||||
N = GROUP_SIZE
|
||||
BLOCK = triton.next_power_of_2(N)
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK // 256, 1), 8)
|
||||
num_stages = 1
|
||||
|
||||
finfo = torch.finfo(FLOAT8_DTYPE)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
_per_token_group_quant_fp8_colmajor[(M,)](
|
||||
x,
|
||||
x_q,
|
||||
x_s,
|
||||
GROUP_SIZE,
|
||||
x.shape[1],
|
||||
x.stride(0),
|
||||
x_s.stride(1),
|
||||
eps=1e-10,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
use_ue8m0=use_ue8m0,
|
||||
BLOCK=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
def reference(x: torch.Tensor, use_ue8m0: bool) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
T, N = x.size()
|
||||
ref_act_out = torch.empty((T, N // 2), dtype=torch.bfloat16, device="cuda")
|
||||
torch.ops._C.silu_and_mul(ref_act_out, x)
|
||||
return reference_quant(ref_act_out, use_ue8m0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("T", [128, 256, 512])
|
||||
@pytest.mark.parametrize("N", [128 * 2, 256 * 2, 768 * 2, 2048 * 2, 7168 * 2])
|
||||
def test_silu_mul_fp8_quant_deep_gemm(T: int, N: int):
|
||||
current_platform.seed_everything(42)
|
||||
|
||||
input = torch.rand((T, N), dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
use_ue8m0 = is_deep_gemm_e8m0_used()
|
||||
|
||||
# Test
|
||||
output, output_scales = silu_mul_per_token_group_quant_fp8_colmajor(
|
||||
input, use_ue8m0=use_ue8m0
|
||||
)
|
||||
|
||||
# Reference
|
||||
ref_output, ref_output_scales = reference(input, use_ue8m0)
|
||||
|
||||
torch.testing.assert_close(output.to(torch.float32), ref_output.to(torch.float32))
|
||||
torch.testing.assert_close(output_scales, ref_output_scales)
|
||||
170
tests/kernels/moe/test_triton_moe_ptpc_fp8.py
Normal file
170
tests/kernels/moe/test_triton_moe_ptpc_fp8.py
Normal file
@@ -0,0 +1,170 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from https://github.com/sgl-project/sglang/blob/main/test/srt/test_triton_moe_channel_fp8_kernel.py
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import fused_moe
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.get_device_capability() < (9, 0):
|
||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
pytest.skip(
|
||||
"Tests in this file require float8_e4m3fn and platform does not support",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
|
||||
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
|
||||
"""Matrix multiplication function that supports per-token input
|
||||
quantization and per-column weight quantization"""
|
||||
A = A.to(torch.float32)
|
||||
B = B.to(torch.float32)
|
||||
|
||||
assert A.shape[-1] == B.shape[-1], "Dimension mismatch"
|
||||
assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor"
|
||||
|
||||
# Reshape input
|
||||
M = A.numel() // A.shape[-1]
|
||||
B = B.t() # Transpose weight matrix
|
||||
N, K = B.shape
|
||||
origin_C_shape = A.shape[:-1] + (K,)
|
||||
A = A.reshape(M, N)
|
||||
|
||||
# As is per-token [M, 1], Bs is per-column [1, K]
|
||||
C = torch.matmul(A, B) # [M, K]
|
||||
C = As * C * Bs.view(1, -1) # Broadcast per-column scale
|
||||
|
||||
return C.reshape(origin_C_shape).to(output_dtype)
|
||||
|
||||
|
||||
def fp8_mask(a, mask):
|
||||
dtype = a.dtype
|
||||
return a.view(torch.int8)[mask].view(dtype)
|
||||
|
||||
|
||||
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
|
||||
"""This function performs fused moe with per-column int8
|
||||
quantization using native torch."""
|
||||
|
||||
B, D = a.shape
|
||||
# Perform per-token quantization
|
||||
a_q, a_s = ops.scaled_fp8_quant(a, use_per_token_if_dynamic=True)
|
||||
# Repeat tokens to match topk
|
||||
a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
# Also repeat the scale
|
||||
a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) # [B*topk, 1]
|
||||
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
|
||||
# Calculate routing
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
# Process each expert
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
# First MLP layer: note that a_s is now per-token
|
||||
inter_out = native_w8a8_per_token_matmul(
|
||||
fp8_mask(a_q, mask),
|
||||
w1[i],
|
||||
fp8_mask(a_s, mask),
|
||||
w1_s[i],
|
||||
output_dtype=a.dtype,
|
||||
)
|
||||
# Activation function
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
# Quantize activation output with per-token
|
||||
act_out_q, act_out_s = ops.scaled_fp8_quant(
|
||||
act_out, use_per_token_if_dynamic=True
|
||||
)
|
||||
|
||||
# Second MLP layer
|
||||
out[mask] = native_w8a8_per_token_matmul(
|
||||
act_out_q, w2[i], act_out_s, w2_s[i], output_dtype=a.dtype
|
||||
)
|
||||
# Apply routing weights and sum
|
||||
return (
|
||||
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def setup_cuda():
|
||||
"""Sets the default CUDA device for all tests in this module."""
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16]
|
||||
M = [1, 33]
|
||||
N = [128, 1024]
|
||||
K = [256, 4096]
|
||||
E = [8]
|
||||
TOP_KS = [2, 6]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M, N, K, E, topk, dtype, seed",
|
||||
itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS),
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
# Initialize int8 quantization parameters
|
||||
factor_for_scale = 1e-2
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max = finfo.max
|
||||
fp8_min = finfo.min
|
||||
|
||||
# Input tensor
|
||||
# M * K
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
|
||||
# Generate int8 weights
|
||||
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2
|
||||
w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2
|
||||
w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
# Generate scale for each column (per-column quantization)
|
||||
w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale
|
||||
w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk)
|
||||
out = fused_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
quant_config=fp8_w8a8_moe_quant_config(
|
||||
per_act_token_quant=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
block_shape=None, # Not using block quantization
|
||||
),
|
||||
)
|
||||
|
||||
# Check results
|
||||
rel_diff = torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
|
||||
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
|
||||
assert rel_diff < 0.05
|
||||
521
tests/kernels/moe/utils.py
Normal file
521
tests/kernels/moe/utils.py
Normal file
@@ -0,0 +1,521 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from tests.kernels.quant_utils import per_block_cast_to_int8
|
||||
from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedPrepareAndFinalize,
|
||||
BatchedTritonExperts,
|
||||
NaiveBatchedExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
from vllm.utils.deep_gemm import per_block_cast_to_fp8
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
|
||||
def triton_moe(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
quant_dtype: torch.dtype | None = None,
|
||||
per_act_token_quant=False,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> torch.Tensor:
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
return fused_experts(a, w1, w2, topk_weight, topk_ids, quant_config=quant_config)
|
||||
|
||||
|
||||
def batched_moe(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
quant_dtype: torch.dtype | None = None,
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> torch.Tensor:
|
||||
max_num_tokens = round_up(a.shape[0], 64)
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
BatchedPrepareAndFinalize(
|
||||
max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
|
||||
),
|
||||
BatchedTritonExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
|
||||
return fused_experts(a, w1, w2, topk_weight, topk_ids)
|
||||
|
||||
|
||||
def naive_batched_moe(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
quant_dtype: torch.dtype | None = None,
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> torch.Tensor:
|
||||
max_num_tokens = round_up(a.shape[0], 64)
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
BatchedPrepareAndFinalize(
|
||||
max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
|
||||
),
|
||||
NaiveBatchedExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
|
||||
return fused_experts(a, w1, w2, topk_weight, topk_ids)
|
||||
|
||||
|
||||
def chunk_scales(
|
||||
scales: torch.Tensor | None, start: int, end: int
|
||||
) -> torch.Tensor | None:
|
||||
if scales is not None:
|
||||
if scales.numel() == 1:
|
||||
return scales
|
||||
else:
|
||||
return scales[start:end]
|
||||
return None
|
||||
|
||||
|
||||
def make_quantized_test_activations(
|
||||
E: int,
|
||||
m: int,
|
||||
k: int,
|
||||
in_dtype: torch.dtype,
|
||||
quant_dtype: torch.dtype | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
per_act_token_quant: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
||||
a = torch.randn((E, m, k), device="cuda", dtype=in_dtype) / 10
|
||||
a_q = a
|
||||
a_scale = None
|
||||
|
||||
if quant_dtype is not None:
|
||||
assert quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8, (
|
||||
"only fp8/int8 supported"
|
||||
)
|
||||
a_q = torch.zeros_like(a, dtype=quant_dtype)
|
||||
a_scale_l = [None] * E
|
||||
for e in range(E):
|
||||
a_q[e], a_scale_l[e] = moe_kernel_quantize_input(
|
||||
a[e], None, quant_dtype, per_act_token_quant, block_shape
|
||||
)
|
||||
a_scale = torch.stack(a_scale_l)
|
||||
|
||||
if not per_act_token_quant and block_shape is None:
|
||||
a_scale = a_scale.view(E, 1, 1)
|
||||
|
||||
return a, a_q, a_scale
|
||||
|
||||
|
||||
def moe_quantize_weights(
|
||||
w: torch.Tensor,
|
||||
w_s: torch.Tensor | None,
|
||||
quant_dtype: torch.dtype | str | None,
|
||||
per_token_quant: bool,
|
||||
block_shape: list[int] | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
assert (
|
||||
quant_dtype == torch.float8_e4m3fn
|
||||
or quant_dtype == torch.int8
|
||||
or quant_dtype == "nvfp4"
|
||||
), "only fp8/int8/nvfp4 supported"
|
||||
|
||||
w_gs = None
|
||||
|
||||
if block_shape is not None:
|
||||
assert not per_token_quant
|
||||
if quant_dtype == torch.int8:
|
||||
w, w_s = per_block_cast_to_int8(w, block_shape)
|
||||
elif quant_dtype == torch.float8_e4m3fn:
|
||||
w, w_s = per_block_cast_to_fp8(w, block_shape)
|
||||
elif quant_dtype == "nvfp4":
|
||||
raise RuntimeError("blocked quantization not supported for nvfp4")
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported quant type {quant_dtype}")
|
||||
else:
|
||||
if quant_dtype == torch.int8:
|
||||
w, w_s = ops.scaled_int8_quant(
|
||||
w, w_s, use_per_token_if_dynamic=per_token_quant
|
||||
)
|
||||
elif quant_dtype == torch.float8_e4m3fn:
|
||||
w, w_s = ops.scaled_fp8_quant(
|
||||
w, w_s, use_per_token_if_dynamic=per_token_quant
|
||||
)
|
||||
elif quant_dtype == "nvfp4":
|
||||
assert not per_token_quant
|
||||
w_amax = torch.abs(w).max().to(torch.float32)
|
||||
w_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w_amax
|
||||
w, w_s = ops.scaled_fp4_quant(w, w_gs)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported quant type {quant_dtype}")
|
||||
|
||||
return w, w_s, w_gs
|
||||
|
||||
|
||||
def make_test_weight(
|
||||
e: int,
|
||||
rows: int,
|
||||
cols: int,
|
||||
in_dtype: torch.dtype = torch.bfloat16,
|
||||
quant_dtype: torch.dtype | str | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
per_out_ch_quant: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
|
||||
w_gs = None
|
||||
|
||||
if quant_dtype is not None:
|
||||
w_l = [None] * e
|
||||
w_s_l = [None] * e
|
||||
w_gs_l = [None] * e
|
||||
for idx in range(e):
|
||||
w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights(
|
||||
w_16[idx], None, quant_dtype, per_out_ch_quant, block_shape
|
||||
)
|
||||
|
||||
w = torch.stack(w_l)
|
||||
w_s = torch.stack(w_s_l)
|
||||
if e > 0 and w_gs_l[0] is not None:
|
||||
w_gs = torch.stack(w_gs_l)
|
||||
if w_s.ndim == 2:
|
||||
assert w_s.shape[-1] == 1
|
||||
w_s = w_s.view(-1, 1, 1)
|
||||
|
||||
if block_shape is not None:
|
||||
block_n, block_k = block_shape
|
||||
n_tiles = (rows + block_n - 1) // block_n
|
||||
k_tiles = (cols + block_k - 1) // block_k
|
||||
assert w_s.shape == (e, n_tiles, k_tiles)
|
||||
else:
|
||||
w = w_16
|
||||
w_s = None
|
||||
w_gs = None
|
||||
|
||||
return w_16, w, w_s, w_gs
|
||||
|
||||
|
||||
def make_test_weights(
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
in_dtype: torch.dtype = torch.bfloat16,
|
||||
quant_dtype: torch.dtype | str | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
per_out_ch_quant: bool = False,
|
||||
make_gate: bool = True,
|
||||
) -> tuple[
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
|
||||
]:
|
||||
return (
|
||||
make_test_weight(
|
||||
e,
|
||||
(2 if make_gate else 1) * n,
|
||||
k,
|
||||
in_dtype,
|
||||
quant_dtype,
|
||||
block_shape,
|
||||
per_out_ch_quant,
|
||||
),
|
||||
make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, per_out_ch_quant),
|
||||
)
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(
|
||||
x: torch.Tensor, block_size: int = 128
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
pad_size = (block_size - (n % block_size)) % block_size
|
||||
x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x
|
||||
x_view = x.view(m, -1, block_size)
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
|
||||
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
|
||||
|
||||
|
||||
def make_test_quant_config(
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
in_dtype: torch.dtype,
|
||||
quant_dtype: torch.dtype | str | None = None,
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: list[int] | None = None,
|
||||
make_gate: bool = True,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]:
|
||||
(_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights(
|
||||
e,
|
||||
n,
|
||||
k,
|
||||
in_dtype,
|
||||
quant_dtype,
|
||||
per_out_ch_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
make_gate=make_gate,
|
||||
)
|
||||
|
||||
# Hacky/trivial scales for nvfp4.
|
||||
a1_gscale: torch.Tensor | None = None
|
||||
a2_gscale: torch.Tensor | None = None
|
||||
if quant_dtype == "nvfp4":
|
||||
a1_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
|
||||
a2_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
|
||||
a1_scale = a1_gscale
|
||||
a2_scale = a2_gscale
|
||||
else:
|
||||
a1_scale = None
|
||||
a2_scale = None
|
||||
|
||||
return (
|
||||
w1,
|
||||
w2,
|
||||
FusedMoEQuantConfig.make(
|
||||
quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
a1_gscale=a1_gscale,
|
||||
a2_gscale=a2_gscale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
# TODO: make sure this is handled properly
|
||||
g1_alphas=(1 / w1_gs) if w1_gs is not None else None,
|
||||
g2_alphas=(1 / w2_gs) if w2_gs is not None else None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def fused_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
score: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool = False,
|
||||
quant_config: FusedMoEQuantConfig | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
topk_weights, topk_ids, _ = fused_topk(
|
||||
hidden_states, score.float(), topk, renormalize
|
||||
)
|
||||
return fused_experts(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
|
||||
# CustomOp?
|
||||
class BaselineMM(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
b: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
):
|
||||
super().__init__()
|
||||
self.b = b.to(dtype=torch.float32)
|
||||
self.out_dtype = out_dtype
|
||||
|
||||
def forward(self, a: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
return torch.mm(a.to(dtype=torch.float32), self.b).to(self.out_dtype), None
|
||||
|
||||
|
||||
class TestMLP(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = BaselineMM(w1, out_dtype)
|
||||
self.down_proj = BaselineMM(w2, out_dtype)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
x, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(x)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
def make_naive_shared_experts(
|
||||
N: int,
|
||||
K: int,
|
||||
in_dtype: torch.dtype = torch.bfloat16,
|
||||
) -> torch.nn.Module:
|
||||
w1 = torch.randn((K, N * 2), device="cuda", dtype=in_dtype) / 15
|
||||
w2 = torch.randn((N, K), device="cuda", dtype=in_dtype) / 15
|
||||
return TestMLP(w1, w2, out_dtype=in_dtype)
|
||||
|
||||
|
||||
class RealMLP(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
hidden_act: str = "silu",
|
||||
quant_config=None,
|
||||
reduce_results: bool = True,
|
||||
prefix: str = "",
|
||||
w1_s: torch.Tensor | None = None,
|
||||
w2_s: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.gate_up_proj.register_parameter(
|
||||
"weight", torch.nn.Parameter(w1, requires_grad=False)
|
||||
)
|
||||
self.gate_up_proj.register_parameter(
|
||||
"weight_scale", torch.nn.Parameter(w1_s, requires_grad=False)
|
||||
)
|
||||
self.gate_up_proj.register_parameter(
|
||||
"input_scale", None
|
||||
) # torch.nn.Parameter(None, requires_grad=False))
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
self.down_proj.register_parameter(
|
||||
"weight", torch.nn.Parameter(w2, requires_grad=False)
|
||||
)
|
||||
self.down_proj.register_parameter(
|
||||
"weight_scale", torch.nn.Parameter(w2_s, requires_grad=False)
|
||||
)
|
||||
self.down_proj.register_parameter(
|
||||
"input_scale", None
|
||||
) # torch.nn.Parameter(None, requires_grad=False))
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
def make_shared_experts(
|
||||
N: int,
|
||||
K: int,
|
||||
in_dtype: torch.dtype = torch.bfloat16,
|
||||
quant_dtype: torch.dtype | str | None = None,
|
||||
) -> torch.nn.Module:
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
|
||||
(_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
|
||||
1,
|
||||
N,
|
||||
K,
|
||||
in_dtype=in_dtype,
|
||||
quant_dtype=quant_dtype,
|
||||
)
|
||||
old_dtype = torch.get_default_dtype()
|
||||
try:
|
||||
torch.set_default_dtype(in_dtype)
|
||||
if quant_dtype == torch.float8_e4m3fn:
|
||||
w1 = w1[0].transpose(0, 1)
|
||||
w2 = w2[0].transpose(0, 1)
|
||||
w1_s = w1_s[0].transpose(0, 1) if w1_s is not None else None
|
||||
w2_s = w2_s[0].transpose(0, 1) if w2_s is not None else None
|
||||
quant_config = Fp8Config(True)
|
||||
else:
|
||||
w1 = w1[0]
|
||||
w2 = w2[0]
|
||||
w1_s = None
|
||||
w2_s = None
|
||||
quant_config = None
|
||||
|
||||
return RealMLP(K, N, w1, w2, "silu", quant_config, w1_s=w1_s, w2_s=w2_s)
|
||||
finally:
|
||||
torch.set_default_dtype(old_dtype)
|
||||
Reference in New Issue
Block a user