From ce688181e68c103c2b10d20ef51f5f7c9e42cd93 Mon Sep 17 00:00:00 2001 From: wangjing Date: Mon, 25 Aug 2025 17:41:34 +0800 Subject: [PATCH] [gpt-oss] Add gpt-oss mxfp4 support --- .gitignore | 212 ++- README.md | 2 +- vllm/attention/layer.py | 99 +- vllm/attention/selector.py | 84 +- vllm/attention/utils/kv_sharing_utils.py | 33 + vllm/envs.py | 32 + .../layers/fused_moe/__init__.py | 26 +- .../model_executor/layers/fused_moe/config.py | 490 +++++++ .../layers/fused_moe/fused_batched_moe.py | 4 +- .../layers/fused_moe/fused_moe.py | 8 +- .../fused_moe/gpt_oss_triton_kernels_moe.py | 248 ++++ vllm/model_executor/layers/fused_moe/layer.py | 1143 ++++++++++------- .../layers/fused_moe/modular_kernel.py | 507 ++++++-- .../fused_moe/topk_weight_and_reduce.py | 146 +++ vllm/model_executor/layers/fused_moe/utils.py | 189 ++- .../layers/quantization/__init__.py | 3 + .../layers/quantization/mxfp4.py | 581 +++++++++ .../quark/schemes/quark_w4a4_mxfp4.py | 78 +- .../_downcast_to_mxfp.py | 158 +++ .../_upcast_from_mxfp.py | 136 ++ .../triton_kernels_numerics_details/mxfp.py | 303 +++++ .../layers/quantization/utils/marlin_utils.py | 15 +- .../quantization/utils/marlin_utils_fp4.py | 167 ++- .../layers/quantization/utils/mxfp4_utils.py | 142 +- .../layers/quantization/utils/quant_utils.py | 90 +- vllm/model_executor/models/gpt_oss.py | 730 +++++------ vllm/platforms/cuda.py | 120 +- vllm/utils.py | 34 + vllm/v1/attention/backends/flash_attn.py | 59 +- vllm/v1/attention/backends/flashinfer.py | 25 +- vllm/v1/attention/backends/flex_attention.py | 15 +- vllm/v1/attention/backends/pallas.py | 12 +- vllm/v1/attention/backends/triton_attn.py | 136 +- 33 files changed, 4835 insertions(+), 1192 deletions(-) create mode 100644 vllm/attention/utils/kv_sharing_utils.py create mode 100644 vllm/model_executor/layers/fused_moe/config.py create mode 100644 vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py create mode 100644 vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py create mode 100644 vllm/model_executor/layers/quantization/mxfp4.py create mode 100644 vllm/model_executor/layers/quantization/triton_kernels_numerics_details/_downcast_to_mxfp.py create mode 100644 vllm/model_executor/layers/quantization/triton_kernels_numerics_details/_upcast_from_mxfp.py create mode 100644 vllm/model_executor/layers/quantization/triton_kernels_numerics_details/mxfp.py diff --git a/.gitignore b/.gitignore index 372c13e..465935d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,212 @@ -__pycache__/ +# version file generated by setuptools-scm +/vllm/_version.py +# vllm-flash-attn built from source +vllm/vllm_flash_attn/* + +# triton jit +.triton + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +cmake-build-*/ +CMakeUserPresets.json +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST +/.deps/ + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# generated files +**/generated/** + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site +docs/argparse +docs/examples/* +!docs/examples/README.md + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ + +# VSCode +.vscode/ + +# DS Store +.DS_Store + +# Results +*.csv + +# Python pickle files +*.pkl + +# Sphinx documentation +_build/ + +# vim swap files +*.swo +*.swp + +# hip files generated by PyTorch +*.hip +*_hip* +hip_compat.h + +# Benchmark dataset +benchmarks/**/*.json + +# Linting +actionlint +shellcheck*/ + +# Ignore moe/marlin_moe gen code +csrc/moe/marlin_moe_wna16/kernel_* + +# Ignore ep_kernels_workspace folder +ep_kernels_workspace/ \ No newline at end of file diff --git a/README.md b/README.md index b08f61c..5d911c9 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ # metax-c500-vllm -1. 支持 `gpt-oss-BF16`:将 `vllm` 目录覆盖到镜像中的 `/opt/conda/lib/python3.10/site-packages/vllm` +1. 支持 `gpt-oss`:将 `vllm` 目录覆盖到镜像中的 `/opt/conda/lib/python3.10/site-packages/vllm`。运行`gpt-oss`时需指定`VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1` 2. 将 `code_generator.py` 覆盖到镜像中的 `/opt/conda/lib/python3.10/site-packages/triton/compiler/code_generator.py` diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index a5fbd1a..55a2d27 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer.""" -from typing import Any, Dict, List, Optional +from typing import List, Optional import torch import torch.nn as nn @@ -9,19 +9,49 @@ import torch.nn.functional as F import vllm.envs as envs from vllm.attention import AttentionType +from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.selector import backend_name_to_enum, get_attn_backend +from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.config import CacheConfig, get_current_vllm_config from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group, is_v1_kv_transfer_group) from vllm.forward_context import ForwardContext, get_forward_context +from vllm.logger import init_logger from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.platforms import _Backend, current_platform from vllm.utils import direct_register_custom_op -from vllm.v1.attention.backends.utils import validate_kv_sharing_target + +logger = init_logger(__name__) +USE_XFORMERS_OPS = None + + +def check_xformers_availability(): + global USE_XFORMERS_OPS + if USE_XFORMERS_OPS is not None: + return USE_XFORMERS_OPS + + if current_platform.is_cuda() and current_platform.has_device_capability( + 100): + # Xformers FA is not compatible with B200 + USE_XFORMERS_OPS = False + else: + try: + from importlib.util import find_spec + + find_spec("xformers.ops") + USE_XFORMERS_OPS = True + except ImportError: + USE_XFORMERS_OPS = False + + # the warning only needs to be shown once + if not USE_XFORMERS_OPS: + logger.warning("Xformers is not available, falling back.") + + return USE_XFORMERS_OPS class Attention(nn.Module): @@ -45,13 +75,13 @@ class Attention(nn.Module): alibi_slopes: Optional[List[float]] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, per_layer_sliding_window: Optional[int] = None, use_mla: bool = False, prefix: str = "", attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, + attn_backend: Optional[type[AttentionBackend]] = None, **extra_impl_args, ) -> None: """ @@ -80,6 +110,9 @@ class Attention(nn.Module): calculate_kv_scales = False if num_kv_heads is None: num_kv_heads = num_heads + assert num_heads % num_kv_heads == 0, \ + f"num_heads ({num_heads}) is not " \ + f"divisible by num_kv_heads ({num_kv_heads})" # The default k/v_scale is set to 1.0. This is ignored # when kv-cache is not fp8, and should be used with @@ -105,6 +138,7 @@ class Attention(nn.Module): self.head_size = head_size self.num_kv_heads = num_kv_heads self.sliding_window = sliding_window + self.has_sink = extra_impl_args.get("sinks") is not None quant_method = quant_config.get_quant_method( self, prefix=prefix) if quant_config else None @@ -126,19 +160,23 @@ class Attention(nn.Module): # During model initialization, the default dtype is set as the model # weight and activation dtype. dtype = torch.get_default_dtype() - attn_backend = get_attn_backend(head_size, - dtype, - kv_cache_dtype, - block_size, - is_attention_free, - blocksparse_params is not None, - use_mla=use_mla) - impl_cls = attn_backend.get_impl_cls() + if attn_backend is None: + self.attn_backend = get_attn_backend(head_size, + dtype, + kv_cache_dtype, + block_size, + is_attention_free, + use_mla=use_mla, + has_sink=self.has_sink) + else: + self.attn_backend = attn_backend + + impl_cls = self.attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, - blocksparse_params, logits_soft_cap, attn_type, + logits_soft_cap, attn_type, kv_sharing_target_layer_name, **extra_impl_args) - self.backend = backend_name_to_enum(attn_backend.get_name()) + self.backend = backend_name_to_enum(self.attn_backend.get_name()) self.dtype = dtype # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how @@ -148,7 +186,7 @@ class Attention(nn.Module): self.use_direct_call = not current_platform.is_cuda_alike( ) and not current_platform.is_cpu() - self.use_output = attn_backend.accept_output_buffer + self.use_output = self.attn_backend.accept_output_buffer compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") @@ -206,7 +244,7 @@ class Attention(nn.Module): if self.use_output: output_shape = (output_shape if output_shape is not None else query.shape) - output = torch.empty(output_shape, + output = torch.zeros(output_shape, dtype=query.dtype, device=query.device) hidden_size = output_shape[-1] @@ -274,6 +312,9 @@ class Attention(nn.Module): if hasattr(self.impl, "process_weights_after_loading"): self.impl.process_weights_after_loading(act_dtype) + def get_attn_backend(self) -> type[AttentionBackend]: + return self.attn_backend + class MultiHeadAttention(nn.Module): """Multi-headed attention without any cache, used for ViT.""" @@ -291,7 +332,9 @@ class MultiHeadAttention(nn.Module): self.scale = scale self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - assert self.num_heads % self.num_kv_heads == 0 + assert self.num_heads % self.num_kv_heads == 0, \ + f"num_heads ({self.num_heads}) is not " \ + f"divisible by num_kv_heads ({self.num_kv_heads})" self.num_queries_per_kv = self.num_heads // self.num_kv_heads dtype = torch.get_default_dtype() @@ -301,12 +344,21 @@ class MultiHeadAttention(nn.Module): block_size=16, is_attention_free=False) backend = backend_name_to_enum(attn_backend.get_name()) - if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}: - backend = _Backend.XFORMERS + if current_platform.is_rocm(): + # currently, only torch_sdpa is supported on rocm + self.attn_backend = _Backend.TORCH_SDPA + else: + if backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1, + _Backend.FLEX_ATTENTION): + backend = _Backend.XFORMERS - self.attn_backend = backend if backend in { - _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1 - } else _Backend.TORCH_SDPA + self.attn_backend = backend if backend in { + _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1 + } else _Backend.TORCH_SDPA + + if (self.attn_backend == _Backend.XFORMERS + and not check_xformers_availability()): + self.attn_backend = _Backend.TORCH_SDPA def forward( self, @@ -430,6 +482,7 @@ def unified_attention_with_output( value: torch.Tensor, output: torch.Tensor, layer_name: str, + output_scale: Optional[torch.Tensor] = None, ) -> None: wait_for_kv_layer_from_connector(layer_name) forward_context: ForwardContext = get_forward_context() @@ -444,7 +497,8 @@ def unified_attention_with_output( value, kv_cache, attn_metadata, - output=output) + output=output, + output_scale=output_scale) maybe_save_kv_layer_to_connector(layer_name, kv_cache) @@ -455,6 +509,7 @@ def unified_attention_with_output_fake( value: torch.Tensor, output: torch.Tensor, layer_name: str, + output_scale: Optional[torch.Tensor] = None, ) -> None: return diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index cb577fa..da0e6f4 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -3,8 +3,9 @@ import os from contextlib import contextmanager +from dataclasses import dataclass from functools import cache -from typing import Generator, Optional, Type +from typing import Generator, Optional, Type, Union import torch @@ -79,15 +80,72 @@ def get_global_forced_attn_backend() -> Optional[_Backend]: return forced_attn_backend +@dataclass(frozen=True) +class _IsSupported: + can_import: bool + head_size: bool + dtype: bool + + def __bool__(self) -> bool: + return self.can_import and self.head_size and self.dtype + + +def is_attn_backend_supported( + attn_backend: Union[str, type[AttentionBackend]], + head_size: int, + dtype: torch.dtype, + *, + allow_import_error: bool = True, +) -> _IsSupported: + if isinstance(attn_backend, str): + try: + attn_backend = resolve_obj_by_qualname(attn_backend) + except ImportError: + if not allow_import_error: + raise + + return _IsSupported(can_import=False, head_size=False, dtype=False) + + assert isinstance(attn_backend, type) + + # TODO: Update the interface once V0 is removed + if get_supported_head_sizes := getattr(attn_backend, + "get_supported_head_sizes", None): + is_head_size_supported = head_size in get_supported_head_sizes() + elif validate_head_size := getattr(attn_backend, "validate_head_size", + None): + try: + validate_head_size(head_size) + is_head_size_supported = True + except Exception: + is_head_size_supported = False + else: + raise NotImplementedError(f"{attn_backend.__name__} does not support " + "head size validation") + + if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes", + None): + is_dtype_supported = dtype in get_supported_dtypes() + else: + raise NotImplementedError(f"{attn_backend.__name__} does not support " + "dtype validation") + + return _IsSupported( + can_import=True, + head_size=is_head_size_supported, + dtype=is_dtype_supported, + ) + + def get_attn_backend( head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, - is_attention_free: bool, - is_blocksparse: bool = False, + is_attention_free: bool = False, use_mla: bool = False, -) -> Type[AttentionBackend]: + has_sink: bool = False, +) -> type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" # Accessing envs.* behind an @lru_cache decorator can cause the wrong # value to be returned from the cache if the value changes between calls. @@ -99,9 +157,9 @@ def get_attn_backend( kv_cache_dtype=kv_cache_dtype, block_size=block_size, is_attention_free=is_attention_free, - is_blocksparse=is_blocksparse, use_v1=envs.VLLM_USE_V1, use_mla=use_mla, + has_sink=has_sink, ) @@ -112,16 +170,10 @@ def _cached_get_attn_backend( kv_cache_dtype: Optional[str], block_size: int, is_attention_free: bool, - is_blocksparse: bool = False, use_v1: bool = False, use_mla: bool = False, -) -> Type[AttentionBackend]: - if is_blocksparse: - logger.info("Using BlocksparseFlashAttention backend.") - from vllm.attention.backends.blocksparse_attn import ( - BlocksparseFlashAttentionBackend) - return BlocksparseFlashAttentionBackend - + has_sink: bool = False, +) -> type[AttentionBackend]: # If there are no attention layers (e.g. we are running Mamba), # use the placeholder NO_ATTENTION if is_attention_free: @@ -144,11 +196,15 @@ def _cached_get_attn_backend( backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND if backend_by_env_var is not None: selected_backend = backend_name_to_enum(backend_by_env_var) + if selected_backend is None: + raise ValueError( + f"Invalid attention backend: '{backend_by_env_var}'. " + f"Valid backends are: {list(_Backend.__members__.keys())}") # get device-specific attn_backend attention_cls = current_platform.get_attn_backend_cls( selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, - use_mla) + use_mla, has_sink) if not attention_cls: raise ValueError( f"Invalid attention backend for {current_platform.device_name}") diff --git a/vllm/attention/utils/kv_sharing_utils.py b/vllm/attention/utils/kv_sharing_utils.py new file mode 100644 index 0000000..b4ae8bd --- /dev/null +++ b/vllm/attention/utils/kv_sharing_utils.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +def validate_kv_sharing_target(current_layer_name, target_layer_name, + static_forward_context): + error_msg = (f"Specified KV sharing target layer for {current_layer_name} " + f"is not valid: target layer {target_layer_name} ") + + if current_layer_name == target_layer_name: + raise ValueError(error_msg + + "cannot be the same as the current layer.") + + if target_layer_name not in static_forward_context: + from vllm.model_executor.models.utils import extract_layer_index + + # If target layer name is not in the static fwd context, it means either + # a) the target layer does not come BEFORE the current layer, or + # b) the target layer is not an Attention layer that exists in the model + current_layer_idx = extract_layer_index(current_layer_name) + target_layer_idx = extract_layer_index(target_layer_name) + if current_layer_idx <= target_layer_idx: + raise ValueError(error_msg + "must come before the current layer.") + else: + raise ValueError(error_msg + + "is not a valid Attention layer in the model.") + + # Currently KV sharing is only supported between layers of the same type + target_layer_attn_type = static_forward_context[ + target_layer_name].attn_type + expected = static_forward_context[current_layer_name].attn_type + if target_layer_attn_type != expected: + raise ValueError( + error_msg + + f"must be the same type as the current layer ({expected}).") diff --git a/vllm/envs.py b/vllm/envs.py index 1d35123..753ecd9 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -112,8 +112,10 @@ if TYPE_CHECKING: VLLM_DP_SIZE: int = 1 VLLM_DP_MASTER_IP: str = "" VLLM_DP_MASTER_PORT: int = 0 + VLLM_MOE_DP_CHUNK_SIZE: int = 256 VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False VLLM_MARLIN_USE_ATOMIC_ADD: bool = False + VLLM_MXFP4_USE_MARLIN: Optional[bool] = None VLLM_V0_USE_OUTLINES_CACHE: bool = False VLLM_TPU_BUCKET_PADDING_GAP: int = 0 VLLM_USE_DEEP_GEMM: bool = False @@ -128,6 +130,8 @@ if TYPE_CHECKING: VLLM_SLEEP_WHEN_IDLE: bool = False VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16 MACA_VLLM_USE_TN_2_NN: bool = True + VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False + VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False def get_default_cache_root(): return os.getenv( @@ -149,6 +153,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: return int(value) +def maybe_convert_bool(value: Optional[str]) -> Optional[bool]: + if value is None: + return None + return bool(int(value)) + + def get_vllm_port() -> Optional[int]: """Get the port from VLLM_PORT environment variable. @@ -769,6 +779,14 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_DP_MASTER_IP": lambda: os.getenv("VLLM_DP_MASTER_IP", "127.0.0.1"), + # In the context of executing MoE models with Data-Parallel, Expert-Parallel + # and Batched All-to-All dispatch/combine kernels, VLLM_MOE_DP_CHUNK_SIZE + # dictates the quantum of tokens that can be dispatched from a DP + # rank. All DP ranks process the activations in VLLM_MOE_DP_CHUNK_SIZE + # units. + "VLLM_MOE_DP_CHUNK_SIZE": + lambda: int(os.getenv("VLLM_MOE_DP_CHUNK_SIZE", "256")), + # Port of the master node in the data parallel setting "VLLM_DP_MASTER_PORT": lambda: int(os.getenv("VLLM_DP_MASTER_PORT", "0")), @@ -794,6 +812,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_MARLIN_USE_ATOMIC_ADD": lambda: os.environ.get("VLLM_MARLIN_USE_ATOMIC_ADD", "0") == "1", + # Whether to use marlin kernel in mxfp4 quantization method + "VLLM_MXFP4_USE_MARLIN": + lambda: maybe_convert_bool(os.environ.get("VLLM_MXFP4_USE_MARLIN", None)), + # Whether to turn on the outlines cache for V0 # This cache is unbounded and on disk, so it's not safe to use in # an environment with potentially malicious users. @@ -810,6 +832,16 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_DEEP_GEMM": lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))), + # If set to 1, use the FlashInfer + # MXFP8 (activation) x MXFP4 (weight) MoE backend. + "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8": + lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "0"))), + + # If set to 1, use the FlashInfer + # BF16 (activation) x MXFP4 (weight) MoE backend. + "VLLM_USE_FLASHINFER_MOE_MXFP4_BF16": + lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "0"))), + # Control the cache sized used by the xgrammar compiler. The default # of 512 MB should be enough for roughly 1000 JSON schemas. # It can be changed with this variable if needed for some reason. diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 2bdc96e..3d40879 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -4,8 +4,12 @@ from contextlib import contextmanager from typing import Any, Optional +from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute, + FusedMoEPrepareAndFinalize) from vllm.triton_utils import HAS_TRITON _config: Optional[dict[str, Any]] = None @@ -26,8 +30,12 @@ def get_config() -> Optional[dict[str, Any]]: __all__ = [ "FusedMoE", + "FusedMoEConfig", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported", + "FusedMoEPermuteExpertsUnpermute", + "FusedMoEActivationFormat", + "FusedMoEPrepareAndFinalize", "override_config", "get_config", ] @@ -36,11 +44,21 @@ if HAS_TRITON: # import to register the custom ops import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa import vllm.model_executor.layers.fused_moe.fused_moe # noqa + from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( + BatchedDeepGemmExperts) + from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 + BatchedTritonOrDeepGemmExperts) from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - cutlass_moe_fp4, cutlass_moe_fp8) + CutlassExpertsFp8, cutlass_moe_fp4, cutlass_moe_fp8) + from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( + DeepGemmExperts) + from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedTritonExperts) from vllm.model_executor.layers.fused_moe.fused_moe import ( TritonExperts, fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk) + from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( + TritonOrDeepGemmExperts) __all__ += [ "fused_moe", @@ -50,5 +68,11 @@ if HAS_TRITON: "grouped_topk", "cutlass_moe_fp8", "cutlass_moe_fp4", + "CutlassExpertsFp8", "TritonExperts", + "BatchedTritonExperts", + "DeepGemmExperts", + "BatchedDeepGemmExperts", + "TritonOrDeepGemmExperts", + "BatchedTritonOrDeepGemmExperts", ] diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py new file mode 100644 index 0000000..c19a7fe --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -0,0 +1,490 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from compressed_tensors.quantization import (QuantizationArgs, + QuantizationStrategy, + QuantizationType) + +import vllm.envs as envs +from vllm.config import ParallelConfig +from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.utils import cdiv +# from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe + +logger = init_logger(__name__) + + +def _get_quant_config_quantization_args( + quant_config: Optional[QuantizationConfig], + prop_name: str, +) -> Optional[QuantizationArgs]: + if (quant_config is not None and hasattr(quant_config, 'target_scheme_map') + and "Linear" in quant_config.target_scheme_map and + "input_activations" in quant_config.target_scheme_map["Linear"]): + return quant_config.target_scheme_map["Linear"].get(prop_name) + else: + return None + + +def get_quant_config_input_quant( + quant_config: Optional[QuantizationConfig] +) -> Optional[QuantizationArgs]: + return _get_quant_config_quantization_args(quant_config, + "input_activations") + + +def get_quant_config_weight_quant( + quant_config: Optional[QuantizationConfig] +) -> Optional[QuantizationArgs]: + return _get_quant_config_quantization_args(quant_config, "weights") + + +# TODO (bnell): use scalar_type instead of bools? +def get_config_quant_dtype( + use_fp8_w8a8: bool, + use_int8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + use_mxfp4_w4a4: bool, +) -> Union[None, torch.dtype, str]: + if use_fp8_w8a8: + return torch.float8_e4m3fn + elif use_int8_w8a8: + return torch.int8 + elif use_mxfp4_w4a4: + return "mxfp4" + return None + + +@dataclass +class FusedMoEQuantConfig: + # The post quantization activation type. + quant_dtype: Optional[torch.dtype] = None + per_act_token_quant: bool = False + per_out_ch_quant: bool = False + block_shape: Optional[list[int]] = None + + # TODO: add col major flag? + # add detailed quant info for input, intermediates, weights, etc? + + def __post_init__(self): + assert (not self.per_act_token_quant + or self.block_shape is None), "illegal quantization" + + @property + def is_quantized(self) -> bool: + return self.quant_dtype is not None + + @property + def is_per_act_token(self) -> bool: + return self.per_act_token_quant + + @property + def is_block_quantized(self) -> bool: + return self.block_shape is not None + + @property + def is_per_tensor(self) -> bool: + return not self.per_act_token_quant and self.block_shape is None + + def scale_shape( + self, + max_tokens: int, + hidden_dim: int, + ) -> Optional[tuple[int, int]]: + if self.is_quantized: + if self.is_block_quantized: + assert self.block_shape is not None + _, block_k = self.block_shape + k_tiles = cdiv(hidden_dim, block_k) + return (max_tokens, k_tiles) + elif self.is_per_act_token: + return (max_tokens, 1) + else: + return (1, 1) + else: + return None + + def batched_scale_shape( + self, + num_experts: int, + max_tokens: int, + hidden_dim: int, + ) -> Optional[tuple[int, int, int]]: + if self.is_quantized: + scale_shape = self.scale_shape(max_tokens, hidden_dim) + assert scale_shape is not None + return (num_experts, *scale_shape) + else: + return None + + @staticmethod + def make( + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, + per_act_token_quant: bool = False, + per_out_ch_quant: bool = False, + block_shape: Optional[list[int]] = None, + ) -> "FusedMoEQuantConfig": + assert sum([ + int(flag) for flag in [ + use_fp8_w8a8, + use_int8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + ] + ]) <= 1, "Quantization flags are mutually exclusive." + + quant_dtype = get_config_quant_dtype( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + use_mxfp4_w4a4=use_mxfp4_w4a4, + ) + return FusedMoEQuantConfig( + quant_dtype, + per_act_token_quant, + per_out_ch_quant, + block_shape, + ) + + +@dataclass +class FusedMoEParallelConfig: + tp_size: int + dp_size: int + ep_size: int + tp_rank: int + dp_rank: int + ep_rank: int + + use_ep: bool # whether to use EP or not + + @property + def use_all2all_kernels(self): + return self.dp_size > 1 and self.use_ep + + @property + def use_pplx_kernels(self): + return (self.use_all2all_kernels + and envs.VLLM_ALL2ALL_BACKEND == "pplx") + + @property + def use_deepep_ht_kernels(self): + return (self.use_all2all_kernels + and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput") + + @property + def use_deepep_ll_kernels(self): + return (self.use_all2all_kernels + and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") + + @property + def use_flashinfer_cutlass_kernels(self): + # return (envs.VLLM_USE_FLASHINFER_MOE_FP4 + # and has_flashinfer_cutlass_fused_moe() + # and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput") + return False + + @staticmethod + def make(tp_size_: int, dp_size_: int, + vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig": + """ + Determine MoE parallel configuration. Based on the input `tp_size_`, + `dp_size_` and vllm's parallel config, determine what + level's of parallelism to use in the fused moe layer. + + Args: + tp_size_ (int): `tp_size` passed into the FusedMoE constructor. + dp_size_ (int): `dp_size` passed into the FusedMoE constructor. + vllm_parallel_config (ParallelConfig): vLLM's parallel config + object which contains the `enable_expert_parallel` flag. + + Examples: + When there is no parallelism requested, + i.e. `tp_size_` = `dp_size_` = 1, we simply return the sizes + unaltered and the ranks set to 0. + + Expert Parallelism is considered only when either `dp_size_` or + `tp_size_` is non trivial. + + When TP = 2, DP = 1 and EP = False, the configuration on different + devices: + + - device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} // + legend : {size, rank} + - device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0} + - Comment : Tensors are sharded across 2 devices. + + When TP = 1, DP = 2 and EP = False, the configuration on different + devices: + + - device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0} + - device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0} + - Comment: There are 2 engine instances and the tensors are sharded + across 2 decvices. + + When TP = 2, DP = 2 and EP = False, the configuration on different + devices: + + - device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0} + - device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0} + - device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0} + - device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0} + - Comment: There are 2 engine instances and the tensors are sharded + across 4 devices. + + When, TP = 2, DP = 1 and EP = True, the configuration on different + devices: + + - device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0} + - device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1} + - Comment: The experts are split between the 2 devices. + + When, TP = 1, DP = 2 and EP = True, the configuration on different + devices: + + - device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0} + - device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1} + - Comment: There are 2 engine instances and the experts are split + between the 2 devices. + + When TP = 2, DP = 2 and EP = True, the configuration on different + devices: + + - device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0} + - device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1} + - device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2} + - device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3} + - Comment: There are 2 engine instances and the experts are split + between the 4 devices. + """ + + def flatten_tp_across_dp(dp_rank: int): + tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank() + # There are actually dp_size_ * tp_size_ devices. Update tp_size + # and tp_rank so we shard across all devices. + tp_size = dp_size_ * tp_size_ + tp_rank = dp_rank * tp_size_ + tp_rank + return tp_size, tp_rank + + use_ep = (dp_size_ * tp_size_ > 1 + and vllm_parallel_config.enable_expert_parallel) + + dp_size = dp_size_ + dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0 + tp_size, tp_rank = flatten_tp_across_dp(dp_rank) + + if not use_ep: + return FusedMoEParallelConfig(tp_size=tp_size, + tp_rank=tp_rank, + dp_size=dp_size, + dp_rank=dp_rank, + ep_size=1, + ep_rank=0, + use_ep=False) + # DP + EP / TP + EP / DP + TP + EP + assert use_ep + # In EP, each device owns a set of experts fully. There is no tensor + # parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that. + ep_size = tp_size + ep_rank = tp_rank + return FusedMoEParallelConfig(tp_size=1, + tp_rank=0, + dp_size=dp_size, + dp_rank=dp_rank, + ep_size=ep_size, + ep_rank=ep_rank, + use_ep=True) + + +# Adapted from pplx-kernels tests/all_to_all_utils.py +@dataclass +class FusedMoEConfig: + num_experts: int + experts_per_token: int + hidden_dim: int + + num_local_experts: int + moe_parallel_config: FusedMoEParallelConfig + + # The activation type. + in_dtype: torch.dtype + + quant_config: Optional[FusedMoEQuantConfig] = None + + max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE + + has_bias: bool = False + + def __post_init__(self): + if self.dp_size > 1: + logger.debug_once("Using FusedMoEConfig::max_num_tokens=%d", + self.max_num_tokens) + + assert self.max_num_tokens > 0 + + @property + def quant_dtype(self) -> Optional[torch.dtype]: + if self.quant_config is not None: + return self.quant_config.quant_dtype + else: + return None + + @property + def block_shape(self) -> Optional[list[int]]: + if self.quant_config is not None: + return self.quant_config.block_shape + else: + return None + + @property + def per_act_token_quant(self) -> bool: + if self.quant_config is not None: + return self.quant_config.per_act_token_quant + else: + return False + + @property + def per_out_ch_quant(self) -> bool: + if self.quant_config is not None: + return self.quant_config.per_out_ch_quant + else: + return False + + @property + def tp_size(self): + return self.moe_parallel_config.tp_size + + @property + def dp_size(self): + return self.moe_parallel_config.dp_size + + @property + def ep_size(self): + return self.moe_parallel_config.ep_size + + @property + def tp_rank(self): + return self.moe_parallel_config.tp_rank + + @property + def dp_rank(self): + return self.moe_parallel_config.dp_rank + + @property + def ep_rank(self): + return self.moe_parallel_config.ep_rank + + @property + def use_ep(self): + return self.moe_parallel_config.use_ep + + @property + def use_pplx_kernels(self): + return self.moe_parallel_config.use_pplx_kernels + + @property + def use_deepep_ht_kernels(self): + return self.moe_parallel_config.use_deepep_ht_kernels + + @property + def use_deepep_ll_kernels(self): + return self.moe_parallel_config.use_deepep_ll_kernels + + @property + def use_flashinfer_cutlass_kernels(self): + return self.moe_parallel_config.use_flashinfer_cutlass_kernels + + @staticmethod + def make( + num_experts: int, + experts_per_token: int, + hidden_dim: int, + num_local_experts: int, + moe_parallel_config: FusedMoEParallelConfig, + in_dtype: torch.dtype, + max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE, + quant_config: Optional[Union[FusedMoEQuantConfig, + QuantizationConfig]] = None, + has_bias: bool = False, + ) -> "FusedMoEConfig": + + _quant_config: Optional[FusedMoEQuantConfig] = None + + if quant_config is not None and isinstance(quant_config, + QuantizationConfig): + if hasattr(quant_config, 'weight_block_size'): + block_shape = quant_config.weight_block_size + else: + block_shape = None + per_act_token_quant = False + per_out_ch_quant = False + quant_dtype: Optional[torch.dtype] = None + + input_quant = get_quant_config_input_quant(quant_config) + weight_quant = get_quant_config_weight_quant(quant_config) + + if input_quant is not None: + per_act_token_quant = (input_quant.strategy + == QuantizationStrategy.TOKEN + if input_quant is not None else False) + + if input_quant.num_bits == 8: + if input_quant.type == QuantizationType.FLOAT: + quant_dtype = torch.float8_e4m3fn + elif input_quant.type == QuantizationType.INT: + quant_dtype = torch.int8 + + from vllm.model_executor.layers.quantization.fp8 import Fp8Config + if quant_dtype is None and isinstance(quant_config, Fp8Config): + quant_dtype = torch.float8_e4m3fn + + from vllm.model_executor.layers.quantization.modelopt import ( + ModelOptNvFp4Config) + if quant_dtype is None and isinstance(quant_config, + ModelOptNvFp4Config): + quant_dtype = torch.uint8 + + if weight_quant is not None: + per_out_ch_quant = ( + weight_quant.strategy == QuantizationStrategy.CHANNEL) + + if quant_dtype is not None: + _quant_config = FusedMoEQuantConfig( + quant_dtype=quant_dtype, + per_act_token_quant=per_act_token_quant, + per_out_ch_quant=per_out_ch_quant, + block_shape=block_shape, + ) + else: + _quant_config = FusedMoEQuantConfig() + if moe_parallel_config.dp_size > 1: + logger.warning_once("MoE DP setup unable to determine " + "quantization scheme or unsupported " + "quantization type. This model will " + "not run with DP enabled.") + else: + _quant_config = quant_config + + return FusedMoEConfig( + num_experts=num_experts, + experts_per_token=experts_per_token, + hidden_dim=hidden_dim, + num_local_experts=num_local_experts, + moe_parallel_config=moe_parallel_config, + in_dtype=in_dtype, + quant_config=_quant_config, + max_num_tokens=max_num_tokens, + has_bias=has_bias, + ) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 68a3485..72c62ca 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -752,8 +752,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( A=intermediate_cache2, A_scale=a2_scale, - qtype=torch.float8_e4m3fn if self.use_fp8_w8a8 else None, - per_channel_quant=self.per_channel_quant, + quant_dtype=torch.float8_e4m3fn if self.use_fp8_w8a8 else None, + per_act_token_quant=self.per_channel_quant, block_shape=self.block_shape) qintermediate_cache2 = qintermediate_cache2.view( diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index a16b66a..ba8f209 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1503,8 +1503,8 @@ def fused_experts_impl( qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input( A=curr_hidden_states, A_scale=a1_scale, - qtype=qtype, - per_channel_quant=per_channel_quant, + quant_dtype=qtype, + per_act_token_quant=per_channel_quant, block_shape=block_shape) invoke_fused_moe_kernel(qcurr_hidden_states, @@ -1562,8 +1562,8 @@ def fused_experts_impl( qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( A=intermediate_cache2, A_scale=a2_scale, - qtype=qtype, - per_channel_quant=per_channel_quant, + quant_dtype=qtype, + per_act_token_quant=per_channel_quant, block_shape=block_shape) invoke_fused_moe_kernel(qintermediate_cache2, diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py new file mode 100644 index 0000000..6b5284d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -0,0 +1,248 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING, Any, Optional + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceDelegate) +from vllm.model_executor.layers.fused_moe.utils import extract_required_args +from vllm.utils import has_triton_kernels + +logger = init_logger(__name__) + +if has_triton_kernels(): + try: + import triton_kernels.swiglu + from triton_kernels.matmul_ogs import (FnSpecs, FusedActivation, + matmul_ogs) + from triton_kernels.routing import routing + except ModuleNotFoundError: + logger.error( + "Failed to import Triton kernels. Please make sure your triton " + "version is compatible.") + +if TYPE_CHECKING: + from triton_kernels.matmul_ogs import PrecisionConfig + + +def triton_kernel_moe_forward( + hidden_states: torch.Tensor, + w1, # Tensor or triton_kernels.Tensor + w2, # Tensor or triton_kernels.Tensor + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, + w1_precision: Optional["PrecisionConfig"] = None, + w2_precision: Optional["PrecisionConfig"] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, +) -> torch.Tensor: + + routing_data, gather_idx, scatter_idx = routing(gating_output, + topk, + sm_first=not renormalize) + + return triton_kernel_fused_experts( + None, + hidden_states, + w1, + w2, + routing_data, + gather_idx, + scatter_idx, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + use_fp8_w8a8=use_fp8_w8a8, + per_channel_quant=per_channel_quant, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_bias=w1_bias, + w2_bias=w2_bias, + w1_precision=w1_precision, + w2_precision=w2_precision, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_shape) + + +# This is a triton implementation of the fused_experts function +def triton_kernel_fused_experts( + output_tensor: torch.Tensor, + hidden_states: torch.Tensor, + w1, # Tensor or triton_kernels.Tensor + w2, # Tensor or triton_kernels.Tensor + routing_data, # RoutingData + gather_indx, # GatherIndx + scatter_indx, # ScatterIndx + activation: str = "silu", + swiglu_alpha: float = 1.702, + swiglu_limit: float = 7.0, + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, + w1_precision: Optional["PrecisionConfig"] = None, + w2_precision: Optional["PrecisionConfig"] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, +) -> torch.Tensor: + + # type check, uint8 means mxfp4 + assert hidden_states.dtype == torch.bfloat16 + assert w1_bias is None or w1_bias.dtype == torch.float32 + assert w2_bias is None or w2_bias.dtype == torch.float32 + + # Shape check, only check non-mxfp4 + assert hidden_states.shape[-1] == w1.shape[-2] + assert w2.shape[-1] == w1.shape[1] + + E, _, N = w1.shape + + if global_num_experts == -1: + global_num_experts = E + + act = FusedActivation( + FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")), + (swiglu_alpha, swiglu_limit), 2) + gammas = routing_data.gate_scal if routing_data else None + + intermediate_cache1 = matmul_ogs( + hidden_states, + w1, + w1_bias, + routing_data, + gather_indx=gather_indx, + precision_config=w1_precision, + gammas=gammas if apply_router_weight_on_input else None, + fused_activation=act) + + intermediate_cache3 = matmul_ogs( + intermediate_cache1, + w2, + w2_bias, + routing_data, + scatter_indx=scatter_indx, + precision_config=w2_precision, + gammas=None if apply_router_weight_on_input else gammas, + y=output_tensor, + ) + return intermediate_cache3 + + +class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__( + self, + quant_config, + max_num_tokens: int, + num_dispatchers: int, + w1_precision: "PrecisionConfig", + w2_precision: "PrecisionConfig", + ): + super().__init__(quant_config) + self.max_num_tokens = max_num_tokens + self.num_dispatchers = num_dispatchers + self.w1_precision = w1_precision + self.w2_precision = w2_precision + + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return (mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts) + + def supports_chunking(self) -> bool: + return False + + def supports_expert_map(self) -> bool: + return False + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # Let PrepareAndFinalize::finalize() decide the impl. + return TopKWeightAndReduceDelegate() + + def workspace_shapes( + self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int, + topk: int, global_num_experts: int, local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata] + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + # workspace are allocated inside the kernel + assert a.dim() == 2 + num_dp = self.num_dispatchers + num_experts = local_num_experts + max_num_tokens = self.max_num_tokens + workspace2 = (0, 0, 0) + output = (num_experts, max_num_tokens * num_dp, N) + return (output, workspace2, output, a.dtype) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]], + ): + w1_bias, w2_bias = (extract_required_args(extra_expert_args, + ["w1_bias", "w2_bias"])) + + return triton_kernel_fused_experts( + output, + hidden_states, + w1, + w2, + None, + None, + None, + activation=activation, + apply_router_weight_on_input=False, + use_fp8_w8a8=False, + per_channel_quant=False, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_bias=w1_bias, + w2_bias=w2_bias, + w1_precision=self.w1_precision, + w2_precision=self.w2_precision, + a1_scale=a1q_scale, + a2_scale=a2_scale) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 337e0c0..a042871 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1,57 +1,57 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import importlib from abc import abstractmethod -from dataclasses import dataclass +from collections.abc import Iterable from enum import Enum -from typing import Callable, Optional, Union +from typing import Callable, Literal, Optional, overload import torch import torch.nn.functional as F -from compressed_tensors.quantization import (QuantizationArgs, - QuantizationStrategy, - QuantizationType) from torch.nn.parameter import UninitializedParameter import vllm.envs as envs -from vllm.config import ParallelConfig, get_current_vllm_config +from vllm.config import get_current_vllm_config from vllm.distributed import (get_dp_group, get_ep_group, - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) +# from vllm.distributed.eplb.eplb_state import EplbState from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp +# yapf: disable +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, FusedMoEParallelConfig) +# yapf: enable +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEActivationFormat, FusedMoEModularKernel, + FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize) from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( is_rocm_aiter_moe_enabled) +# from vllm.model_executor.layers.fused_moe.routing_simulator import ( +# RoutingSimulator) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum -from vllm.utils import direct_register_custom_op - -has_pplx = importlib.util.find_spec("pplx_kernels") is not None -has_deepep = importlib.util.find_spec("deep_ep") is not None - -if has_deepep: - try: - import deep_ep - except ImportError: - has_deepep = False +from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx, + round_up) +# from vllm.utils.flashinfer import has_flashinfer if current_platform.is_cuda_alike(): from .fused_batched_moe import BatchedTritonExperts from .fused_moe import TritonExperts, fused_experts - from .modular_kernel import (FusedMoEModularKernel, - FusedMoEPermuteExpertsUnpermute, - FusedMoEPrepareAndFinalize) - if has_pplx: - from .pplx_prepare_finalize import PplxPrepareAndFinalize - if has_deepep: + if has_pplx(): + from .pplx_prepare_finalize import (PplxPrepareAndFinalize, + pplx_hidden_dim_scale_bytes) + if has_deep_ep(): from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize - from .deepep_ll_prepare_finalize import DeepEPLLPrepareAndFinalize + from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE, + DeepEPLLPrepareAndFinalize) + # if has_flashinfer(): + # from .flashinfer_cutlass_prepare_finalize import ( + # FlashInferCutlassMoEPrepareAndFinalize) else: fused_experts = None # type: ignore FusedMoEPermuteExpertsUnpermute = None # type: ignore @@ -59,215 +59,17 @@ else: if is_rocm_aiter_moe_enabled(): from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 rocm_aiter_grouped_topk as grouped_topk) +elif current_platform.is_cpu(): + pass else: from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk if current_platform.is_tpu(): from .moe_pallas import fused_moe as fused_moe_pallas else: fused_moe_pallas = None # type: ignore + logger = init_logger(__name__) -# Note: this limit is somewhat arbitrary and might be changed later. -# The size of the activations will be E x MOE_DP_CHUNK_SIZE x hidden_dim. -MOE_DP_CHUNK_SIZE = 256 - - -@dataclass -class FusedMoEParallelConfig: - tp_size: int - dp_size: int - ep_size: int - tp_rank: int - dp_rank: int - ep_rank: int - - use_ep: bool # whether to use EP or not - - @property - def use_all2all_kernels(self): - return self.dp_size > 1 and self.use_ep - - @property - def use_pplx_kernels(self): - return (self.use_all2all_kernels - and envs.VLLM_ALL2ALL_BACKEND == "pplx") - - @property - def use_deepep_ht_kernels(self): - return (self.use_all2all_kernels - and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput") - - @property - def use_deepep_ll_kernels(self): - return (self.use_all2all_kernels - and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") - - @staticmethod - def make(tp_size_: int, dp_size_: int, - vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig": - """ - Determine MoE parallel configuration. Based on the input tp_size_, - dp_size_, ep_size_ and vllm's parallel config, determine what - level's of parallelism to use in the fused moe layer. - - Args: - tp_size_ (int): tp_size passed into the FusedMoE constructor. - dp_size_ (int): dp_size passed into the FusedMoE constructor. - ep_size_ (int): ep_size passed into the FusedMoE constructor. - vllm_parallel_config (ParallelConfig): vllm's parallel config - object. - - Examples: - When there is no parallelism requested, i.e. tp_size_ = dp_size_ = 1, - we simply return the sizes unaltered and the ranks set to 0. - - Expert Parallelism is considered only when either dp_size_ or tp_size_ - is non trivial. - - When TP = 2, DP = 1 and EP = False, the configuration on different - devices, - - device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} // - legend : {size, rank} - - device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0} - - Comment : Tensors are sharded across 2 devices. - - When TP = 1, DP = 2 and EP = False, the configuration on different - devices, - - device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0} - - device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0} - - Comment: There are 2 engine instances and the tensors are sharded - across 2 decvices. - - When TP = 2, DP = 2 and EP = False, the configuration on different - devices, - - device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0} - - device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0} - - device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0} - - device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0} - - Comment: There are 2 engine instances and the tensors are sharded - across 4 devices. - - When, TP = 2, DP = 1 and EP = True, the configuration on different - devices, - - device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0} - - device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1} - - Comment: The experts are split between the 2 devices. - - When, TP = 1, DP = 2 and EP = True, the configuration on different - devices, - - device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0} - - device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1} - - Comment: There are 2 engine instances and the experts are split - between the 2 devices. - - When TP = 2, DP = 2 and EP = True, the configuration on different - devices, - - device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0} - - device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1} - - device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2} - - device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3} - - Comment: There are 2 engine instances and the experts are split - between the 4 devices. - """ - - def flatten_tp_across_dp(dp_rank: int): - tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank() - # There are actually dp_size_ * tp_size_ devices. Update tp_size - # and tp_rank so we shard across all devices. - tp_size = dp_size_ * tp_size_ - tp_rank = dp_rank * tp_size_ + tp_rank - return tp_size, tp_rank - - use_ep = (dp_size_ * tp_size_ > 1 - and vllm_parallel_config.enable_expert_parallel) - - dp_size = dp_size_ - dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0 - tp_size, tp_rank = flatten_tp_across_dp(dp_rank) - - if not use_ep: - return FusedMoEParallelConfig(tp_size=tp_size, - tp_rank=tp_rank, - dp_size=dp_size, - dp_rank=dp_rank, - ep_size=1, - ep_rank=0, - use_ep=False) - # DP + EP / TP + EP / DP + TP + EP - assert use_ep - # In EP, each device owns a set of experts fully. There is no tensor - # parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that. - ep_size = tp_size - ep_rank = tp_rank - return FusedMoEParallelConfig(tp_size=1, - tp_rank=0, - dp_size=dp_size, - dp_rank=dp_rank, - ep_size=ep_size, - ep_rank=ep_rank, - use_ep=True) - - -# Adapted from pplx-kernels tests/all_to_all_utils.py -@dataclass -class MoEConfig: - num_experts: int - experts_per_token: int - hidden_dim: int - - num_local_experts: int - moe_parallel_config: FusedMoEParallelConfig - - in_dtype: torch.dtype # The activation type. - quant_dtype: torch.dtype = None - - # TODO: add more quantization params, blocked, per-token, etc. - block_size: int = 128 - - max_num_tokens: int = MOE_DP_CHUNK_SIZE - - has_bias: bool = False - - @property - def tp_size(self): - return self.moe_parallel_config.tp_size - - @property - def dp_size(self): - return self.moe_parallel_config.dp_size - - @property - def ep_size(self): - return self.moe_parallel_config.ep_size - - @property - def tp_rank(self): - return self.moe_parallel_config.tp_rank - - @property - def dp_rank(self): - return self.moe_parallel_config.dp_rank - - @property - def ep_rank(self): - return self.moe_parallel_config.ep_rank - - @property - def use_ep(self): - return self.moe_parallel_config.use_ep - - @property - def use_pplx_kernels(self): - return self.moe_parallel_config.use_pplx_kernels - - @property - def use_deepep_ht_kernels(self): - return self.moe_parallel_config.use_deepep_ht_kernels - - @property - def use_deepep_ll_kernels(self): - return self.moe_parallel_config.use_deepep_ll_kernels - class FusedMoeWeightScaleSupported(Enum): TENSOR = "tensor" @@ -276,21 +78,9 @@ class FusedMoeWeightScaleSupported(Enum): BLOCK = "block" -def get_quant_config_input_activations( - quant_config: Optional[QuantizationConfig] -) -> Optional[QuantizationArgs]: - if (quant_config is not None and hasattr(quant_config, 'target_scheme_map') - and "Linear" in quant_config.target_scheme_map and - "input_activations" in quant_config.target_scheme_map["Linear"]): - return quant_config.target_scheme_map["Linear"].get( - "input_activations") - else: - return None - - class FusedMoEMethodBase(QuantizeMethodBase): - moe: MoEConfig + moe: FusedMoEConfig @abstractmethod def create_weights(self, layer: torch.nn.Module, num_experts: int, @@ -298,23 +88,37 @@ class FusedMoEMethodBase(QuantizeMethodBase): params_dtype: torch.dtype, **extra_weight_attrs): raise NotImplementedError - def init_prepare_finalize(self, moe: MoEConfig, - quant_config: Optional[QuantizationConfig]): + def uses_weight_scale_2_pattern(self) -> bool: + """ + Returns True if this quantization method uses 'weight_scale_2' pattern + for per-tensor weight scales (e.g., FP4 variants), False otherwise. + + This method should be overridden by subclasses that use the + 'weight_scale_2' pattern instead of the standard 'weight_scale' pattern. + """ + return False + + @staticmethod + def maybe_make_prepare_finalize( + moe: FusedMoEConfig) -> Optional[FusedMoEPrepareAndFinalize]: all2all_manager = get_ep_group().device_communicator.all2all_manager assert all2all_manager is not None - self.moe = moe - quant_dtype = None - act_quant_block_size = None - from vllm.model_executor.layers.quantization.fp8 import Fp8Config - if isinstance(quant_config, Fp8Config): - act_quant_block_size = quant_config.weight_block_size - quant_dtype = torch.float8_e4m3fn + prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None - prepare_finalize: Optional[Union[PplxPrepareAndFinalize, - DeepEPHTPrepareAndFinalize, - DeepEPLLPrepareAndFinalize]] = None + if moe.use_flashinfer_cutlass_kernels: + prepare_finalize = FlashInferCutlassMoEPrepareAndFinalize( + quant_dtype=moe.quant_dtype, ) if moe.use_pplx_kernels: + hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes( + moe.max_num_tokens, + moe.hidden_dim, + moe.in_dtype, + moe.quant_dtype, + per_act_token_quant=moe.per_act_token_quant, + block_shape=moe.block_shape, + ) + all_to_all_args = dict( max_num_tokens=moe.max_num_tokens, num_experts=moe.num_experts, @@ -324,16 +128,13 @@ class FusedMoEMethodBase(QuantizeMethodBase): # dp_size actually means tp_size, bug in pplx kernels dp_size=all2all_manager.tp_group.world_size, hidden_dim=moe.hidden_dim, - hidden_dim_bytes=moe.hidden_dim * moe.quant_dtype.itemsize, - # For blocked per token: set to - # ceil_div(hidden_dim, block_size) * sizeof(float32) - # For per-token: set to sizeof(float32) - hidden_dim_scale_bytes=( - 0 if moe.quant_dtype.itemsize != 1 else - ((moe.hidden_dim + moe.block_size - 1) // moe.block_size * - torch.float32.itemsize)), + hidden_dim_bytes=hidden_dim_bytes, + hidden_dim_scale_bytes=hidden_scale_bytes, ) + num_dispatchers = (all2all_manager.world_size // + all2all_manager.tp_group.world_size) + # Intranode pplx a2a takes a group name while internode does not. if not all2all_manager.internode: all_to_all_args[ @@ -341,20 +142,11 @@ class FusedMoEMethodBase(QuantizeMethodBase): handle = all2all_manager.get_handle(all_to_all_args) - input_activations = get_quant_config_input_activations( - quant_config) - prepare_finalize = PplxPrepareAndFinalize( handle, max_num_tokens=moe.max_num_tokens, - world_size=all2all_manager.world_size, - rank=all2all_manager.rank, - # dp_size actually means tp_size, bug in pplx kernels - dp_size=all2all_manager.tp_group.world_size, - quant_dtype=moe.quant_dtype, - per_act_token=(input_activations.strategy - == QuantizationStrategy.TOKEN - if input_activations is not None else False), + num_local_experts=moe.num_local_experts, + num_dispatchers=num_dispatchers, ) elif moe.use_deepep_ht_kernels: assert moe.dp_size == all2all_manager.dp_world_size @@ -363,18 +155,13 @@ class FusedMoEMethodBase(QuantizeMethodBase): handle = all2all_manager.get_handle(all_to_all_args) prepare_finalize = DeepEPHTPrepareAndFinalize( handle, - world_size=all2all_manager.world_size, - rank=all2all_manager.rank, + num_dispatchers=all2all_manager.world_size, dp_size=all2all_manager.dp_world_size, rank_expert_offset=all2all_manager.rank * moe.num_local_experts, - quant_dtype=quant_dtype, - block_shape=act_quant_block_size, ) elif moe.use_deepep_ll_kernels: - assert moe.dp_size == all2all_manager.dp_world_size - all_to_all_args = dict( max_num_tokens_per_dp_rank=moe.max_num_tokens, token_hidden_size=moe.hidden_dim, @@ -384,35 +171,54 @@ class FusedMoEMethodBase(QuantizeMethodBase): all2all_manager.world_size) handle = all2all_manager.get_handle(all_to_all_args) - # Note (varun): Whether to use FP8 dispatch or not needs some - # profiling. Turning it off for now. + # Note : We may want to use FP8 dispatch even otherwise just to + # reduce datamovement + use_fp8_dispatch = (moe.quant_config is not None + and moe.quant_config.quant_dtype + == current_platform.fp8_dtype() + and moe.quant_config.block_shape + == DEEPEP_QUANT_BLOCK_SHAPE) + prepare_finalize = DeepEPLLPrepareAndFinalize( handle, - world_size=all2all_manager.world_size, - dp_size=all2all_manager.dp_world_size, max_tokens_per_rank=moe.max_num_tokens, - quant_dtype=quant_dtype, - block_shape=act_quant_block_size, - use_fp8_dispatch=False, + num_dispatchers=all2all_manager.world_size, + use_fp8_dispatch=use_fp8_dispatch, ) + return prepare_finalize + + def init_prepare_finalize(self, moe: FusedMoEConfig): + self.moe = moe + prepare_finalize = FusedMoEMethodBase.maybe_make_prepare_finalize( + self.moe) + self.topk_indices_dtype = None if prepare_finalize is not None: + logger.debug("%s", prepare_finalize.__class__.__name__) self.topk_indices_dtype = prepare_finalize.topk_indices_dtype() - experts = self.select_gemm_impl(prepare_finalize, moe) + experts = self.select_gemm_impl(prepare_finalize, self.moe) self.fused_experts = FusedMoEModularKernel( prepare_finalize, experts, ) def select_gemm_impl( - self, prepare_finalize: FusedMoEPrepareAndFinalize, - moe: Optional[MoEConfig]) -> FusedMoEPermuteExpertsUnpermute: + self, + prepare_finalize: FusedMoEPrepareAndFinalize, + moe: FusedMoEConfig, + ) -> FusedMoEPermuteExpertsUnpermute: # based on the all2all implementation, select the appropriate # gemm implementation raise NotImplementedError( - "Subclass must select appropriate gemm implementation" - " based on the prepare_finalize") + f"{self.__class__.__name__} must select appropriate gemm " + "implementation based on the prepare_finalize") + + def maybe_swap_experts_impl( + self, + moe_parallel_config: FusedMoEParallelConfig, + ): + pass @abstractmethod def apply( @@ -432,6 +238,10 @@ class FusedMoEMethodBase(QuantizeMethodBase): e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError @@ -440,13 +250,12 @@ class FusedMoEMethodBase(QuantizeMethodBase): class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" - def __init__(self, moe: MoEConfig): + def __init__(self, moe: FusedMoEConfig): super().__init__() self.fused_experts = fused_experts # type: ignore self.topk_indices_dtype = None self.moe = moe self.has_bias = self.moe.has_bias - self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() if self.rocm_aiter_moe_enabled: from .rocm_aiter_fused_moe import rocm_aiter_fused_experts @@ -454,44 +263,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): else: self.rocm_aiter_fused_experts = None # type: ignore - def select_gemm_impl(self, prepare_finalize: FusedMoEPrepareAndFinalize, - moe: Optional[MoEConfig]): - - assert self.fused_experts == fused_experts - - all2all_manager = get_ep_group().device_communicator.all2all_manager - assert all2all_manager is not None - - experts: Optional[FusedMoEPermuteExpertsUnpermute] = None - - use_batched_experts = prepare_finalize.max_num_tokens_per_rank( - ) is not None - if use_batched_experts: + def select_gemm_impl( + self, + prepare_finalize: FusedMoEPrepareAndFinalize, + moe: FusedMoEConfig, + ) -> FusedMoEPermuteExpertsUnpermute: + if (prepare_finalize.activation_format == + FusedMoEActivationFormat.BatchedExperts): logger.debug("BatchedTritonExperts %s", self.moe) - assert self.moe.dp_size == all2all_manager.dp_world_size - experts = BatchedTritonExperts( + return BatchedTritonExperts( max_num_tokens=self.moe.max_num_tokens, - world_size=all2all_manager.world_size, - # dp_size actually means tp_size, bug in pplx kernels - dp_size=all2all_manager.tp_group.world_size, - use_fp8_w8a8=False, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - block_shape=None, - per_channel_quant=False, + num_dispatchers=prepare_finalize.num_dispatchers(), ) else: logger.debug("TritonExperts %s", self.moe) - experts = TritonExperts( - use_fp8_w8a8=False, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - block_shape=None, - per_channel_quant=False, - ) - return experts + return TritonExperts() def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -513,7 +299,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): requires_grad=False) layer.register_parameter("w13_bias", w13_bias) set_weight_attrs(w13_bias, extra_weight_attrs) - # down_proj (row parallel) w2_weight = torch.nn.Parameter(torch.empty( num_experts, @@ -559,14 +344,32 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): layer.w13_weight.data = shuffled_w13 layer.w2_weight.data = shuffled_w2 - if current_platform.is_cpu(): + if current_platform.is_xpu(): + import intel_extension_for_pytorch as ipex + layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( + layer.w13_weight, + layer.w2_weight, + use_prepack=True, + ) + elif current_platform.is_cpu(): if current_platform.get_cpu_architecture() == CpuArchEnum.X86: - import intel_extension_for_pytorch as ipex - layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( - layer.w13_weight, - layer.w2_weight, - use_prepack=envs.VLLM_CPU_MOE_PREPACK, - ) + from vllm.model_executor.layers.fused_moe import cpu_fused_moe + dtype = layer.w13_weight.dtype + if (envs.VLLM_CPU_SGL_KERNEL + and torch._C._cpu._is_amx_tile_supported() + and dtype == torch.bfloat16): + packed_w13_weight = torch.ops._C.convert_weight_packed( + layer.w13_weight) + assert packed_w13_weight.size() == layer.w13_weight.size() + layer.w13_weight.copy_(packed_w13_weight) + del packed_w13_weight + packed_w2_weight = torch.ops._C.convert_weight_packed( + layer.w2_weight) + assert packed_w2_weight.size() == layer.w2_weight.size() + layer.w2_weight.copy_(packed_w2_weight) + layer.cpu_fused_moe = cpu_fused_moe.SGLFusedMOE(layer) + else: + layer.cpu_fused_moe = cpu_fused_moe.IPEXFusedMOE(layer) else: raise NotImplementedError("CPU MOE only supports x86 arch.") @@ -587,7 +390,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + assert expert_load_view is not None + assert logical_to_physical_map is not None + assert logical_replica_count is not None + assert isinstance(layer, FusedMoE) + return self.forward( x=x, layer=layer, @@ -603,7 +416,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input) + apply_router_weight_on_input=apply_router_weight_on_input, + enable_eplb=enable_eplb, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) def forward_cuda( self, @@ -622,6 +440,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: topk_weights, topk_ids = FusedMoE.select_experts( @@ -635,25 +457,29 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) + indices_type=self.topk_indices_dtype, + enable_eplb=enable_eplb, + expert_map=expert_map, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count) if self.rocm_aiter_moe_enabled: - assert expert_map is None return self.rocm_aiter_fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, + expert_map=expert_map, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input) else: - return self.fused_experts( + # add w1_bias/w2_bias to kwargs if they exist + kwargs = dict( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, - w1_bias=layer.w13_bias if self.has_bias else None, - w2_bias=layer.w2_bias if self.has_bias else None, topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, @@ -662,6 +488,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): global_num_experts=global_num_experts, expert_map=expert_map, ) + if isinstance(self.fused_experts, + FusedMoEModularKernel) and self.has_bias: + raise ValueError( + "FusedMoEModularKernel does not support bias.") + if self.has_bias: + kwargs.update({ + "w1_bias": getattr(layer, "w13_bias", None), + "w2_bias": getattr(layer, "w2_bias", None), + }) + + return self.fused_experts(**kwargs) def forward_cpu( self, @@ -678,13 +515,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, - activation: str = "silu", apply_router_weight_on_input: bool = False, - **kwargs, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ): - assert activation == "silu", f"{activation} is not supported." - assert apply_router_weight_on_input is False - return layer.ipex_fusion( + if enable_eplb is not False or expert_load_view is not None or \ + logical_to_physical_map is not None or \ + logical_replica_count is not None: + raise NotImplementedError("Expert load balancing is not supported " + "for CPU.") + return layer.cpu_fused_moe( + layer, x, use_grouped_topk, top_k, @@ -692,12 +536,16 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): renormalize, topk_group, num_expert_group, + global_num_experts, + expert_map, custom_routing_function, scoring_func, e_score_correction_bias, + apply_router_weight_on_input, + activation, ) - def forward_hpu( + def forward_xpu( self, layer: torch.nn.Module, x: torch.Tensor, @@ -714,21 +562,26 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", - ) -> torch.Tensor: - assert not use_grouped_topk - assert num_expert_group is None - assert topk_group is None + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ): + if enable_eplb is not False or expert_load_view is not None or \ + logical_to_physical_map is not None or \ + logical_replica_count is not None: + raise NotImplementedError("Expert load balancing is not supported " + "for XPU.") assert custom_routing_function is None - assert layer is not None - assert apply_router_weight_on_input is False - if scoring_func != "softmax": - raise NotImplementedError( - "Only softmax scoring function is supported for HPU.") - if e_score_correction_bias is not None: - raise NotImplementedError( - "Expert score correction bias is not supported for HPU.") - return layer.hpu_fused_moe(x, layer.w13_weight, layer.w2_weight, - router_logits, top_k) + return layer.ipex_fusion( + x, + use_grouped_topk, + top_k, + router_logits, + renormalize, + topk_group, + num_expert_group, + ) def forward_tpu( self, @@ -747,6 +600,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: assert not use_grouped_topk assert num_expert_group is None @@ -760,6 +617,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): raise NotImplementedError( "Expert score correction bias is not supported for TPU.") assert activation == "silu", f"{activation} is not supported for TPU." + if enable_eplb is not False or expert_load_view is not None or \ + logical_to_physical_map is not None or \ + logical_replica_count is not None: + raise NotImplementedError("Expert load balancing is not supported " + "for TPU.") return fused_moe_pallas(hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -769,7 +631,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_map=expert_map, renormalize=renormalize) - forward_native = forward_tpu if current_platform.is_tpu() else forward_cuda + if current_platform.is_tpu(): + forward_native = forward_tpu + elif current_platform.is_cpu(): + forward_native = forward_cpu + else: + forward_native = forward_cuda def determine_expert_map( @@ -798,26 +665,25 @@ def determine_expert_map( if ep_size == 1: return (global_num_experts, None) - local_num_experts = global_num_experts // ep_size + # Distribute experts as evenly as possible to each rank. + base_experts = global_num_experts // ep_size + remainder = global_num_experts % ep_size + if ep_rank < remainder: + local_num_experts = base_experts + 1 + else: + local_num_experts = base_experts # Create a tensor of size num_experts filled with -1 expert_map = torch.full((global_num_experts, ), -1, dtype=torch.int32) # Create a expert map for the local experts - if ep_rank < (ep_size - 1): - # Each non-last rank gets local_num_experts experts. - expert_map[ep_rank * local_num_experts: - (ep_rank + 1) * local_num_experts] = \ - torch.arange(0, local_num_experts, dtype=torch.int32) - else: - # All remaining experts are assigned to the last rank. - local_num_experts = (global_num_experts - ep_rank * local_num_experts) - - expert_map[-local_num_experts:] = \ - torch.arange(0, local_num_experts, dtype=torch.int32) + start_idx = ep_rank * base_experts + min(ep_rank, remainder) + expert_map[start_idx:start_idx + local_num_experts] = torch.arange( + 0, local_num_experts, dtype=torch.int32) return (local_num_experts, expert_map) -class FusedMoE(torch.nn.Module): +@CustomOp.register("fused_moe") +class FusedMoE(CustomOp): """FusedMoE layer for MoE models. This layer contains both MergedColumnParallel weights (gate_up_proj / @@ -836,6 +702,7 @@ class FusedMoE(torch.nn.Module): reduce_results: Whether to all all_reduce on the output of the layer renomalize: Whether to renormalize the logits in the fused_moe kernel quant_config: Quantization configure. + enable_eplb: Whether to enable expert parallelism load balancer. """ def __init__( @@ -860,6 +727,8 @@ class FusedMoE(torch.nn.Module): e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + num_redundant_experts: int = 0, has_bias: bool = False, ): super().__init__() @@ -867,28 +736,48 @@ class FusedMoE(torch.nn.Module): params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype + tp_size_ = (tp_size if tp_size is not None else + get_tensor_model_parallel_world_size()) + dp_size_ = (dp_size + if dp_size is not None else get_dp_group().world_size) + vllm_config = get_current_vllm_config() self.moe_parallel_config: FusedMoEParallelConfig = ( FusedMoEParallelConfig.make( - tp_size_=(tp_size if tp_size is not None else - get_tensor_model_parallel_world_size()), - dp_size_=(dp_size if dp_size is not None else - get_dp_group().world_size), + tp_size_=tp_size_, + dp_size_=dp_size_, vllm_parallel_config=vllm_config.parallel_config)) - self.global_num_experts = num_experts + self.global_num_experts = num_experts + num_redundant_experts + + # we padding globally so EP buffer allocation works + if (quant_config and quant_config.get_name() == "mxfp4" + and (current_platform.is_rocm() + or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 + or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16)): + hidden_size = round_up(hidden_size, 256) # For smuggling this layer into the fused moe custom op - self.use_direct_call = self.dp_size == 1 - if not self.use_direct_call: - compilation_config = vllm_config.compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError("Duplicate layer name: {}".format(prefix)) - compilation_config.static_forward_context[prefix] = self - self.layer_name = prefix + compilation_config = vllm_config.compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError("Duplicate layer name: {}".format(prefix)) + compilation_config.static_forward_context[prefix] = self + self.layer_name = prefix + + self.enable_eplb = enable_eplb + self.expert_load_view: Optional[torch.Tensor] = None + self.logical_to_physical_map: Optional[torch.Tensor] = None + self.logical_replica_count: Optional[torch.Tensor] = None # Determine expert maps if self.use_ep: + if self.enable_eplb: + assert self.global_num_experts % self.ep_size == 0, \ + "EPLB currently only supports even distribution of " \ + "experts across ranks." + else: + assert num_redundant_experts == 0, \ + "Redundant experts are only supported with EPLB." self.local_num_experts, self.expert_map = determine_expert_map( ep_size=self.ep_size, ep_rank=self.ep_rank, @@ -918,31 +807,23 @@ class FusedMoE(torch.nn.Module): if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " "non-grouped topk.") - if current_platform.is_hpu(): - from vllm_hpu_extension.ops import DynamicFusedMOE - self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts) - # Only support float8 for now. - quant_dtype = params_dtype - if quant_config is not None: - input_activations = get_quant_config_input_activations( - quant_config) - if (input_activations is not None - and input_activations.num_bits == 8 - and input_activations.type == QuantizationType.FLOAT): - quant_dtype = torch.float8_e4m3fn + if vllm_config.model_config is not None: + model_dtype = vllm_config.model_config.dtype + else: + # TODO (bnell): This is a hack to get test_mixtral_moe to work + # since model_config is not set in the pytest test. + model_dtype = params_dtype - moe = MoEConfig( - num_experts=self.global_num_experts, - experts_per_token=top_k, - hidden_dim=hidden_size, - num_local_experts=self.local_num_experts, - moe_parallel_config=self.moe_parallel_config, - in_dtype=params_dtype, - quant_dtype=quant_dtype, - max_num_tokens=MOE_DP_CHUNK_SIZE, - has_bias=has_bias, - ) + moe = FusedMoEConfig.make(num_experts=self.global_num_experts, + experts_per_token=top_k, + hidden_dim=hidden_size, + num_local_experts=self.local_num_experts, + moe_parallel_config=self.moe_parallel_config, + in_dtype=model_dtype, + max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, + quant_config=quant_config, + has_bias=has_bias) self.moe_config = moe self.quant_config = quant_config @@ -956,6 +837,21 @@ class FusedMoE(torch.nn.Module): assert isinstance(quant_method, FusedMoEMethodBase) self.quant_method = quant_method + if self.enable_eplb: + from vllm.model_executor.layers.quantization.fp8 import ( + Fp8MoEMethod) + if not isinstance(quant_method, + (Fp8MoEMethod, UnquantizedFusedMoEMethod)): + # TODO: Add support for additional quantization methods. + # The implementation for other quantization methods does not + # contain essential differences, but the current quant API + # design causes duplicated work when extending to new + # quantization methods, so I'm leaving it for now. + # If you plan to add support for more quantization methods, + # please refer to the implementation in `Fp8MoEMethod`. + raise NotImplementedError("EPLB is only supported for FP8 " + "quantization for now.") + moe_quant_params = { "num_experts": self.local_num_experts, "hidden_size": hidden_size, @@ -972,21 +868,24 @@ class FusedMoE(torch.nn.Module): moe_quant_params["intermediate_size_full"] = intermediate_size self.quant_method.create_weights(layer=self, **moe_quant_params) + if isinstance(self.quant_method, FusedMoEMethodBase): + self.quant_method.maybe_swap_experts_impl(self.moe_parallel_config) # Chunked all2all staging tensor self.batched_hidden_states: Optional[torch.Tensor] = None self.batched_router_logits: Optional[torch.Tensor] = None if (self.moe_parallel_config.use_pplx_kernels - or self.moe_parallel_config.use_deepep_ll_kernels): - act_dtype = vllm_config.model_config.dtype + or self.moe_parallel_config.use_deepep_ll_kernels + or self.moe_parallel_config.use_flashinfer_cutlass_kernels): self.batched_hidden_states = torch.zeros( - (MOE_DP_CHUNK_SIZE, self.hidden_size), - dtype=act_dtype, + (moe.max_num_tokens, self.hidden_size), + dtype=moe.in_dtype, device=torch.cuda.current_device()) + # Note here we use `num_experts` which is logical expert count self.batched_router_logits = torch.zeros( - (MOE_DP_CHUNK_SIZE, self.global_num_experts), - dtype=act_dtype, + (moe.max_num_tokens, num_experts), + dtype=moe.in_dtype, device=torch.cuda.current_device()) @property @@ -1029,6 +928,19 @@ class FusedMoE(torch.nn.Module): def use_deepep_ll_kernels(self): return self.moe_parallel_config.use_deepep_ll_kernels + @property + def use_flashinfer_cutlass_kernels(self): + return self.moe_parallel_config.use_flashinfer_cutlass_kernels + + def update_expert_map(self): + # ep_size and ep_rank should already be updated + assert self.expert_map is not None + with self.expert_map.device: + self.local_num_experts, self.expert_map = determine_expert_map( + ep_size=self.ep_size, + ep_rank=self.ep_rank, + global_num_experts=self.global_num_experts) + def _load_per_tensor_weight_scale(self, shard_id: str, param: torch.nn.Parameter, loaded_weight: torch.Tensor, @@ -1044,6 +956,18 @@ class FusedMoE(torch.nn.Module): elif shard_id == "w2": param_data[expert_id] = loaded_weight + def _load_combined_w13_weight_scale(self, shard_dim: int, + loaded_weight: torch.Tensor, + param: torch.Tensor, tp_rank: int): + """ + Load w13 weight scales assuming that w1 weight scales and w3 weight + scales are stored in the same loaded_weight tensor. + """ + shard_size = param.shape[shard_dim] + loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, + shard_size) + param.copy_(loaded_weight) + def _load_model_weight_or_group_weight_scale(self, shard_dim: int, expert_data: torch.Tensor, @@ -1089,14 +1013,21 @@ class FusedMoE(torch.nn.Module): expert_data=expert_data, tp_rank=tp_rank) - def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, - shard_id: str, loaded_weight: torch.Tensor, tp_rank: int): + def _load_w13(self, + expert_data: torch.Tensor, + shard_dim: int, + shard_id: str, + loaded_weight: torch.Tensor, + tp_rank: int, + load_full: bool = False): # Index the loaded weight for tp sharding. # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim shard_size = expert_data.shape[shard_dim] // 2 - loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, - shard_size) + if not load_full: + loaded_weight = loaded_weight.narrow(shard_dim, + shard_size * tp_rank, + shard_size) # Narrow parameter and load. # w1, gate_proj: Load into first logical weight of w13. if shard_id == "w1": @@ -1149,13 +1080,45 @@ class FusedMoE(torch.nn.Module): return expert_id return self.expert_map[expert_id].item() + @overload def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, - shard_id: str, expert_id: int) -> None: + shard_id: str, expert_id: int, + return_success: Literal[False]) -> None: + ... + + @overload + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, weight_name: str, + shard_id: str, expert_id: int, + return_success: Literal[True]) -> bool: + ... + + def weight_loader(self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, + return_success: bool = False) -> Optional[bool]: + + if self.quant_config and self.quant_config.get_name() == "mxfp4": + # (FIXME) for gpt-oss all experts are combined + if "bias" in weight_name: + dim1 = loaded_weight.shape[1] + param.data[:, :dim1].copy_(loaded_weight) + else: + dim1 = loaded_weight.shape[1] + dim2 = loaded_weight.shape[2] + param.data[:, :dim1, :dim2].copy_(loaded_weight) + return True if return_success else None expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) if expert_id == -1: - return + # Failed to load this param since it's not local to this rank + return False if return_success else None + # Hereafter, `expert_id` is local physical id + quant_method_name = self.quant_method.__class__.__name__ # compressed-tensors checkpoints with packed weights are stored flipped # TODO (mgoin): check self.quant_method.quant_config.quant_format @@ -1169,9 +1132,6 @@ class FusedMoE(torch.nn.Module): raise ValueError(f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}.") - WEIGHT_SCALE_SUPPORTED = [ - e.value for e in FusedMoeWeightScaleSupported - ] # Fetch the dim to shard the parameter/loaded weight # based on the shard id. This will be whatever # dimension intermediate_size_per_partition is used. @@ -1182,7 +1142,28 @@ class FusedMoE(torch.nn.Module): if is_gguf_weight_type: param.weight_type = loaded_weight.item() param.data.copy_(loaded_weight) - return + return True if return_success else None + + # Case for BitsAndBytes + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) + if use_bitsandbytes_4bit: + shard_dim = 0 + + expert_data = param.data[expert_id] + if shard_id == "w2": + expert_data.copy_(loaded_weight) + elif shard_id in ("w1", "w3"): + # BNB inflight quantization has already sharded the weights + full_load = True + self._load_w13( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=self.tp_rank, + load_full=full_load, + ) + return True if return_success else None # is_transposed: if the dim to shard the weight # should be flipped. Required by GPTQ, compressed-tensors @@ -1205,6 +1186,7 @@ class FusedMoE(torch.nn.Module): param.materialize(final_shape, dtype=loaded_weight.dtype) expert_data = param.data if full_load else param.data[expert_id] + # Case input scale: input_scale loading is only supported for fp8 if "input_scale" in weight_name: # this is needed for compressed-tensors only @@ -1221,7 +1203,7 @@ class FusedMoE(torch.nn.Module): self._load_single_value(param=param, loaded_weight=loaded_weight, expert_id=expert_id) - return + return True if return_success else None # Case g_idx if "g_idx" in weight_name: @@ -1230,25 +1212,61 @@ class FusedMoE(torch.nn.Module): loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=self.tp_rank) - return + return True if return_success else None + # TODO @dsikka: ModelOpt should follow the proper MoE loading pattern if "ModelOpt" in quant_method_name: - if ('weight_scale_2' in weight_name - or 'input_scale' in weight_name): - self._load_per_tensor_weight_scale(shard_id=shard_id, - param=param, - loaded_weight=loaded_weight, - expert_id=expert_id) - elif "weight" in weight_name: + # Determine per-tensor weight scale patterns based on variant + # Use the dedicated method instead of brittle string matching + uses_weight_scale_2 = self.quant_method.uses_weight_scale_2_pattern( + ) + + # Call _load_per_tensor_weight_scale() to load per-tensor (scalar) + # weights scales. + # Input scales are always per-tensor. + # Weight scales: FP4 uses "weight_scale_2" and FP8 uses + # "weight_scale" for per-tensor scales. + is_per_tensor = ("weight_scale_2" in weight_name + if uses_weight_scale_2 else "weight_scale" + in weight_name) or "input_scale" in weight_name + if is_per_tensor: + self._load_per_tensor_weight_scale( + shard_id=shard_id, + param=param, + loaded_weight=loaded_weight, + expert_id=expert_id, + ) + return True if return_success else None + + # If the weight is w13_weight_scale and w13_weight_scales are + # combined into single loaded_weight, call + # _load_combined_w13_weight_scale() to load it. + # This is checked by comparing the hidden_out dims of the + # loaded_weight and the param. + if "w13_weight_scale" in weight_name: + loaded_weight_hidden_out = loaded_weight.shape[-2] + param_hidden_out = param.data.shape[-2] * self.tp_size + if loaded_weight_hidden_out == param_hidden_out: + self._load_combined_w13_weight_scale( + shard_dim=shard_dim, + loaded_weight=loaded_weight, + param=param, + tp_rank=self.tp_rank, + ) + return True if return_success else None + + # For other weights, call _load_model_weight_or_group_weight_scale() + # to load it. + if "weight" in weight_name: self._load_model_weight_or_group_weight_scale( shard_id=shard_id, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=self.tp_rank) - return + return True if return_success else None - # Case weight scales, zero_points and offset + # Case weight scales, zero_points and offset, weight/input global scales if ("scale" in weight_name or "zero" in weight_name or "offset" in weight_name): # load the weight scales and zp based on the quantization scheme @@ -1281,9 +1299,12 @@ class FusedMoE(torch.nn.Module): loaded_weight=loaded_weight, expert_id=expert_id) else: + WEIGHT_SCALE_SUPPORTED = [ + e.value for e in FusedMoeWeightScaleSupported + ] raise ValueError( f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}") - return + return True if return_success else None # Case weight_shape if "weight_shape" in weight_name: @@ -1291,7 +1312,7 @@ class FusedMoE(torch.nn.Module): self._load_single_value(param=param, loaded_weight=loaded_weight, expert_id=expert_id) - return + return True if return_success else None # Case model weights if "weight" in weight_name: @@ -1301,23 +1322,87 @@ class FusedMoE(torch.nn.Module): loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=self.tp_rank) - return + return True if return_success else None + + return False if return_success else None + + def get_expert_weights(self) -> Iterable[torch.Tensor]: + weights = list(self.named_parameters()) + assert all(weight.is_contiguous() for _, weight in weights) + + # Filter out the non-expert weights. + # `e_score_correction_bias` is a bias for each logical expert, + # with shape (num_logical_experts,), not an expert weight. + NON_EXPERT_WEIGHTS = { + "e_score_correction_bias", + } + + return [ + weight.view(self.local_num_experts, -1) for name, weight in weights + if name not in NON_EXPERT_WEIGHTS + ] + + def set_eplb_state( + self, + moe_layer_idx: int, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + """ + Register the EPLB state in this layer. + + This is used later in forward pass, where we get the expert mapping + and record the load metrics in `expert_load_view`. + """ + self.expert_load_view = expert_load_view[moe_layer_idx] + self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx] + self.logical_replica_count = logical_replica_count[moe_layer_idx] @staticmethod - def select_experts(hidden_states: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - use_grouped_topk: bool, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, - indices_type: Optional[torch.dtype] = None): + def select_experts( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + indices_type: Optional[torch.dtype] = None, + enable_eplb: bool = False, + expert_map: Optional[torch.Tensor] = None, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Route the input hidden states to the top-k experts based on the + router logits. + + Returns: + (topk_weights, topk_ids) (tuple[torch.Tensor, torch.Tensor]): + The weights and *global physical* expert ids of the top-k experts. + + **Compatibility**: When EPLB is not enabled, the returned ids are + equivalent to global logical ids, so should be compatible with + plain MoE implementations without redundant experts. + """ from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk - # DeekSeekv2 uses grouped_top_k + # # Check if we should use a routing simulation strategy + # routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY + # if routing_strategy != "": + # return RoutingSimulator.simulate_routing( + # hidden_states=hidden_states, + # router_logits=router_logits, + # strategy_name=routing_strategy, + # top_k=top_k, + # indices_type=indices_type) + + # DeepSeekv2 uses grouped_top_k if use_grouped_topk: assert topk_group is not None assert num_expert_group is not None @@ -1349,6 +1434,63 @@ class FusedMoE(torch.nn.Module): if indices_type is not None: topk_ids = topk_ids.to(dtype=indices_type) + if enable_eplb: + assert expert_load_view is not None + assert logical_to_physical_map is not None + assert logical_replica_count is not None + + # 1. Convert the logical expert ids to physical expert ids + # Directly select a random replica for each logical expert + + # TODO: maybe optimize this by using specified kernels, + # or compute pseudo-random indices by modulo + + # In case `indices_type` is not `torch.long` or `torch.int`, + # e.g. `torch.uint32` as required by dispatch/combine kernels + topk_ids_long = topk_ids.long() + replica_indices = ( + torch.rand_like(topk_ids, dtype=torch.float) * + logical_replica_count[topk_ids_long]).long().unsqueeze(-1) + physical_ids = logical_to_physical_map[topk_ids_long].gather( + -1, replica_indices).squeeze(-1) + + topk_ids = physical_ids + + # 2. Record expert load metrics. + + # TODO(bowen): When using `FusedMoEModularKernel`, this + # can be done in a more unified way, since + # `FusedMoEPrepareAndFinalize` will return the expert + # token count, in some cases directly from the kernel. + # However, now there are many code paths not using + # the modular kernel, e.g. calling `fused_experts`, + # so we decide to keep the logic here. + # + # If later refactor moved all the MoE kernel calls + # to the modular kernel, we can move this logic there + # to achieve better efficiency. + + # `expert_load_view`: (num_physical_experts,) + + topk_ids_flatten = topk_ids.flatten() + + # Performance optimization: + # `masked_fill` is significantly faster than `masked_select` + invalid_mask = topk_ids_flatten < 0 + # Replace invalid expert ids with 0 (just a dummy position) + # to avoid out-of-bounds errors in scatter_add_ + index = topk_ids_flatten.masked_fill_(invalid_mask, 0) + # `src` is the valid mask, which is 1 for valid and 0 for invalid + src = ~invalid_mask + + expert_load_view.scatter_add_(dim=0, + index=index.long(), + src=src.to(expert_load_view)) + + topk_ids = topk_ids.to(dtype=indices_type) + + assert topk_ids.dtype == indices_type or indices_type is None + return topk_weights, topk_ids def must_reduce_shared_expert_outputs(self) -> bool: @@ -1380,11 +1522,20 @@ class FusedMoE(torch.nn.Module): def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): - if self.use_direct_call: + og_hidden_states = hidden_states.shape[-1] + if self.hidden_size != og_hidden_states: + hidden_states = F.pad(hidden_states, + (0, self.hidden_size - og_hidden_states), + mode='constant', + value=0.0) + # TODO: Once the OOM issue for the TPU backend is resolved, we will + # switch to using the moe_forward custom op. + if current_platform.is_tpu(): return self.forward_impl(hidden_states, router_logits) else: - return torch.ops.vllm.moe_forward(hidden_states, router_logits, - self.layer_name) + return torch.ops.vllm.moe_forward( + hidden_states, router_logits, + self.layer_name)[..., :og_hidden_states] def forward_impl_chunked(self, full_hidden_states: torch.Tensor, full_router_logits: torch.Tensor): @@ -1407,7 +1558,7 @@ class FusedMoE(torch.nn.Module): assert (self.batched_hidden_states.size(0) # type: ignore >= chunk_size) - assert (self.batched_router_logits.size(0) # type: ignore + assert (self.batched_router_logits.size(0) # type: ignore >= chunk_size) staged_hidden_states = self.batched_hidden_states[: chunk_size, :] # type: ignore @@ -1432,6 +1583,10 @@ class FusedMoE(torch.nn.Module): scoring_func=self.scoring_func, e_score_correction_bias=self.e_score_correction_bias, activation=self.activation, + enable_eplb=self.enable_eplb, + expert_load_view=self.expert_load_view, + logical_to_physical_map=self.logical_to_physical_map, + logical_replica_count=self.logical_replica_count, ) if not skip_result_store: @@ -1439,35 +1594,43 @@ class FusedMoE(torch.nn.Module): final_hidden_states, non_blocking=True) ctx = get_forward_context() + # flashinfer_cutlass_kernels can handle: optional DP + TP/EP max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens - num_tokens = full_hidden_states.size(0) - for chunk_start_ in range(0, max_tokens_across_dp, - moe_dp_chunk_size_per_rank): + for chunk_idx, chunk_start_ in enumerate( + range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank)): chunk_start = chunk_start_ chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank, max_tokens_across_dp) # clamp start and end chunk_start = min(chunk_start, num_tokens - 1) chunk_end = min(chunk_end, num_tokens) - - process_chunk(chunk_start, - chunk_end, - skip_result_store=chunk_start_ >= num_tokens) + with ctx.dp_metadata.chunked_sizes(moe_dp_chunk_size_per_rank, + chunk_idx): + process_chunk(chunk_start, + chunk_end, + skip_result_store=chunk_start_ >= num_tokens) return full_final_hidden_states def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): assert self.quant_method is not None + # Route to the chunked forward path using the FlashInfer Cutlass kernel + # only when data parallelism (DP) is enabled. + use_flashinfer_cutlass_kernels = ( + self.dp_size > 1 + and self.moe_parallel_config.use_flashinfer_cutlass_kernels) if (self.moe_parallel_config.use_pplx_kernels - or self.moe_parallel_config.use_deepep_ll_kernels): + or self.moe_parallel_config.use_deepep_ll_kernels + or use_flashinfer_cutlass_kernels): return self.forward_impl_chunked(hidden_states, router_logits) do_naive_dispatch_combine: bool = ( self.dp_size > 1 - and not self.moe_parallel_config.use_deepep_ht_kernels) + and not self.moe_parallel_config.use_deepep_ht_kernels + and not self.moe_parallel_config.use_flashinfer_cutlass_kernels) if do_naive_dispatch_combine: hidden_states, router_logits = get_ep_group().dispatch( hidden_states, router_logits) @@ -1489,11 +1652,14 @@ class FusedMoE(torch.nn.Module): e_score_correction_bias=self.e_score_correction_bias, activation=self.activation, apply_router_weight_on_input=self.apply_router_weight_on_input, + enable_eplb=self.enable_eplb, + expert_load_view=self.expert_load_view, + logical_to_physical_map=self.logical_to_physical_map, + logical_replica_count=self.logical_replica_count, ) if do_naive_dispatch_combine: final_hidden_states = get_ep_group().combine(final_hidden_states) - if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs. final_hidden_states = self.maybe_all_reduce_tensor_model_parallel( @@ -1501,23 +1667,37 @@ class FusedMoE(torch.nn.Module): return final_hidden_states - @classmethod - def make_expert_params_mapping( - cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, - ckpt_up_proj_name: str, - num_experts: int) -> list[tuple[str, str, int, str]]: + # @classmethod + # def make_expert_params_mapping( + # cls, + # ckpt_gate_proj_name: str, + # ckpt_down_proj_name: str, + # ckpt_up_proj_name: str, + # num_experts: int, + # num_redundant_experts: int = 0) -> list[tuple[str, str, int, str]]: - return [ - # (param_name, weight_name, expert_id, shard_id) - ("experts.w13_" if weight_name - in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_", - f"experts.{expert_id}.{weight_name}.", expert_id, shard_id) - for expert_id in range(num_experts) for shard_id, weight_name in [ - ("w1", ckpt_gate_proj_name), - ("w2", ckpt_down_proj_name), - ("w3", ckpt_up_proj_name), - ] - ] + # num_physical_experts = num_experts + num_redundant_experts + + # # In the returned mapping: + # # - `expert_id` is the physical expert id + # # - `weight_name` contains the weight name of the logical expert + # # So that we should map the expert id to logical in `weight_name` + # physical_to_logical_map = \ + # EplbState.build_initial_global_physical_to_logical_map( + # num_experts, num_redundant_experts) + + # return [ + # # (param_name, weight_name, expert_id, shard_id) + # ("experts.w13_" if weight_name + # in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_", + # f"experts.{physical_to_logical_map[expert_id]}.{weight_name}.", + # expert_id, shard_id) for expert_id in range(num_physical_experts) + # for shard_id, weight_name in [ + # ("w1", ckpt_gate_proj_name), + # ("w2", ckpt_down_proj_name), + # ("w3", ckpt_up_proj_name), + # ] + # ] def extra_repr(self) -> str: @@ -1557,7 +1737,12 @@ def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, direct_register_custom_op( op_name="moe_forward", op_func=moe_forward, - mutates_args=[], + mutates_args=["hidden_states"], fake_impl=moe_forward_fake, dispatch_key=current_platform.dispatch_key, + tags=(torch.Tag.needs_fixed_stride_order, ), ) + +# Mark the FusedMoE weight_loader as supporting MoE-specific parameters +# to avoid expensive runtime reflection in model loading code +FusedMoE.weight_loader.supports_moe_loading = True # type: ignore[attr-defined] diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index e7aaf62..6262904 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -1,10 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from typing import Optional +from dataclasses import dataclass +from enum import Enum +from math import prod +from typing import Any, Optional, final import torch +import vllm.envs as envs +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable + _resize_cache, count_expert_num_tokens) +from vllm.utils import cdiv + # # This file defines a set of base classes used to make MoE kernels more modular. # The goal is to be able to utilize different communication mechanisms with @@ -14,7 +23,7 @@ import torch # # [Router] → [Quantize-Dispatch] → [Permute-Experts-Unpermute] → [Combine] # -# Each component will be independent of the others except for +# Each component will be independent of (but may inform) the others except for # [Quantize-Dispatch] and `[Combine] (see below). The components can then be # mixed and matched with so that DP+EP can be supported easily for multiple # MoE kernel implementations. @@ -23,13 +32,19 @@ import torch # * FusedMoEPrepareAndFinalize - an abstract base class for preparation of MoE # inputs (e.g. quantization, distribution) and finalization of Moe outputs. # The prepare method must take care of any needed quantization and the -# finalize method must apply weights and do the final reduction of the output. +# finalize method, informed by the FusedMoEPermuteExpertsUnpermute method, +# may apply weights and/or do the final reduction of the output. # * FusedMoEPermuteExpertsUnpermute - an abstract base class for the main fused -# MoE operation. One important feature to note is that this class does not -# apply topk weights or reduce the final output. +# MoE operation, i.e matmul + act_mul + optionally quant + matmul. +# Some FusedMoEPermuteExpertsUnpermute implementations may choose to do +# the weight application and/or reduction. The class communicates this +# to [Finalize] via a TopKWeightAndReduce object. # * FusedMoEModularKernel - an interface class that combines a # FusedMoEPrepareAndFinalize and a FusedMoEPermuteExpertsUnpermute to # provide the standard fused MoE kernel interface. +# * TopKWeightAndReduce - A TopKWeightAndReduce implementation chosen +# by the FusedMoEPermuteExpertsUnpermute implementation that is passed +# on to [Finalize]. # # [Quantize-Prepare] and [Finalize] functionality are bundled into a single # class `FusedMoEPrepareAndFinalize` since they could use collective @@ -77,6 +92,56 @@ def _moe_problem_size( return E, M, N, K, topk +class FusedMoEActivationFormat(Enum): + """ + The standard activation format (num_tokens, hidden dim). + """ + Standard = "standard", + """ + The batched experts format (num experts, max tokens per expert, hidden dim) + """ + BatchedExperts = "batched_experts", + + +@dataclass +class ExpertTokensMetadata: + """ + Metadata regarding expert-token routing. + """ + expert_num_tokens: torch.Tensor + expert_num_tokens_cpu: Optional[torch.Tensor] + + @staticmethod + def make_from_list(expert_num_tokens_list: list[int], + device: str) -> "ExpertTokensMetadata": + expert_num_tokens_cpu = torch.tensor(expert_num_tokens_list, + device="cpu", + dtype=torch.int32) + return ExpertTokensMetadata( + expert_num_tokens=expert_num_tokens_cpu.to(device, + non_blocking=True), + expert_num_tokens_cpu=expert_num_tokens_cpu) + + +class TopKWeightAndReduce(ABC): + """ + An abstract base class for weight application and reduction implementations. + """ + + @abstractmethod + def apply(self, output: Optional[torch.Tensor], + fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool) -> torch.Tensor: + """ + Apply topk_weights to the fused_experts_outputs and/or reduce. + If an output tensor is not passed, it will be created in the + function. + """ + raise NotImplementedError + + +# TODO: pass FusedMoEParallelConfig in as ctor parameter? class FusedMoEPrepareAndFinalize(ABC): """ An abstract base class for the [Quantize-Prepare] and [Finalize] steps @@ -85,17 +150,15 @@ class FusedMoEPrepareAndFinalize(ABC): @abstractmethod def prepare( - self, - a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], - apply_router_weight_on_input: bool, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], - Optional[torch.Tensor], Optional[torch.Tensor]]: + self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, + topk_ids: torch.Tensor, num_experts: int, + expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + extra_prepare_args: Optional[dict[str, Any]] + ) -> tuple[torch.Tensor, Optional[torch.Tensor], + Optional[ExpertTokensMetadata], Optional[torch.Tensor], + Optional[torch.Tensor]]: """ Perform any quantization (and/or) dispatching needed for this kernel. @@ -114,22 +177,20 @@ class FusedMoEPrepareAndFinalize(ABC): Returns a tuple of: - quantized + dispatched a. - quantized + dispatched a1_scales. - - Optional tensor as big as number of local experts that contains the - number of tokens assigned to each local expert. + - Optional ExpertTokensMetadata containing gpu/cpu tensors + as big as the number of local experts with the information about the + number of tokens assigned to each local expert. - Optional dispatched expert topk IDs - - Optional dispatched expert topk weight + - Optional dispatched expert topk weight """ raise NotImplementedError @abstractmethod - def finalize( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - ) -> None: + def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: TopKWeightAndReduce, + extra_finalize_args: Optional[dict[str, Any]]) -> None: """ Perform any combine plus apply weights and perform a reduction on the fused experts output. @@ -140,6 +201,17 @@ class FusedMoEPrepareAndFinalize(ABC): - topk_ids: The topk_ids. - apply_router_weight_on_input: When False, apply the weights to fused_expert_output. + - weight_and_reduce_impl: An optional TopKWeightAndReduce + implementation. + """ + raise NotImplementedError + + @property + @abstractmethod + def activation_format(self) -> FusedMoEActivationFormat: + """ + A property indicating the output format of the activations for the + 'prepare' method. """ raise NotImplementedError @@ -159,11 +231,15 @@ class FusedMoEPrepareAndFinalize(ABC): Some PrepareFinalize All2All implementations are batched. Meaning, they can processes only as set of tokens at a time. This function returns the batch size i.e the maximum number of tokens - the implementation can process at a time. + the implementation can process at a time. Return None if there are no such restrictions. """ raise NotImplementedError + @abstractmethod + def num_dispatchers(self) -> int: + raise NotImplementedError + class FusedMoEPermuteExpertsUnpermute(ABC): """ @@ -171,6 +247,57 @@ class FusedMoEPermuteExpertsUnpermute(ABC): above. """ + def __init__( + self, + quant_config: Optional[FusedMoEQuantConfig], + ): + if quant_config is not None: + self.quant_config = quant_config + else: + self.quant_config = FusedMoEQuantConfig() + + @property + @abstractmethod + def activation_formats( + self) -> tuple[FusedMoEActivationFormat, FusedMoEActivationFormat]: + """ + A property which is a tuple of the input and output activation formats + for the 'apply' method. + """ + raise NotImplementedError + + @property + def quant_dtype(self) -> Optional[torch.dtype]: + return self.quant_config.quant_dtype + + @property + def block_shape(self) -> Optional[list[int]]: + return self.quant_config.block_shape + + @property + def per_act_token_quant(self) -> bool: + return self.quant_config.per_act_token_quant + + @property + def per_out_ch_quant(self) -> bool: + return self.quant_config.per_out_ch_quant + + # TODO (bnell): make this return a CHUNK_SIZE or None instead? + @abstractmethod + def supports_chunking(self) -> bool: + """ + A flag indicating whether or not this class supports activation + chunking. + """ + raise NotImplementedError + + @abstractmethod + def supports_expert_map(self) -> bool: + """ + A flag indicating whether or not this class supports expert maps + """ + raise NotImplementedError + @abstractmethod def workspace_shapes( self, @@ -180,20 +307,25 @@ class FusedMoEPermuteExpertsUnpermute(ABC): N: int, K: int, topk: int, - num_experts: int, - ) -> tuple[int, int, torch.dtype]: + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: Optional[ExpertTokensMetadata], + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: """ - Compute the number of elements for the temporary outputs of the two - gemms and activation in the fused expert function. Since the - gemms are independent, the workspace for the first gemm can be shared - with the workspace for the last gemm. + Compute the shapes for the temporary and final outputs of the two gemms + and activation in the fused expert function. Since the gemms are + independent, the workspace for the first gemm can be shared with the + workspace for the last gemm. Returns a tuple of: - - Number of workspace13 elements: must be large enough to hold the + - workspace13 shape tuple: must be large enough to hold the result of either expert gemm. - - Number of workspace2 elements: must be large enough to hold the + - workspace2 shape tuple: must be large enough to hold the result of the activation function. + - output shape tuple: must be exact size of the final gemm output. - Workspace type: The dtype to use for the workspace tensors. + - Note: in order for activation chunking to work, the first dimension + of each tuple must be the number of tokens. """ raise NotImplementedError @@ -207,12 +339,21 @@ class FusedMoEPermuteExpertsUnpermute(ABC): else: raise ValueError(f"Unsupported FusedMoe activation: {activation}") + def enable_chunking(self): + return envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and \ + self.supports_chunking() + + def finalize_weight_and_reduce_impl(self) -> TopKWeightAndReduce: + raise NotImplementedError + @abstractmethod def apply( self, + output: torch.Tensor, hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str, global_num_experts: int, @@ -225,17 +366,22 @@ class FusedMoEPermuteExpertsUnpermute(ABC): a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, - expert_num_tokens: Optional[torch.Tensor], - ) -> torch.Tensor: + expert_tokens_meta: Optional[ExpertTokensMetadata], + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]], + ): """ This function computes the intermediate result of a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2. Parameters: + - output: (torch.Tensor): The unweighted, unreduced output tensor. - hidden_states: (torch.Tensor): The (quantized) input tensor to the MoE layer. - w1 (torch.Tensor): The first set of expert weights. - w2 (torch.Tensor): The second set of expert weights. + - topk_weights: A map of row to expert weights. Some implementations + choose to do weight application. - topk_ids (torch.Tensor): A map of row to expert id. - activation (str): The activation function to apply after the first MoE layer. @@ -257,15 +403,28 @@ class FusedMoEPermuteExpertsUnpermute(ABC): must be large enough to hold output of either MoE gemm. - workspace2 (torch.Tensor): A scratch tensor used for the activation function. - - expert_num_tokens: An optional tensor containing the number of tokens - assigned to each expert when using batched experts format input. - - Returns: - - torch.Tensor: The unweighted, unreduced output tensor + - expert_tokens_meta (Optional[ExpertTokensMetadata]) - An optional + ExpertTokensMetadata object containing gpu/cpu tensors + as big as the number of local experts with the information about the + number of tokens assigned to each local expert. + - apply_router_weight_on_input: True if router weights are already + applied on the input. This is relevant if the implementation + chooses to do weight application. """ raise NotImplementedError +def _chunk_scales(scales: Optional[torch.Tensor], start: int, + end: int) -> Optional[torch.Tensor]: + if scales is not None: + if scales.numel() == 1: + return scales + else: + return scales[start:end] + return None + + +@final class FusedMoEModularKernel(torch.nn.Module): """ This class combines a FusedMoEPrepareAndFinalize instance and @@ -287,46 +446,56 @@ class FusedMoEModularKernel(torch.nn.Module): super().__init__() self.prepare_finalize = prepare_finalize self.fused_experts = fused_experts + assert prepare_finalize.activation_format == \ + fused_experts.activation_formats[0], ( + f"{prepare_finalize.__class__.__name__}." + f"{prepare_finalize.activation_format} == " + f"{fused_experts.__class__.__name__}." + f"{fused_experts.activation_formats[0]}") def _do_fused_experts( - self, - a1: torch.Tensor, # input to forward fn - a1q: torch.Tensor, # output of prepare fn - w1: torch.Tensor, - w2: torch.Tensor, - topk_ids: torch.Tensor, - expert_num_tokens: torch.Tensor, - activation: str, - global_num_experts: int, + self, fused_out: Optional[torch.Tensor], a1: torch.Tensor, + a1q: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + activation: str, global_num_experts: int, local_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor]) -> torch.Tensor: + a2_scale: Optional[torch.Tensor], + expert_tokens_meta: Optional[ExpertTokensMetadata], + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]]) -> torch.Tensor: _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) - # Use a1 here to decipher the correct workspace datatype - workspace13_shape, workspace2_shape, workspace_dtype = ( - self.fused_experts.workspace_shapes(a1, a1q, M, N, K, top_k, - global_num_experts)) + (workspace13_shape, workspace2_shape, fused_out_shape, + workspace_dtype) = self.fused_experts.workspace_shapes( + a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts, + expert_tokens_meta) - # We can reuse the memory between cache1 and cache3 because by the time - # we need cache3, we're done with cache1 - workspace13 = torch.zeros(workspace13_shape, + # We can reuse the memory between cache1 and cache3 because by the + # time we need cache3, we're done with cache1. + workspace13 = torch.empty(prod(workspace13_shape), device=a1.device, dtype=workspace_dtype) - workspace2 = torch.zeros(workspace2_shape, + workspace2 = torch.empty(prod(workspace2_shape), device=a1.device, dtype=workspace_dtype) - fused_out = self.fused_experts.apply( + assert fused_out is None or fused_out.shape == fused_out_shape, ( + f"fused_out {fused_out.shape} but expected {fused_out_shape}") + if fused_out is None: + # reuse workspace13 for the output + fused_out = _resize_cache(workspace13, fused_out_shape) + + self.fused_experts.apply( + fused_out, a1q, w1, w2, - topk_ids, + topk_weights=topk_weights, + topk_ids=topk_ids, activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, @@ -338,8 +507,162 @@ class FusedMoEModularKernel(torch.nn.Module): a2_scale=a2_scale, workspace13=workspace13, workspace2=workspace2, - expert_num_tokens=expert_num_tokens, - ) + expert_tokens_meta=expert_tokens_meta, + apply_router_weight_on_input=apply_router_weight_on_input, + extra_expert_args=extra_expert_args) + + return fused_out + + def _maybe_chunk_fused_experts( + self, + a1: torch.Tensor, + a1q: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + local_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + expert_tokens_meta: Optional[ExpertTokensMetadata], + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]], + ) -> torch.Tensor: + + _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) + + CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE + num_chunks = cdiv(M, CHUNK_SIZE) + + if not self.fused_experts.supports_chunking() or num_chunks == 1: + return self._do_fused_experts( + fused_out=None, + a1=a1, + a1q=a1q, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + global_num_experts=global_num_experts, + local_num_experts=local_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1q_scale=a1q_scale, + a2_scale=a2_scale, + expert_tokens_meta=expert_tokens_meta, + apply_router_weight_on_input=apply_router_weight_on_input, + extra_expert_args=extra_expert_args) + + # Chunking required case + assert num_chunks > 1 + + # Construct the entire output that can then be processed in chunks. + (_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes( + a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts, + expert_tokens_meta) + fused_out = torch.empty(fused_out_shape, + device=a1q.device, + dtype=a1.dtype) + + def slice_input_tensors( + chunk_idx: int + ) -> tuple[torch.Tensor, Optional[torch.Tensor], + Optional[torch.Tensor], torch.Tensor, torch.Tensor]: + s = chunk_idx * CHUNK_SIZE + e = min(s + CHUNK_SIZE, M) + return (a1q[s:e], _chunk_scales(a1q_scale, s, e), + _chunk_scales(a2_scale, s, + e), topk_ids[s:e], topk_weights[s:e]) + + def slice_output_tensor(chunk_idx: int) -> torch.Tensor: + assert fused_out.size(0) % M == 0, ( + f"fused_out shape {fused_out.shape} vs M {M}") + factor = fused_out.size(0) // M + out_chunk_size = CHUNK_SIZE * factor + s = chunk_idx * out_chunk_size + e = min(s + out_chunk_size, fused_out.size(0)) + return fused_out[s:e] + + def slice_expert_tokens_metadata( + full_expert_tokens_meta: ExpertTokensMetadata, + chunk_topk_ids: torch.Tensor, local_num_experts: int, + expert_map: Optional[torch.Tensor]) -> ExpertTokensMetadata: + # The existing expert_num_tokens is for the entire a1q + # input. Chunking forces recomputation of the number + # of tokens assigned to each expert. + c_expert_num_tokens = count_expert_num_tokens( + chunk_topk_ids, local_num_experts, expert_map) + + c_expert_num_tokens_cpu = None + need_expert_num_tokens_cpu = ( + full_expert_tokens_meta.expert_num_tokens_cpu is not None) + if need_expert_num_tokens_cpu: + # This is blocking as some implementations need the count + # on the CPU to determine appropriate input/out fused-moe + # buffers + c_expert_num_tokens_cpu = c_expert_num_tokens.to( + "cpu", non_blocking=False) + + return ExpertTokensMetadata( + expert_num_tokens=c_expert_num_tokens, + expert_num_tokens_cpu=c_expert_num_tokens_cpu) + + m = None + if extra_expert_args is not None and 'm' in extra_expert_args: + m = extra_expert_args.get('m') + + if extra_expert_args is not None: + chunked_extra_expert_args = extra_expert_args + else: + chunked_extra_expert_args = {} + + for chunk_idx in range(num_chunks): + c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = ( + slice_input_tensors(chunk_idx)) + + c_expert_tokens_meta = None + if expert_tokens_meta is not None: + c_expert_tokens_meta = slice_expert_tokens_metadata( + expert_tokens_meta, c_topk_ids, local_num_experts, + expert_map) + + s = chunk_idx * CHUNK_SIZE + e = min(s + CHUNK_SIZE, M) + + if m is not None: + chunked_extra_expert_args['m'] = e - s + self._do_fused_experts( + fused_out=slice_output_tensor(chunk_idx), + a1=a1, + a1q=c_a1q, + w1=w1, + w2=w2, + topk_weights=c_topk_weights, + topk_ids=c_topk_ids, + activation=activation, + global_num_experts=global_num_experts, + local_num_experts=local_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1q_scale=c_a1q_scale, + a2_scale=c_a2_scale, + expert_tokens_meta=c_expert_tokens_meta, + apply_router_weight_on_input=apply_router_weight_on_input, + extra_expert_args=chunked_extra_expert_args) return fused_out @@ -361,6 +684,9 @@ class FusedMoEModularKernel(torch.nn.Module): a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, + extra_expert_args: Optional[dict] = None, + extra_prepare_args: Optional[dict] = None, + extra_finalize_args: Optional[dict] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets @@ -393,6 +719,12 @@ class FusedMoEModularKernel(torch.nn.Module): - apply_router_weight_on_input (bool): When true, the topk weights are applied directly on the inputs. This is only applicable when topk is 1. + - extra_expert_args (Optional[dict]): Extra keyword arguments to pass to + fused_experts.apply. + - extra_prepare_args (Optional[dict]): Extra keyword arguments to pass + to prepare. + - extra_finalize_args (Optional[dict]): Extra keyword arguments to pass + to finalize. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -401,19 +733,31 @@ class FusedMoEModularKernel(torch.nn.Module): a1 = hidden_states output = a1 if inplace else torch.zeros_like(a1) + local_num_experts = w1.size(0) if global_num_experts == -1: - global_num_experts = w1.size(0) + global_num_experts = local_num_experts - (a1q, a1q_scale, expert_num_tokens, _expert_topk_ids, + (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, _expert_topk_weights) = self.prepare_finalize.prepare( - a1, a1_scale, a2_scale, topk_weights, topk_ids, - global_num_experts, expert_map, apply_router_weight_on_input) + a1, + a1_scale, + a2_scale, + topk_weights, + topk_ids, + global_num_experts, + expert_map, + apply_router_weight_on_input, + self.fused_experts.quant_config, + extra_prepare_args, + ) + # Maybe prepare gathered topk_ids and topk_weights from other EP ranks. topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids topk_weights = (topk_weights if _expert_topk_weights is None else _expert_topk_weights) fused_out = None + if a1q.numel() == 0: # This happens when none of the tokens from the all2all reach this # EP rank. Also, note that this is only relevant for CUDAGraph @@ -423,24 +767,31 @@ class FusedMoEModularKernel(torch.nn.Module): # and can never run into the tensor.numel() == 0 case. fused_out = torch.empty_like(a1q).to(dtype=a1.dtype) else: - fused_out = self._do_fused_experts( + fused_out = self._maybe_chunk_fused_experts( a1=a1, a1q=a1q, w1=w1, w2=w2, + topk_weights=topk_weights, topk_ids=topk_ids, - expert_num_tokens=expert_num_tokens, activation=activation, global_num_experts=global_num_experts, + local_num_experts=local_num_experts, expert_map=expert_map, w1_scale=w1_scale, w2_scale=w2_scale, w1_zp=w1_zp, w2_zp=w2_zp, a1q_scale=a1q_scale, - a2_scale=a2_scale) + a2_scale=a2_scale, + expert_tokens_meta=expert_tokens_meta, + apply_router_weight_on_input=apply_router_weight_on_input, + extra_expert_args=extra_expert_args) - self.prepare_finalize.finalize(output, fused_out, topk_weights, - topk_ids, apply_router_weight_on_input) + self.prepare_finalize.finalize( + output, fused_out, topk_weights, topk_ids, + apply_router_weight_on_input, + self.fused_experts.finalize_weight_and_reduce_impl(), + extra_finalize_args) return output diff --git a/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py b/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py new file mode 100644 index 0000000..fb398ee --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py @@ -0,0 +1,146 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +import vllm._custom_ops as ops +import vllm.model_executor.layers.fused_moe.modular_kernel as mk + + +class TopKWeightAndReduceDelegate(mk.TopKWeightAndReduce): + """ + Useful in the case when some FusedMoEPermuteExpertsUnpermute + implementation does not perform weight application and reduction + but cannot address the needs of all the compatible PrepareAndFinalize + implementations. + For example, BatchedTritonExperts is compatible with both + PplxPrepareAndFinalize and BatchedPrepareAndFinalize. PplxPrepareAndFinalize + does the weight-application + reduction as part of the pplx combine kernel. + But the BatchedPrepareAndFinalize needs an implementation. To facilitate + this case, the BatchedTritonExperts could use TopKWeightAndReduceDelegate + so the PrepareAndFinalize implementations could choose how to + weight + reduce. + """ + + def __eq__(self, other): + return isinstance(other, TopKWeightAndReduceDelegate) + + def apply(self, output: Optional[torch.Tensor], + fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool) -> torch.Tensor: + raise RuntimeError("The caller is expected to choose an appropriate " + "TopKWeightAndReduce implementation.") + + +class TopKWeightAndReduceNoOP(mk.TopKWeightAndReduce): + """ + The fused_experts outputs have already been weight applied and reduced. + This implementation is a no-op. + """ + + def __eq__(self, other): + return isinstance(other, TopKWeightAndReduceNoOP) + + def apply(self, output: Optional[torch.Tensor], + fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool) -> torch.Tensor: + # Weight application and reduction operations are already done. + if output is None: + return fused_expert_output + + # MoEPrepareAndFinalizeNoEP needs the output to be in the `output` + # tensor. + assert output.size() == fused_expert_output.size(), ( + "output shape is expected to match the fused_expert_output shape. " + f"But got output={output.size()}, " + f"used_expert_output={fused_expert_output.size()}") + output.copy_(fused_expert_output, non_blocking=True) + return output + + +class TopKWeightAndReduceContiguous(mk.TopKWeightAndReduce): + """ + TopKWeightAndReduce implementation for a fused_experts output + of shape (m, topk, K) + """ + + def __eq__(self, other): + return isinstance(other, TopKWeightAndReduceContiguous) + + def apply(self, output: Optional[torch.Tensor], + fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool) -> torch.Tensor: + + m, num_topk = topk_ids.size() + k = fused_expert_output.size(-1) + if fused_expert_output.ndim == 2: + fused_expert_output = fused_expert_output.view(m, num_topk, k) + + assert fused_expert_output.size() == (m, num_topk, k), ( + f"Expected fused_expert_output size {(m, num_topk, k)}. But got " + f"{fused_expert_output.size()}") + + if not apply_router_weight_on_input: + fused_expert_output.mul_(topk_weights.view(m, -1, 1)) + + if output is None: + output = torch.empty((m, k), + device=fused_expert_output.device, + dtype=fused_expert_output.dtype) + assert output.size() == (m, k), ( + f"Expected output size {(m, k)}. But got {output.size()}") + + ops.moe_sum(fused_expert_output, output) + return output + + +class TopKWeightAndReduceNaiveBatched(mk.TopKWeightAndReduce): + """ + TopKWeightAndReduce implementation for a fused_experts output + of shape (num_experts, batch_size, K) + """ + + def __init__(self, rank: int): + self.rank = rank + + def __eq__(self, other): + return (isinstance(other, TopKWeightAndReduceNaiveBatched) + and (other.rank == self.rank)) + + def apply(self, output: Optional[torch.Tensor], + fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool) -> torch.Tensor: + assert fused_expert_output.ndim == 3 + num_tokens = topk_ids.size(0) + num_local_experts = fused_expert_output.size(0) + K = fused_expert_output.size(-1) + + if output is None: + output = torch.zeros((num_tokens, K), + device=fused_expert_output.device, + dtype=fused_expert_output.dtype) + else: + output.fill_(0) + + assert output.size() == (num_tokens, K), ( + f"Expected output size {(num_tokens, K)}, but got {output.size()}") + + first_expert = num_local_experts * self.rank + last_expert = first_expert + num_local_experts + + for expert_id in range(first_expert, last_expert): + matching_tokens = topk_ids == expert_id + topks = torch.any(matching_tokens, dim=1).flatten() + rows = torch.count_nonzero(topks) + rhs = fused_expert_output[expert_id - first_expert, :rows, :] + if not apply_router_weight_on_input: + rhs.mul_(topk_weights[matching_tokens].view(rhs.size(0), 1)) + output[topks] = output[topks] + rhs + + return output diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index cb49594..80e298e 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from math import prod -from typing import Optional +from typing import Any, Optional, Union import torch @@ -10,7 +10,83 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8) from vllm.model_executor.layers.quantization.utils.int8_utils import ( per_token_group_quant_int8, per_token_quant_int8) +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( + quant_dequant_mxfp4) +from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton from vllm.utils import cdiv +# from vllm.utils.flashinfer import fp4_quantize + + +@triton.jit +def _count_expert_num_tokens(topk_ids_ptr, expert_num_tokens_ptr, num_experts, + topk_numel, expert_map, + HAS_EXPERT_MAP: tl.constexpr, + BLOCK_SIZE: tl.constexpr): + + curr_expert = tl.program_id(0) + + offsets = tl.arange(0, BLOCK_SIZE) + topk_ids_ptrs = topk_ids_ptr + offsets + + acc = tl.zeros((BLOCK_SIZE, ), dtype=tl.int32) + for x in range(tl.cdiv(topk_numel, BLOCK_SIZE)): + mask = offsets < (topk_numel - x * BLOCK_SIZE) + expert_ids = tl.load(topk_ids_ptrs, mask=mask, other=-1) + if HAS_EXPERT_MAP: + expert_map_ptrs = expert_map + expert_ids + expert_map_mask = expert_ids >= 0 + expert_ids = tl.load(expert_map_ptrs, + mask=expert_map_mask, + other=-1) + + has_curr_expert = tl.where(expert_ids == curr_expert, 1, 0) + acc = acc + has_curr_expert + topk_ids_ptrs += BLOCK_SIZE + + if curr_expert < num_experts: + tl.store(expert_num_tokens_ptr + curr_expert, tl.sum(acc)) + + +def count_expert_num_tokens( + topk_ids: torch.Tensor, num_local_experts: int, + expert_map: Optional[torch.Tensor]) -> torch.Tensor: + """ + Count the number to tokens assigned to each expert. + + Parameters: + - topk_ids (torch.Tensor): Tensor mapping each token to its + list of experts. + - num_local_experts (int): Number of experts in this rank. + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert + parallel shard. + + Returns: + A tensor of size num_local_experts, where tensor[i] holds the number + of tokens assigned to the ith expert. + """ + assert topk_ids.dtype.is_signed, ( + "The kernel uses -1 to represent invalid topk_ids") + expert_num_tokens = torch.empty((num_local_experts), + device=topk_ids.device, + dtype=torch.int32) + + grid = num_local_experts + BLOCK_SIZE = min(topk_ids.numel(), 1024) + BLOCK_SIZE = triton.next_power_of_2(BLOCK_SIZE) + + _count_expert_num_tokens[(grid, )]( + topk_ids, + expert_num_tokens, + num_local_experts, + topk_ids.numel(), + expert_map, + HAS_EXPERT_MAP=expert_map is not None, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return expert_num_tokens def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor: @@ -23,6 +99,16 @@ def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor: return x.flatten()[:prod(v)].view(*v) +# def _fp4_quantize( +# A: torch.Tensor, +# A_scale: Optional[torch.Tensor], +# is_sf_swizzled_layout: bool, +# ) -> tuple[torch.Tensor, torch.Tensor]: +# return fp4_quantize(A, +# A_scale, +# is_sf_swizzled_layout=is_sf_swizzled_layout) + + def _fp8_quantize( A: torch.Tensor, A_scale: Optional[torch.Tensor], @@ -34,9 +120,12 @@ def _fp8_quantize( is provided, the output will be blocked. """ if block_shape is None: + # TODO(luka): use QuantFP8 custom op + # https://github.com/vllm-project/vllm/issues/20711 A, A_scale = ops.scaled_fp8_quant( A, A_scale, use_per_token_if_dynamic=per_act_token) else: + assert not per_act_token assert len(block_shape) == 2 _, block_k = block_shape[0], block_shape[1] A, A_scale = per_token_group_quant_fp8(A, block_k) @@ -62,9 +151,9 @@ def _int8_quantize( if block_shape is None: assert per_act_token, \ "int8 quantization only supports block or channel-wise" - # A, A_scale = per_token_quant_int8(A) - A, A_scale, _ = ops.scaled_int8_quant(A, A_scale) + A, A_scale = per_token_quant_int8(A) else: + assert not per_act_token assert len(block_shape) == 2 _, block_k = block_shape[0], block_shape[1] A, A_scale = per_token_group_quant_int8(A, block_k) @@ -73,19 +162,40 @@ def _int8_quantize( return A, A_scale +def _mxfp4_quantize( + A: torch.Tensor, + A_scale: Optional[torch.Tensor], + per_act_token_quant: bool, + block_shape: Optional[list[int]] = None, +) -> tuple[torch.Tensor, None]: + assert block_shape is None + if not current_platform.supports_mx(): + A = quant_dequant_mxfp4(A) + else: + raise NotImplementedError() + + return A, None + + def moe_kernel_quantize_input( A: torch.Tensor, A_scale: Optional[torch.Tensor], - qtype: Optional[torch.dtype], - per_channel_quant: bool, + quant_dtype: Union[None, torch.dtype, str], + per_act_token_quant: bool, block_shape: Optional[list[int]] = None, + is_fp4_scale_swizzled: bool = True, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - if qtype == torch.float8_e4m3fn: - return _fp8_quantize(A, A_scale, per_channel_quant, block_shape) - elif qtype == torch.int8: - return _int8_quantize(A, A_scale, per_channel_quant, block_shape) + if quant_dtype == torch.float8_e4m3fn: + return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape) + elif quant_dtype == torch.int8: + return _int8_quantize(A, A_scale, per_act_token_quant, block_shape) + elif quant_dtype == torch.uint8: # nvfp4 + return _fp4_quantize(A, + A_scale, + is_sf_swizzled_layout=is_fp4_scale_swizzled) + elif quant_dtype == "mxfp4": + return _mxfp4_quantize(A, A_scale, per_act_token_quant, block_shape) else: - assert A_scale is None return A, A_scale @@ -97,3 +207,62 @@ def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype) else: return m[idx, ...] + + +def normalize_scales_shape( + scales: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + if scales is not None: + if scales.numel() == 1: + scales = scales.view(1, 1) + else: + scales = scales.view(-1, scales.size(-1)) + return scales + + +def normalize_batched_scales_shape( + scales: Optional[torch.Tensor], + num_experts: int, +) -> Optional[torch.Tensor]: + if scales is not None and scales.ndim < 3: + if scales.numel() == 1: + scales = scales.view(1) + scales = torch.repeat_interleave(scales, num_experts, + dim=0).view(num_experts, 1, 1) + else: + scales = scales.view(num_experts, -1, scales.size(-1)) + + return scales + + +def _validate_scale_shape( + a: torch.Tensor, + a_scale: Optional[torch.Tensor], + per_act_token_quant: bool, + block_shape: Optional[list[int]], +) -> None: + if a_scale is None: + return + + if not per_act_token_quant and block_shape is None: + assert a_scale.numel() == 1, f"{a_scale.shape}" + elif per_act_token_quant: + assert a_scale.shape[0] == a.shape[0] and a_scale.shape[1] == 1, ( + f"{a_scale.shape[0]} == {a.shape[0]} and {a_scale.shape[1]} == 1") + else: + assert block_shape is not None + expected = (a.shape[0], cdiv(a.shape[1], block_shape[1])) + assert a_scale.shape == expected, f"{a_scale.shape} == {expected}" + + +def extract_required_args( + extra_args: Optional[dict[str, Any]], + required_keys: list[str], +) -> tuple[Any, ...]: + if extra_args is None: + raise ValueError("`extra_args` must be provided.") + + missing_keys = [k for k in required_keys if k not in extra_args] + if missing_keys: + raise ValueError(f"Missing keys in `extra_args`: {missing_keys}") + + return tuple(extra_args[k] for k in required_keys) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index c3c184a..dbd8708 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -36,6 +36,7 @@ QuantizationMethods = Literal[ "moe_wna16", "torchao", "auto-round", + "mxfp4", ] QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) @@ -108,6 +109,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: from .marlin import MarlinConfig from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config from .moe_wna16 import MoeWNA16Config + from .mxfp4 import Mxfp4Config from .neuron_quant import NeuronQuantConfig from .ptpc_fp8 import PTPCFp8Config from .qqq import QQQConfig @@ -143,6 +145,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "moe_wna16": MoeWNA16Config, "torchao": TorchAOConfig, "auto-round": AutoRoundConfig, + "mxfp4": Mxfp4Config, } # Update the `method_to_config` with customized quantization methods. method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py new file mode 100644 index 0000000..b0932fc --- /dev/null +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -0,0 +1,581 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Callable, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm import envs +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, + FusedMoEMethodBase, fused_experts) +from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( + triton_kernel_moe_forward) +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + prepare_moe_fp4_layer_for_marlin) +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( + _can_support_mxfp4, _swizzle_mxfp4) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped) +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types +from vllm.utils import (has_triton_kernels, is_torch_equal_or_newer, + next_power_of_2, round_up) + +if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 + or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): + # from flashinfer.fused_moe import cutlass_fused_moe + from flashinfer import (mxfp8_quantize, shuffle_matrix_a, + shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe) + + +class Mxfp4Config(QuantizationConfig): + + def __init__(self, ignored_layers: Optional[list[str]] = None): + super().__init__() + self.ignored_layers = ignored_layers + + @classmethod + def from_config(cls, config): + return cls() + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "mxfp4" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16] + + @classmethod + def get_config_filenames(cls) -> list[str]: + return [] + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + + if isinstance(layer, LinearBase): + if self.ignored_layers and is_layer_skipped( + prefix=prefix, + ignored_layers=self.ignored_layers, + fused_mapping=self.packed_modules_mapping): + return UnquantizedLinearMethod() + raise NotImplementedError("Mxfp4 linear layer is not implemented") + elif isinstance(layer, FusedMoE): + return Mxfp4MoEMethod(layer.moe_config) + elif isinstance(layer, Attention): + raise NotImplementedError( + "Mxfp4 attention layer is not implemented") + return None + + +class Mxfp4MoEMethod(FusedMoEMethodBase): + + def __init__(self, moe: FusedMoEConfig): + super().__init__() + self.topk_indices_dtype = None + self.moe = moe + self.use_marlin = self._should_use_marlin() + + def _should_use_marlin(self): + if envs.VLLM_MXFP4_USE_MARLIN is not None: + return envs.VLLM_MXFP4_USE_MARLIN + # if current_platform.is_cuda() and \ + # not current_platform.has_device_capability(100): + # if not current_platform.is_device_capability(90): + # # marlin kernel has better performance on ampere + # return True + # if not has_triton_kernels(): + # return True + # if not is_torch_equal_or_newer("2.8.0"): + # return True + return False + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + self.num_experts = num_experts + weight_dtype = torch.uint8 + scale_dtype = torch.uint8 + + # FIXME (zyongye): ship after torch and safetensors support mxfp4 + # is_torch_mxfp4_available = ( + # hasattr(torch, "float4_e2m1fn_x2") and + # hasattr(torch, "float8_e8m0fnu")) + # if is_torch_mxfp4_available: + # weight_dtype = torch.float4_e2m1fn_x2 + # scale_dtype = torch.float8_e8m0fnu + + mxfp4_block = 32 + + intermediate_size_per_partition_after_pad = \ + intermediate_size_per_partition + if self.use_marlin: + # The moe marlin kernel requires that for each linear + # n % 256 == 0 and k % 128 == 0. + # In gate_up_proj: + # n = 2 * intermediate_size_per_partition_after_pad + # k = hidden_size + # In down_proj + # n = hidden_size + # k = intermediate_size_per_partition_after_pad + intermediate_size_per_partition_after_pad = round_up( + intermediate_size_per_partition, 128) + hidden_size = round_up(hidden_size, 256) + + layer.params_dtype = params_dtype + layer.num_experts = num_experts + layer.hidden_size = hidden_size + layer.intermediate_size_per_partition = \ + intermediate_size_per_partition_after_pad + elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 + or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): + # pad the intermediate size to be a multiple of 2 * mxfp4_block + # for to hold non-uniform sharded tensor as well as swizzling + # other padding to increase performance + intermediate_size_per_partition_after_pad = round_up( + intermediate_size_per_partition, 256) + hidden_size = round_up(hidden_size, 256) + elif current_platform.is_rocm(): + intermediate_size_per_partition_after_pad = round_up( + intermediate_size_per_partition, 128) + else: + intermediate_size_per_partition_after_pad = round_up( + intermediate_size_per_partition, 64) + + self.intermediate_size = intermediate_size_per_partition_after_pad + self.hidden_size = hidden_size + # Fused gate_up_proj (column parallel) + w13_weight = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition_after_pad, + hidden_size // 2, + dtype=weight_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w13_weight_scale = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition_after_pad, + hidden_size // mxfp4_block, + dtype=scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + + w13_bias = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition_after_pad, + dtype=torch.bfloat16, + ), + requires_grad=False, + ) + layer.register_parameter("w13_bias", w13_bias) + set_weight_attrs(w13_bias, extra_weight_attrs) + + # down_proj (row parallel) + w2_weight = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size, + intermediate_size_per_partition_after_pad // 2, + dtype=weight_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + w2_weight_scale = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size, + intermediate_size_per_partition_after_pad // mxfp4_block, + dtype=scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + w2_bias = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size, + dtype=torch.bfloat16, + ), + requires_grad=False, + ) + layer.register_parameter("w2_bias", w2_bias) + set_weight_attrs(w2_bias, extra_weight_attrs) + + def process_weights_after_loading(self, layer): + if self.use_marlin: + prepare_moe_fp4_layer_for_marlin(layer) + elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 + or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): + layer.gemm1_alpha = Parameter(torch.tensor( + [1.702] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False) + layer.gemm1_beta = Parameter(torch.tensor( + [1.0] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False) + layer.gemm1_clamp_limit = Parameter(torch.tensor( + [7.0] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False) + sf_block_size = 32 # mxfp4 block size + + assert (layer.w13_weight.dim() == 3 + and layer.w13_weight.shape[0] == self.num_experts + and layer.w13_weight.shape[1] == self.intermediate_size * 2 + and layer.w13_weight.shape[2] == self.hidden_size // 2) + assert (layer.w13_weight_scale.dim() == 3 + and layer.w13_weight_scale.shape[0] == self.num_experts + and layer.w13_weight_scale.shape[1] + == self.intermediate_size * 2 + and layer.w13_weight_scale.shape[2] + == self.hidden_size // sf_block_size) + assert (layer.w2_weight.dim() == 3 + and layer.w2_weight.shape[0] == self.num_experts + and layer.w2_weight.shape[1] == self.hidden_size and + layer.w2_weight.shape[2] == self.intermediate_size // 2) + assert (layer.w2_weight_scale.dim() == 3 + and layer.w2_weight_scale.shape[1] == self.hidden_size + and layer.w2_weight_scale.shape[2] + == self.intermediate_size // sf_block_size) + assert (layer.w13_bias.dim() == 2 + and layer.w13_bias.shape[0] == self.num_experts + and layer.w13_bias.shape[1] == self.intermediate_size * 2) + assert (layer.w2_bias.dim() == 2 + and layer.w2_bias.shape[0] == self.num_experts + and layer.w2_bias.shape[1] == self.hidden_size) + + w13_weight_scale = layer.w13_weight_scale.data + w2_weight_scale = layer.w2_weight_scale.data + w13_weight = layer.w13_weight.data + w2_weight = layer.w2_weight.data + w13_bias = layer.w13_bias.data.to(torch.float32) + w2_bias = layer.w2_bias.data.to(torch.float32) + + # Swap w1 and w3 as the defenition of + # swiglu is different in the trtllm-gen + def swap_every_two_rows(x, axis=-1): + shape = x.shape + if axis < 0: + axis = len(shape) + axis + + # Create a new shape with pairs swapped along specified axis + new_shape = list(shape) + new_shape[axis] = shape[axis] // 2 + new_shape.insert(axis + 1, 2) + + # Reshape to expose pairs, swap them, and reshape back + x = x.reshape(*new_shape) + x = x.flip(axis + 1) + new_shape = list(shape) + return x.reshape(*new_shape) + + w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2) + w13_weight = swap_every_two_rows(w13_weight, -2) + w13_bias = swap_every_two_rows(w13_bias, -1) + + # Do not interleave as the checkpoint is already interleaved + + # Shuffle weights and scaling factors for transposed mma output + gemm1_weights_mxfp4_shuffled = [] + gemm1_scales_mxfp4_shuffled = [] + gemm2_weights_mxfp4_shuffled = [] + gemm2_scales_mxfp4_shuffled = [] + gemm1_bias_shuffled = [] + gemm2_bias_shuffled = [] + epilogue_tile_m = 128 # FIXME: this depends on the kernel internals + for i in range(self.num_experts): + gemm1_weights_mxfp4_shuffled.append( + shuffle_matrix_a(w13_weight[i].view(torch.uint8), + epilogue_tile_m)) + gemm1_scales_mxfp4_shuffled.append( + shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8), + epilogue_tile_m)) + gemm1_bias_shuffled.append( + shuffle_matrix_a(w13_bias[i].clone().reshape(-1, 1), + epilogue_tile_m)) + + gemm2_weights_mxfp4_shuffled.append( + shuffle_matrix_a(w2_weight[i].view(torch.uint8), + epilogue_tile_m)) + gemm2_scales_mxfp4_shuffled.append( + shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8), + epilogue_tile_m)) + gemm2_bias_shuffled.append( + shuffle_matrix_a(w2_bias[i].clone().reshape(-1, 1), + epilogue_tile_m)) + + w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled) + w13_weight_scale = torch.stack( + gemm1_scales_mxfp4_shuffled).reshape( + self.num_experts, 2 * self.intermediate_size, + self.hidden_size // sf_block_size).view( + torch.float8_e4m3fn) + + w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled) + w2_weight_scale = torch.stack(gemm2_scales_mxfp4_shuffled).reshape( + self.num_experts, self.hidden_size, self.intermediate_size // + sf_block_size).view(torch.float8_e4m3fn) + + layer.w13_weight = Parameter(w13_weight, requires_grad=False) + layer.w13_weight_scale = Parameter(w13_weight_scale, + requires_grad=False) + layer.w2_weight = Parameter(w2_weight, requires_grad=False) + layer.w2_weight_scale = Parameter(w2_weight_scale, + requires_grad=False) + layer.w13_bias = Parameter( + torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1), + requires_grad=False) + layer.w2_bias = Parameter(torch.stack(gemm2_bias_shuffled).reshape( + self.num_experts, -1), + requires_grad=False) + elif has_triton_kernels(): + from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig + + w13_bias = layer.w13_bias.to(torch.float32) + w2_bias = layer.w2_bias.to(torch.float32) + + layer.w13_bias = Parameter(w13_bias, requires_grad=False) + layer.w2_bias = Parameter(w2_bias, requires_grad=False) + + # FIXME warp need to be adjusted based on batch size + # only apply to batched mode + if self.moe.use_ep: + num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8 + else: + num_warps = 8 + + w13_weight, w13_flex, w13_scale = _swizzle_mxfp4( + layer.w13_weight, layer.w13_weight_scale, num_warps) + w2_weight, w2_flex, w2_scale = _swizzle_mxfp4( + layer.w2_weight, layer.w2_weight_scale, num_warps) + + self.w13_precision_config = PrecisionConfig( + weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)) + self.w2_precision_config = PrecisionConfig( + weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)) + + self.w13_weight_triton_tensor = w13_weight + self.w2_weight_triton_tensor = w2_weight + + # need to delete the original weights to save memory on single GPU + del layer.w13_weight + del layer.w2_weight + layer.w13_weight = None + layer.w2_weight = None + torch.cuda.empty_cache() + else: + # normal triton + from .triton_kernels_numerics_details.mxfp import upcast_from_mxfp + w13_weight = upcast_from_mxfp( + layer.w13_weight, layer.w13_weight_scale, dtype=torch.bfloat16, axis=-1 + ) + w2_weight = upcast_from_mxfp( + layer.w2_weight, layer.w2_weight_scale, dtype=torch.bfloat16, axis=-1 + ) + del layer.w13_weight + del layer.w2_weight + del layer.w13_weight_scale + del layer.w2_weight_scale + layer.w13_weight = Parameter(w13_weight.data, requires_grad=False) + layer.w2_weight = Parameter(w2_weight.data, requires_grad=False) + torch.cuda.empty_cache() + + + def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int): + # Number of tokens in the input tensor. + num_tokens = x.shape[0] + # Factor to account for the imbalance of the experts. + # factor equals to the + # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert + # - 1.0 means perfect expert distribution. + # - > 1.0 means some experts have more + # tokens than the perfect distribution. + # - < 1.0 does not make sense. + imbalance_factor = 1.3 + # Calculate the number of tokens per expert + # assuming perfect distribution. + num_tokens_per_expert = (num_tokens * top_k) // self.num_experts + # Apply the imbalance factor. + num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) + # And pad the number to the next power of 2. + tile_tokens_dim = next_power_of_2(num_tokens_per_expert) + # Cap to 8-64 tokens per CTA tile + # as it's the range supported by the kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + + return tile_tokens_dim + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + if enable_eplb: + raise NotImplementedError("EPLB is not supported for mxfp4") + + if self.use_marlin: + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + return torch.ops.vllm.fused_marlin_moe( + x, + layer.w13_weight, + layer.w2_weight, + layer.w13_bias, + layer.w2_bias, + layer.w13_weight_scale, + layer.w2_weight_scale, + router_logits, + topk_weights, + topk_ids, + global_scale1=None, + global_scale2=None, + quant_type_id=scalar_types.float4_e2m1f.id, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + activation=activation, + expert_map=expert_map) + + assert _can_support_mxfp4( + use_grouped_topk, topk_group, num_expert_group, expert_map, + custom_routing_function, e_score_correction_bias, + apply_router_weight_on_input, scoring_func, activation, + expert_load_view, logical_to_physical_map, + logical_replica_count), ( + "MXFP4 are not supported with this configuration.") + + if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 + or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): + assert not self.moe.use_ep, ( + "EP is not supported for flashinfer mxfp4 moe backend yet.") + if envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: + assert x.dtype == torch.bfloat16 + x_quant = x + x_scale = None + else: + x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8 + x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1) + trtllm_gen_output = trtllm_fp4_block_scale_moe( + router_logits.to(torch.bfloat16), + None, # routing_bias + x_quant, + x_scale, + layer.w13_weight, # uint8 (e2m1 x 2) + layer.w13_weight_scale, # uint8 (e4m3 x 2) + layer.w13_bias, # fp32 per expert per channel + layer.gemm1_alpha, # fp32 per expert + layer.gemm1_beta, # fp32 per expert + layer.gemm1_clamp_limit, # fp32 per expert + layer.w2_weight, # uint8 (e2m1 x 2) + layer.w2_weight_scale, # ue8m0 + layer.w2_bias, # fp32 per expert per channel + None, # output1_scale_scalar + None, # output1_scale_gate_scalar + None, # output2_scale_scalar + self.num_experts, + top_k, + None, # n_group + None, # topk_group + self.intermediate_size, # padded to multiple of 256 + 0, # local_expert_offset + self.num_experts, # local num experts + None, + self._get_tile_tokens_dim(x, top_k), + 1 if renormalize else 0, # routing_method_type, renormalize + True, # do finalize + )[0] + return trtllm_gen_output + elif has_triton_kernels(): + return triton_kernel_moe_forward( + hidden_states=x, + w1=self.w13_weight_triton_tensor, + w2=self.w2_weight_triton_tensor, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + w1_precision=self.w13_precision_config, + w2_precision=self.w2_precision_config, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + else: + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + ) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py index 3c56251..880438a 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py @@ -6,14 +6,16 @@ from typing import Any, Callable, Optional import torch import torch.nn.functional as F -import vllm.envs as envs +from vllm.logger import init_logger from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( - OCP_MX_BLOCK_SIZE, per_token_group_quant_mxfp4) + OCP_MX_BLOCK_SIZE, dequant_mxfp4, quant_dequant_mxfp4) from vllm.model_executor.parameter import (GroupQuantScaleParameter, PackedvLLMParameter) from vllm.platforms import current_platform +logger = init_logger(__name__) + __all__ = ["QuarkW4A4MXFP4"] @@ -25,7 +27,29 @@ class QuarkW4A4MXFP4(QuarkScheme): self.qscheme = "per_group" self.weight_quant_spec = weight_quant_spec self.input_quant_spec = input_quant_spec - self.emulate = not current_platform.supports_mx() + + self.static_input_scales = not input_quant_spec.get("is_dynamic") + + if self.static_input_scales: + raise NotImplementedError( + "QuarkW4A4MXFP4 with static input scales is currently not " + "implemented. Please open an issue.") + + if not current_platform.supports_mx(): + self.emulate = True + logger.warning_once( + "The current platform does not support native MXFP4 " + "computation. Simulated weight dequantization and activation " + "QDQ (quantize and dequantize) will be used, with the linear " + "layers computed in high precision.") + else: + self.emulate = True + logger.warning_once( + "The current platform supports native MXFP4 " + "computation, but kernels are not yet integrated in vLLM. " + "Simulated weight dequantization and activation " + "QDQ (quantize and dequantize) will be used, with the linear " + "layers computed in high precision.") @classmethod def get_min_capability(cls) -> int: @@ -37,43 +61,6 @@ class QuarkW4A4MXFP4(QuarkScheme): layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data, requires_grad=False) - if self.emulate: - try: - from quark.torch.export.nn.modules import realquantizer - from quark.torch.quantization.config.config import ( - QuantizationSpec) - except ImportError as err: - raise ImportError( - "The package `amd-quark` is required to use AMD Quark " - "MX-FP4 models. Please install it with `pip install " - "amd-quark`.") from err - - weight_quant_spec = QuantizationSpec.from_dict( - self.weight_quant_spec) - - weight_quantizer = realquantizer.get_real_quantizer( - qspec=weight_quant_spec, - quantizer=None, - real_quantized=True, - reorder=False, - float_dtype=self.out_dtype, - scale_shape=layer.weight_scale.shape, - zero_point_shape=None, - ) - weight_quantizer.scale.data = layer.weight_scale.data - - if not envs.VLLM_QUARK_EMU_MEM_OPT: - layer.weight = torch.nn.Parameter( - weight_quantizer(layer.weight.data).to(self.out_dtype), - requires_grad=False, - ) - else: - self.weight_quantizer = weight_quantizer - layer.weight_scale = None - - # This call is necessary to release the scales memory. - torch.cuda.empty_cache() - def create_weights(self, layer: torch.nn.Module, output_partition_sizes: list[int], input_size_per_partition: int, @@ -116,11 +103,10 @@ class QuarkW4A4MXFP4(QuarkScheme): bias: Optional[torch.Tensor] = None) -> torch.Tensor: if self.emulate: - if envs.VLLM_QUARK_EMU_MEM_OPT: - dq_w = self.weight_quantizer(layer.weight).to(self.out_dtype) - else: - dq_w = layer.weight - qdq_x, _ = per_token_group_quant_mxfp4(x, OCP_MX_BLOCK_SIZE) - return F.linear(qdq_x, dq_w, bias) + dq_w = dequant_mxfp4(layer.weight, layer.weight_scale, x.dtype) + + x = quant_dequant_mxfp4(x) + + return F.linear(x, dq_w, bias) else: raise NotImplementedError() diff --git a/vllm/model_executor/layers/quantization/triton_kernels_numerics_details/_downcast_to_mxfp.py b/vllm/model_executor/layers/quantization/triton_kernels_numerics_details/_downcast_to_mxfp.py new file mode 100644 index 0000000..57317db --- /dev/null +++ b/vllm/model_executor/layers/quantization/triton_kernels_numerics_details/_downcast_to_mxfp.py @@ -0,0 +1,158 @@ +import triton +import triton.language as tl + +# fmt: off + + +MXFP_BLOCK_SIZE = tl.constexpr(32) + + +@triton.jit +def _get_max_quant_val(dtype: tl.constexpr): + if dtype == tl.uint8: + return 6.0 + elif dtype == tl.float8e5: + return 57344.0 + elif dtype == tl.float8e4nv: + return 448.0 + else: + tl.static_assert(False, f"Invalid {dtype=}") + +@triton.jit +def _compute_quant_and_scale(src_tensor, valid_src_mask, mx_tensor_dtype: tl.constexpr, + DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr = 0): + is_fp8: tl.constexpr = mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5 + BLOCK_SIZE_OUT_DIM: tl.constexpr = src_tensor.shape[0] + BLOCK_SIZE_QUANT_DIM: tl.constexpr = src_tensor.shape[1] + BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = src_tensor.shape[1] // MXFP_BLOCK_SIZE + + # Explicit cast to fp32 since most ops are not supported on bfloat16. We avoid needless conversions to and from bf16 + f32_tensor = src_tensor.to(tl.float32) + abs_tensor = tl.abs(f32_tensor) + abs_tensor = tl.where(valid_src_mask, abs_tensor, -1.0) # Don't consider padding tensors in scale computation + abs_tensor = tl.reshape(abs_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE]) + max_val = tl.max(abs_tensor, axis=2, keep_dims=True) + dequant_scale = max_val / _get_max_quant_val(mx_tensor_dtype) + if DEQUANT_SCALE_ROUNDING_MODE == 0: + # DequantScaleRoundingMode.ROUND_UP + # compute 2 ** ceil(log2(dequant_scale)) + # Adding 0x007FFFFF adds exponent by 1 unless mantissa is all zeros + # A corner case: exponent is 0xFF that will overflow but that's already + # NaN so assume we don't care. + dequant_scale_exponent = (dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF) & 0x7F800000 + else: + # DequantScaleRoundingMode.ROUND_DOWN + # compute 2 ** floor(log2(dequant_scale)) + assert DEQUANT_SCALE_ROUNDING_MODE == 1 + dequant_scale_exponent = dequant_scale.to(tl.uint32, bitcast=True) & 0x7F800000 + dequant_scale_rounded = dequant_scale_exponent.to(tl.float32, bitcast=True) + quant_scale = tl.where(dequant_scale_rounded == 0, 0, 1.0 / dequant_scale_rounded) + + f32_tensor = tl.reshape(f32_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE]) + quant_tensor = f32_tensor * quant_scale + + # Reshape the tensors after scaling + quant_tensor = quant_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM]) + # Set the invalid portions of the tensor to 0. This will ensure that any padding tensors are 0 in the mx format. + quant_tensor = tl.where(valid_src_mask, quant_tensor, 0) + dequant_scale_exponent = dequant_scale_exponent.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE]) + + # First, we simply extract the exponent part of the scales and store the result + dequant_scale_exponent = (dequant_scale_exponent >> 23).to(tl.uint8) + # Now we must convert the tensors to the mx format. + if is_fp8: + out_tensor = quant_tensor.to(mx_tensor_dtype) + else: + quant_tensor = quant_tensor.to(tl.uint32, bitcast=True) + signs = quant_tensor & 0x80000000 + exponents = (quant_tensor >> 23) & 0xFF + mantissas = (quant_tensor & 0x7FFFFF) + + # 0.25 <= x < 0.75 maps to 0.5, a denormal number + E8_BIAS = 127 + E2_BIAS = 1 + # Move implicit bit 1 at the beginning to mantissa for denormals + adjusted_exponents = tl.core.sub(E8_BIAS, exponents + 1, sanitize_overflow=False) + mantissas = tl.where(exponents < E8_BIAS, (0x400000 | (mantissas >> 1)) >> adjusted_exponents, mantissas) + + # For normal numbers, we change the bias from 127 to 1, and for subnormals, we keep exponent as 0. + exponents = tl.maximum(exponents, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS) + + # Combine sign, exponent, and mantissa, while saturating + # rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right + e2m1_tmp = tl.minimum((((exponents << 2) | (mantissas >> 21)) + 1) >> 1, 0x7) + e2m1_value = ((signs >> 28) | e2m1_tmp).to(tl.uint8) + + e2m1_value = tl.reshape(e2m1_value, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM // 2, 2]) + evens, odds = tl.split(e2m1_value) + out_tensor = evens | (odds << 4) + + return out_tensor, dequant_scale_exponent + +@triton.jit +def _downcast_to_mxfp(mx_tensor_ptr, stride_mxt_outer, stride_mxt_quant: tl.constexpr, + mx_scale_ptr, stride_mx_scale_outer, stride_mx_scale_quant, + src_ptr, stride_src_outer, stride_src_quant, + outer_dim, quant_dim, + BLOCK_SIZE_OUT_DIM: tl.constexpr, BLOCK_SIZE_QUANT_DIM: tl.constexpr, + DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr): + + tl.static_assert(stride_mxt_quant == 1, f"Output stride, {stride_mxt_quant=} must be 1.") + tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, f"{BLOCK_SIZE_QUANT_DIM=} must be a multiple of 32") + + # uint8 signifies two fp4 e2m1 values packed into a single byte + mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty + tl.static_assert(mx_tensor_dtype == tl.uint8 or (mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5), + f"Invalid {mx_tensor_dtype=}. Must be uint8 or float8.") + + src_dtype: tl.constexpr = src_ptr.dtype.element_ty + tl.static_assert(mx_scale_ptr.dtype.element_ty == tl.uint8, f"{mx_scale_ptr.dtype.element_ty=} must be uint8") + tl.static_assert((src_dtype == tl.bfloat16) or (src_dtype == tl.float16) or (src_dtype == tl.float32), f"{src_dtype=} must be bfloat16 or float16 or float32") + is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8 + + outer_block = tl.program_id(0).to(tl.int64) + quant_block = tl.program_id(1).to(tl.int64) + + K_DIVISOR: tl.constexpr = 2 if is_fp4 else 1 + BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE + BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR + + start_src_quant = quant_block * BLOCK_SIZE_QUANT_DIM + start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE + start_mx_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR + start_out = outer_block * BLOCK_SIZE_OUT_DIM + + src_ptr += start_src_quant * stride_src_quant + start_out * stride_src_outer + mx_scale_ptr += start_mx_scale_quant * stride_mx_scale_quant + start_out * stride_mx_scale_outer + mx_tensor_ptr += start_mx_quant * stride_mxt_quant + start_out * stride_mxt_outer + + offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64) + offs_mxt_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64) + offs_scale_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64) + offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64) + + mask_src_quant = start_src_quant + offs_src_quant < quant_dim + mask_n = start_out + offs_outer < outer_dim + full_mask_src = mask_src_quant & mask_n + + mask_mxt_quant = start_mx_quant + offs_mxt_quant < tl.cdiv(quant_dim, K_DIVISOR) + full_mask_mxt = mask_mxt_quant & mask_n + + scale_mask_k = start_mx_scale_quant + offs_scale_quant < tl.cdiv(quant_dim, MXFP_BLOCK_SIZE) + full_scale_mask = scale_mask_k & mask_n + + src_tensor_offsets = offs_src_quant * stride_src_quant + offs_outer * stride_src_outer + mx_scale_offsets = offs_scale_quant * stride_mx_scale_quant + offs_outer * stride_mx_scale_outer + mx_tensor_offsets = offs_mxt_quant * stride_mxt_quant + offs_outer * stride_mxt_outer + src_tensor = tl.load(src_ptr + src_tensor_offsets, mask=full_mask_src) + + out_tensor, scale_tensor = _compute_quant_and_scale(src_tensor, full_mask_src, mx_tensor_dtype, + DEQUANT_SCALE_ROUNDING_MODE) + + tl.store(mx_scale_ptr + mx_scale_offsets, scale_tensor, mask=full_scale_mask) + tl.store(mx_tensor_ptr + mx_tensor_offsets, out_tensor, mask=full_mask_mxt) + + +@triton.jit(repr=lambda _: "_dequantize_mxfp8") +def _dequantize_mxfp8_fn(input, mask, pid=None): + return _compute_quant_and_scale(input, mask, tl.float8e4nv) \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/triton_kernels_numerics_details/_upcast_from_mxfp.py b/vllm/model_executor/layers/quantization/triton_kernels_numerics_details/_upcast_from_mxfp.py new file mode 100644 index 0000000..3ff2da8 --- /dev/null +++ b/vllm/model_executor/layers/quantization/triton_kernels_numerics_details/_upcast_from_mxfp.py @@ -0,0 +1,136 @@ +import triton +import triton.language as tl + +from ._downcast_to_mxfp import MXFP_BLOCK_SIZE + + +# fmt: off +@triton.jit +def _upcast_from_mxfp(out_ptr, stride_o_outer, stride_o_quant: tl.constexpr, mx_scale_ptr, stride_scale_outer, + stride_scale_quant, mx_tensor_ptr, stride_tensor_outer, stride_tensor_quant: tl.constexpr, + outer_dim, quant_dim, BLOCK_SIZE_OUT_DIM: tl.constexpr, BLOCK_SIZE_QUANT_DIM: tl.constexpr): + + tl.static_assert(stride_o_quant == 1, "the weight must be contiguous in the k dimension for mx") + tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, "BLOCK_SIZE_K must be a multiple of 32") + # uint8 signifies two fp4 e2m1 values packed into a single byte + mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty + dst_dtype: tl.constexpr = out_ptr.dtype.element_ty + tl.static_assert(dst_dtype == tl.float16 or (dst_dtype == tl.bfloat16 or dst_dtype == tl.float32)) + tl.static_assert( + mx_tensor_dtype == tl.uint8 + or ((mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5) or mx_tensor_dtype == dst_dtype), + "mx_tensor_ptr must be uint8 or float8 or dst_dtype") + tl.static_assert(mx_scale_ptr.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8") + + # Determine if we are dealing with fp8 types. + is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8 + is_fp8: tl.constexpr = mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5 + K_DIVISOR: tl.constexpr = 2 if is_fp4 else 1 + BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE + BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR + + # Compute starting indices for the quantized (packed) dimension and the outer dimension. + outer_block = tl.program_id(0).to(tl.int64) + quant_block = tl.program_id(1).to(tl.int64) + + start_mxt_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR + start_out_quant = quant_block * BLOCK_SIZE_QUANT_DIM + start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE + start_out = outer_block * BLOCK_SIZE_OUT_DIM + + mx_tensor_ptr += start_mxt_quant * stride_tensor_quant + start_out * stride_tensor_outer + mx_scale_ptr += start_mx_scale_quant * stride_scale_quant + start_out * stride_scale_outer + out_ptr += start_out * stride_o_outer + start_out_quant * stride_o_quant + + # Compute offsets and masks. + offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64) + offs_out_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64) + offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64) + offs_scale = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64) + + mask_outer = start_out + offs_outer < outer_dim + mask_out_quant = start_out_quant + offs_out_quant < quant_dim + full_mask_out = mask_out_quant & mask_outer + + mask_src_quant = start_mxt_quant + offs_src_quant < tl.cdiv(quant_dim, K_DIVISOR) + full_mask_src = mask_src_quant & mask_outer + + mask_scale = start_mx_scale_quant + offs_scale < tl.cdiv(quant_dim, MXFP_BLOCK_SIZE) + full_scale_mask = mask_scale & mask_outer + + tensor_offsets = offs_src_quant * stride_tensor_quant + offs_outer * stride_tensor_outer + scale_offsets = offs_scale * stride_scale_quant + offs_outer * stride_scale_outer + out_offsets = offs_out_quant * stride_o_quant + offs_outer * stride_o_outer + + # Load the packed tensor and scale. + tensor = tl.load(mx_tensor_ptr + tensor_offsets, mask=full_mask_src) + scale = tl.load(mx_scale_ptr + scale_offsets, mask=full_scale_mask) + + # Upcast the scale to the destination type. + if dst_dtype == tl.bfloat16: + # dst_scale = (scale.to(tl.uint16) << 7).to(dst_dtype, bitcast=True) + dst_scale = (scale.to(tl.uint16) << 7).to(tl.uint16).to(tl.bfloat16, bitcast=True) + else: + dst_scale = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True) + if dst_dtype == tl.float16: + dst_scale = dst_scale.to(tl.float16) + + # Now upcast the tensor. + intermediate_dtype: tl.constexpr = tl.bfloat16 if dst_dtype == tl.float32 else dst_dtype + if is_fp8: + dst_tensor = tensor.to(intermediate_dtype) + if tensor.dtype == tl.float8e5: + from_e_bits: tl.constexpr = 5 + from_m_bits: tl.constexpr = 2 + to_e_bits: tl.constexpr = 8 if intermediate_dtype == tl.bfloat16 else 5 + to_m_bits: tl.constexpr = 7 if intermediate_dtype == tl.bfloat16 else 10 + + # Preserve infs and nans. FIXME Fp8E5M2_to_Bf16 doesn't preserve them! + non_finite_mask_src: tl.constexpr = ((1 << from_e_bits) - 1) << from_m_bits + non_finite_mask_dst: tl.constexpr = ((1 << to_e_bits) - 1) << to_m_bits + dst_tensor = tl.where( + (tensor.to(tl.uint8, bitcast=True) & non_finite_mask_src) == non_finite_mask_src, + (dst_tensor.to(tl.uint16, bitcast=True) | non_finite_mask_dst).to(intermediate_dtype, bitcast=True), + dst_tensor, + ) + else: + assert is_fp4 + dst_bias: tl.constexpr = 127 if intermediate_dtype == tl.bfloat16 else 15 + dst_0p5: tl.constexpr = 16128 if intermediate_dtype == tl.bfloat16 else 0x3800 + dst_m_bits: tl.constexpr = 7 if intermediate_dtype == tl.bfloat16 else 10 + # e2m1 + em0 = tensor & 0x07 + em1 = tensor & 0x70 + x0 = (em0.to(tl.uint16) << (dst_m_bits - 1)) | ((tensor & 0x08).to(tl.uint16) << 12) + x1 = (em1.to(tl.uint16) << (dst_m_bits - 5)) | ((tensor & 0x80).to(tl.uint16) << 8) + # Three cases: + # 1) x is normal and non-zero: Correct bias + x0 = tl.where((em0 & 0x06) != 0, x0 + ((dst_bias - 1) << dst_m_bits), x0) + x1 = tl.where((em1 & 0x60) != 0, x1 + ((dst_bias - 1) << dst_m_bits), x1) + # 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type + x0 = tl.where(em0 == 0x01, dst_0p5 | (x0 & 0x8000), x0) + x1 = tl.where(em1 == 0x10, dst_0p5 | (x1 & 0x8000), x1) + # 3) x is zero, do nothing + if intermediate_dtype == tl.bfloat16: + dst_tensor = tl.interleave(x0, x1).to(tl.uint16).to(tl.bfloat16, bitcast=True) + else: + dst_tensor = tl.interleave(x0, x1).to(tl.float16, bitcast=True) + # dst_tensor = dst_tensor.to(dst_dtype) + if dst_dtype == tl.bfloat16: + dst_tensor = dst_tensor.to(tl.bfloat16) + elif dst_dtype == tl.float16: + dst_tensor = dst_tensor.to(tl.float16) + else: + dst_tensor = dst_tensor.to(tl.float32) + + + # Reshape for proper broadcasting: the scale was stored with a 32‐sized “inner” grouping. + dst_tensor = dst_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE]) + dst_scale = dst_scale.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1]) + scale = scale.reshape(dst_scale.shape) + + out_tensor = dst_tensor * dst_scale + # Correct any NaNs encoded via the scale. + out_tensor = tl.where(scale == 0xFF, float("nan"), out_tensor) + out_tensor = out_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM]) + tl.store(out_ptr + out_offsets, out_tensor, mask=full_mask_out) diff --git a/vllm/model_executor/layers/quantization/triton_kernels_numerics_details/mxfp.py b/vllm/model_executor/layers/quantization/triton_kernels_numerics_details/mxfp.py new file mode 100644 index 0000000..0fee19c --- /dev/null +++ b/vllm/model_executor/layers/quantization/triton_kernels_numerics_details/mxfp.py @@ -0,0 +1,303 @@ +# isort: off +# fmt: off +from enum import Enum +import triton +import torch +import torch.nn.functional as F +from ._upcast_from_mxfp import _upcast_from_mxfp +from ._downcast_to_mxfp import _downcast_to_mxfp, _dequantize_mxfp8_fn, MXFP_BLOCK_SIZE + +# ----------------------------------------------------------------------------- +# Dequantization / Quantization Utilities +# ----------------------------------------------------------------------------- + + +class DequantScaleRoundingMode(Enum): + ROUND_UP = 0 + ROUND_DOWN = 1 + + +def downcast_to_mxfp(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis: int, + DEQUANT_SCALE_ROUNDING_MODE: DequantScaleRoundingMode = DequantScaleRoundingMode.ROUND_UP): + """ + Convert the src weights to mx format. The src weight is quantized along the axis dimension. + + If weight_quant_type is torch.uint8, we output mxfp4 where two e2m1 values are packed into a single byte. + Note that this means the k_dim of the tensor will be half of the logical k_dim. + + If weight_quant_type is torch.float8_e4m3fn or torch.float8_e5m2, we output mxfp8 with the float8s are stored + in their respective formats. + """ + ndim = src_tensor.ndim + assert -ndim <= axis < ndim, f"Invalid axis {axis=}" + axis = axis if axis >= 0 else axis + ndim + # downcast + src_tensor = src_tensor.transpose(axis, src_tensor.ndim - 1) + is_fp4 = out_quant_type == torch.uint8 + is_fp8 = out_quant_type in (torch.float8_e4m3fn, torch.float8_e5m2) + assert is_fp4 or is_fp8 + divisor = 2 if is_fp4 else 1 + L = src_tensor.shape[-1] + if is_fp4: + assert L % 2 == 0, f"axis dim must be divisible by 2 for e2m1. Got {L}" + out_shape = src_tensor.shape[:-1] + (L // divisor, ) + out_scale_shape = src_tensor.shape[:-1] + (triton.cdiv(L, MXFP_BLOCK_SIZE), ) + + out_quant_tensor = src_tensor.new_empty(out_shape, dtype=out_quant_type) + out_scale = src_tensor.new_empty(out_scale_shape, dtype=torch.uint8) + + if src_tensor.numel() > 0: + kernel_src_tensor = src_tensor.reshape(-1, src_tensor.shape[-1]) + kernel_quant_tensor = out_quant_tensor.view(-1, out_quant_tensor.shape[-1]) + kernel_scale = out_scale.view(-1, out_scale.shape[-1]) + + BLOCK_OUT_DIM = 128 + BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE.value + grid_out = triton.cdiv(kernel_src_tensor.shape[0], BLOCK_OUT_DIM) + grid_quant = triton.cdiv(kernel_src_tensor.shape[1], BLOCK_QUANT_DIM) + + _downcast_to_mxfp[(grid_out, grid_quant)](kernel_quant_tensor, *kernel_quant_tensor.stride(), kernel_scale, + *kernel_scale.stride(), kernel_src_tensor, *kernel_src_tensor.stride(), + *kernel_src_tensor.shape, BLOCK_OUT_DIM, BLOCK_QUANT_DIM, + DEQUANT_SCALE_ROUNDING_MODE.value, num_warps=8) + + out_quant_tensor = out_quant_tensor.transpose(axis, src_tensor.ndim - 1) + out_scale = out_scale.transpose(axis, src_tensor.ndim - 1) + return out_quant_tensor, out_scale + + +def upcast_from_mxfp(tensor: torch.Tensor, scale: torch.Tensor, dtype: torch.dtype, axis: int): + """ + Upcasts an mxfp (packed) weight tensor back to float16 or bfloat16. + + The function assumes that the tensors were quantized along the given axis. + It permutes the tensor so that the quantized axis is last, reshapes to 2D, + launches the Triton upcast kernel, and then unpermutes back to the original order. + """ + ndim = tensor.ndim + assert -ndim <= axis < ndim, f"Invalid axis {axis=}" + axis = axis if axis >= 0 else axis + ndim + assert tensor.ndim == scale.ndim, (f"Weight and scale must have the same number of dimensions. " + f"Got {tensor.ndim=} and {scale.ndim=}") + # dtype checks + assert tensor.dtype in {torch.uint8, torch.float8_e5m2, torch.float8_e4m3fn}, \ + f"Invalid tensor dtype {tensor.dtype=}" + assert scale.dtype == torch.uint8, f"Invalid scale dtype {scale.dtype=}" + assert dtype in (torch.float16, torch.bfloat16, torch.float32), f"Invalid output dtype {dtype=}" + # upcast + logical_quant_dim = tensor.shape[axis] * (2 if tensor.dtype == torch.uint8 else 1) + tensor = tensor.transpose(axis, tensor.ndim - 1).contiguous() + scale = scale.transpose(axis, scale.ndim - 1).contiguous() + out = torch.empty((*tensor.shape[:-1], logical_quant_dim), dtype=dtype, device=tensor.device) + reshaped_out = out.view(-1, out.shape[-1]) + reshaped_tensor = tensor.view(-1, tensor.shape[-1]) + reshaped_scale = scale.view(-1, scale.shape[-1]) + BLOCK_OUT_DIM = 128 + BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE.value + blocks_out_dim = triton.cdiv(reshaped_out.shape[0], BLOCK_OUT_DIM) + blocks_quant_dim = triton.cdiv(reshaped_out.shape[1], BLOCK_QUANT_DIM) + _upcast_from_mxfp[(blocks_out_dim, blocks_quant_dim)](reshaped_out, *reshaped_out.stride(), reshaped_scale, + *reshaped_scale.stride(), reshaped_tensor, + *reshaped_tensor.stride(), *reshaped_out.shape, BLOCK_OUT_DIM, + BLOCK_QUANT_DIM, num_warps=8) + out = out.transpose(axis, scale.ndim - 1).contiguous() + return out + + +# ------------ + + +def right_shift_unsigned(x, shift): + # CUDA torch does not support bit ops on uint32, so we need to mask to get unsigned right shift + return (x >> shift) & ((1 << (32 - shift)) - 1) + + +def get_max_quant_val(dtype: torch.dtype): + d = {torch.uint8: 6.0, torch.float8_e5m2: 57344.0, torch.float8_e4m3fn: 448.0} + assert dtype in d + return d[dtype] + + +def downcast_to_mxfp_torch(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis: int, + DEQUANT_SCALE_ROUNDING_MODE: DequantScaleRoundingMode = DequantScaleRoundingMode.ROUND_UP): + """ + Converts the src tensor to the output format specified by out_quant_type. + axis: The axis along which the tensors are contiguous and quantization is applied. + DEQUANT_SCALE_ROUNDING_MODE: 0 for ROUND_UP, 1 for ROUND_DOWN. + + Returns: + out_quant_tensor: Quantized tensor in mx format. + • For mxfp8, the output has the same shape as src_tensor. + • For mxfp4, the size along the axis is halved, and the tensor is returned as a torch.uint8. + scale: Scale tensor (stored as uint8) computed per group of 32 elements along the axis. + Its shape is the same as src_tensor except that the axis is replaced by ceil(L/32), + where L is the original length along that axis. + """ + # This should probably be packed into its own tiny class + ndim = src_tensor.ndim + assert -ndim <= axis < ndim, f"Invalid axis {axis=}" + assert src_tensor.dtype in {torch.float32, torch.bfloat16, + torch.float16}, f"Invalid input tensor dtype {src_tensor.dtype}" + + axis = axis if axis >= 0 else axis + ndim + is_fp4 = out_quant_type == torch.uint8 + is_fp8 = "float8" in str(out_quant_type) + assert is_fp4 or is_fp8, f"Invalid input tensor dtype {out_quant_type}" + + device = src_tensor.device + + # For mxfp4 conversion, we assume the contiguous axis length is even. + if is_fp4: + axis_shape = src_tensor.size(axis) + assert axis_shape % 2 == 0, "For mxfp4 conversion the contiguous axis length must be even." + + # Permute the tensor so that the contiguous axis becomes the last dimension. + src = src_tensor.transpose(axis, src_tensor.ndim - 1).to(torch.float32) + axis_shape = src.shape[-1] + + # Pad the axis to be divisible by 32, in case it is not. + next_multiple = triton.cdiv(axis_shape, MXFP_BLOCK_SIZE) * MXFP_BLOCK_SIZE + pad_amount = next_multiple - axis_shape + padded_src = F.pad(src, (0, pad_amount)) + valid_mask = F.pad(torch.ones_like(src, dtype=torch.bool), (0, pad_amount)) + padded_axis_shape = padded_src.size(-1) # now divisible by 32 + + # --- Compute per-group maximums for scale --- + # Set padded entries to -1 so they don’t affect the max. + abs_f = torch.abs(padded_src) + abs_f = torch.where(valid_mask, abs_f, torch.tensor(-1.0, device=device, dtype=padded_src.dtype)) + # Reshape the last dimension into groups of 32. + new_shape = padded_src.shape[:-1] + (padded_axis_shape // MXFP_BLOCK_SIZE, MXFP_BLOCK_SIZE) + abs_groups = abs_f.view(*new_shape) + # Compute maximum along the group dimension (of size 32). + max_val, _ = abs_groups.max(dim=-1, keepdim=True) + + # Choose a max quantization value depending on type. + max_quant_val = get_max_quant_val(out_quant_type) + dequant_scale = max_val / max_quant_val # shape: (..., padded_axis_shape//32, 1) + + # Convert to int to round the FP32 scale, prior to quantization! + ds_int = dequant_scale.view(torch.int32) + if DEQUANT_SCALE_ROUNDING_MODE == DequantScaleRoundingMode.ROUND_UP: + ds_int_rounded = (ds_int + 0x007FFFFF) & 0x7F800000 + else: + ds_int_rounded = ds_int & 0x7F800000 + # Reinterpret back as float32. + dequant_scale_rounded = ds_int_rounded.view(torch.float32) + + # Compute the quantization scale. + quant_scale = torch.where(dequant_scale_rounded == 0, torch.tensor(0.0, device=device), 1.0 / dequant_scale_rounded) + + # Quantize the tensor + orig_padded_shape = padded_src.shape + padded_src_groups = padded_src.view(*new_shape) + quant_tensor = padded_src_groups * quant_scale + # Reshape back to the original shape and trim padding + quant_tensor = quant_tensor.view(orig_padded_shape) + quant_tensor = quant_tensor[..., :axis_shape] + + # Finally, convert the quantized tensor to the target format + if is_fp8: + # Conversion must use satfinite PTX, so clamp before the conversion in torch to emulate this behavior + quant_tensor = torch.clamp(quant_tensor, -max_quant_val, max_quant_val) + out_weight = quant_tensor.to(out_quant_type) + else: + assert is_fp4, f"Invalid output quantization type {out_quant_type}" + # For mxfp4, perform bit-level manipulation and pack two 4-bit values per uint8. + # First, reinterpret the quantized tensor bits. + q_int = quant_tensor.contiguous().view(torch.int32) + # Extract sign, exponent, and mantissa. + signs = q_int & 0x80000000 + exponents = right_shift_unsigned(q_int, 23) & 0xFF + mantissas = q_int & 0x7FFFFF + + E8_BIAS = 127 + E2_BIAS = 1 + # Adjust mantissas for subnormals. + mantissas = torch.where(exponents < E8_BIAS, (0x400000 | right_shift_unsigned(mantissas, 1)) >> + (E8_BIAS - exponents - 1), mantissas) + exponents = torch.maximum(exponents, torch.tensor(E8_BIAS - E2_BIAS, device=device)) - (E8_BIAS - E2_BIAS) + e2m1_tmp = right_shift_unsigned(((exponents << 2) | right_shift_unsigned(mantissas, 21)) + 1, 1) + e2m1_tmp = torch.minimum(e2m1_tmp, torch.tensor(0x7, device=device)) + e2m1_value = (right_shift_unsigned(signs, 28) | e2m1_tmp).to(torch.uint8) # shape: (..., even_axis_shape) + + # Pack pairs of 4-bit values along the last dimension. + e2m1_value = e2m1_value.view(*e2m1_value.shape[:-1], axis_shape // 2, 2) + evens = e2m1_value[..., 0] + odds = e2m1_value[..., 1] + out_weight = evens | (odds << 4) # shape: (..., axis_shape//2) + + # --- Process and output the scale --- + dq_scale = (ds_int_rounded.view(*dequant_scale.shape) >> 23).to(torch.uint8) # shape: (..., axis_shape//32, 1) + dq_scale = dq_scale.squeeze(-1) + out_weight = out_weight.transpose(axis, src_tensor.ndim - 1) + dq_scale = dq_scale.transpose(axis, src_tensor.ndim - 1) + return out_weight, dq_scale + + +def cvt_e2m1_to_fp32(input_tensor): + assert input_tensor.dtype == torch.uint8 + + input_tensor = input_tensor.to(torch.int32) + evens = input_tensor & 0xF + odds = (input_tensor >> 4) & 0xF + + vals = [0.0, 0.5, 1, 1.5, 2, 3, 4, 6] + outputs = torch.tensor(vals, dtype=torch.float32, device=input_tensor.device) + outputs = torch.cat([outputs, -outputs]) + + even_floats = outputs[evens] + odd_floats = outputs[odds] + output_tensor = torch.stack([even_floats, odd_floats], dim=-1) + output_tensor = output_tensor.view(*input_tensor.shape[:-1], -1) + return output_tensor + + +def upcast_from_mxfp_torch(tensor: torch.Tensor, scale: torch.Tensor, target_dtype: torch.dtype, axis: int): + """ + Converts the mxfp4/mxfp8 tensor to the target format specified by target_dtype. + axis: The axis along which dequantization is applied. + + Returns: + out_weight: Tensor in the target format. + """ + + ndim = tensor.ndim + assert -ndim <= axis < ndim, f"Invalid axis {axis=}" + is_fp8 = tensor.dtype == torch.float8_e4m3fn or tensor.dtype == torch.float8_e5m2 + assert is_fp8 or tensor.dtype == torch.uint8, f"Invalid input quantization type {tensor.dtype}" + + # Permute the tensor and scale so that the quantization axis becomes the last dimension + axis = axis if axis >= 0 else axis + ndim + scale = scale.transpose(axis, scale.ndim - 1) + tensor = tensor.transpose(axis, tensor.ndim - 1) + + dq_scale = (scale.to(torch.int32) << 23).view(torch.float32) # Shift to the exponent and bitcast to fp32 + if tensor.dtype == torch.uint8: + fp32_tensor = cvt_e2m1_to_fp32(tensor) + else: + fp32_tensor = tensor.to(torch.float32) + + logical_quant_dim = tensor.shape[-1] * (2 if tensor.dtype == torch.uint8 else 1) + axis_shape = fp32_tensor.size(-1) + padded_axis_shape = triton.cdiv(logical_quant_dim, MXFP_BLOCK_SIZE) * MXFP_BLOCK_SIZE + pad_size = padded_axis_shape - axis_shape + padded_tensor = F.pad(fp32_tensor, (0, pad_size)) + + new_axis_shape = padded_tensor.shape[-1] + new_shape = padded_tensor.shape[:-1] + (new_axis_shape // MXFP_BLOCK_SIZE, MXFP_BLOCK_SIZE) + padded_tensor = padded_tensor.view(*new_shape) + dq_scale_padded = dq_scale.unsqueeze(-1) # shape: [..., ceil(axis_shape/32), 1] + out_padded = padded_tensor * dq_scale_padded + + # Flatten back and remove the padded tail + out_padded = out_padded.view(*fp32_tensor.shape[:-1], new_axis_shape) + out_tensor = out_padded[..., :axis_shape] + + out_tensor = out_tensor.to(target_dtype).contiguous() + out_tensor = out_tensor.transpose(axis, tensor.ndim - 1) + + return out_tensor + + +dequantize_mxfp8_fn = _dequantize_mxfp8_fn \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 7540a15..02057b4 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -261,6 +261,13 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, return s +def marlin_permute_bias(s: torch.Tensor) -> torch.Tensor: + origin_shape = s.shape + _, scale_perm_single = get_scale_perms() + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + return s.reshape(*origin_shape).contiguous() + + def marlin_moe_permute_scales( s: torch.Tensor, size_k: int, @@ -410,6 +417,7 @@ def apply_gptq_marlin_linear( output = ops.gptq_marlin_gemm(reshaped_x, None, weight, + bias, weight_scale, None, weight_zp, @@ -425,9 +433,6 @@ def apply_gptq_marlin_linear( use_fp32_reduce=use_fp32_reduce, is_zp_float=False) - if bias is not None: - output.add_(bias) # In-place add - return output.reshape(out_shape) @@ -456,6 +461,7 @@ def apply_awq_marlin_linear( output = ops.gptq_marlin_gemm(reshaped_x, None, weight, + bias, weight_scale, None, weight_zp, @@ -470,7 +476,4 @@ def apply_awq_marlin_linear( use_fp32_reduce=use_fp32_reduce, is_zp_float=False) - if bias is not None: - output.add_(bias) # In-place add - return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py index ca10db6..94ffdcd 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -8,8 +8,8 @@ import torch import vllm._custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales, - should_use_atomic_add_reduce) + USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_bias, + marlin_permute_scales, should_use_atomic_add_reduce) from vllm.platforms import current_platform from vllm.scalar_type import scalar_types @@ -22,7 +22,7 @@ def is_fp4_marlin_supported(): return current_platform.has_device_capability(80) -def fp4_marlin_process_scales(marlin_scales): +def nvfp4_marlin_process_scales(marlin_scales): if not (marlin_scales >= 0).all(): logger.warning_once( "NVFP4 Marlin assumes the scales to be >=0, but has encountered " @@ -56,7 +56,20 @@ def fp4_marlin_process_scales(marlin_scales): return marlin_scales -def fp4_marlin_process_global_scale(global_scale): +def mxfp4_marlin_process_scales(marlin_scales): + # 8 is the number of scale number using by one thread + marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8) + marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape( + marlin_scales.size(0) * 2, -1) + + # fit the layout of fp8 dequantization + marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( + marlin_scales.size(0), -1) + marlin_scales = marlin_scales.to(torch.float8_e8m0fnu) + return marlin_scales + + +def nvfp4_marlin_process_global_scale(global_scale): assert global_scale.dtype in [torch.half, torch.bfloat16] fp4_exponent = 2 if global_scale.dtype == torch.half: @@ -73,7 +86,7 @@ def apply_fp4_marlin_linear( input: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, - weight_scale_2: torch.Tensor, + weight_scale_2: Optional[torch.Tensor], workspace: torch.Tensor, size_n: int, size_k: int, @@ -94,6 +107,7 @@ def apply_fp4_marlin_linear( output = ops.gptq_marlin_gemm(a=reshaped_x, c=None, b_q_weight=weight, + b_bias=bias, b_scales=weight_scale, global_scale=weight_scale_2, b_zeros=None, @@ -107,9 +121,6 @@ def apply_fp4_marlin_linear( use_atomic_add=use_atomic_add, use_fp32_reduce=use_fp32_reduce) - if bias is not None: - output.add_(bias) # In-place add - return output.reshape(out_shape) @@ -120,6 +131,9 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: "be used leveraging the Marlin kernel. This may degrade " "performance for compute-heavy workloads.") + is_nvfp4 = hasattr(layer, "weight_scale_2") + group_size = 16 if is_nvfp4 else 32 + part_size_n = layer.output_size_per_partition part_size_k = layer.input_size_per_partition param_dtype = layer.params_dtype @@ -145,18 +159,35 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: # WEIGHT SCALES # Permute scales - weight_scale = layer.weight_scale.T.to(param_dtype) + weight_scale = layer.weight_scale.T.contiguous() + + if not is_nvfp4: + weight_scale = weight_scale.view(torch.float8_e8m0fnu) + + weight_scale = weight_scale.to(param_dtype) weight_scale = marlin_permute_scales(s=weight_scale, size_k=part_size_k, size_n=part_size_n, - group_size=16) - weight_scale = fp4_marlin_process_scales(weight_scale) - layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + group_size=group_size) - weight_scale_2 = layer.weight_scale_2.to(param_dtype) - weight_scale_2 = fp4_marlin_process_global_scale(weight_scale_2) - layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, - requires_grad=False) + if is_nvfp4: + weight_scale = nvfp4_marlin_process_scales(weight_scale) + layer.weight_scale = torch.nn.Parameter(weight_scale, + requires_grad=False) + + weight_scale_2 = layer.weight_scale_2.to(param_dtype) + weight_scale_2 = nvfp4_marlin_process_global_scale(weight_scale_2) + layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, + requires_grad=False) + else: + weight_scale = mxfp4_marlin_process_scales(weight_scale) + layer.weight_scale = torch.nn.Parameter(weight_scale, + requires_grad=False) + + if hasattr(layer, "bias") and layer.bias is not None: + assert layer.bias.shape == (part_size_n, ) + bias = marlin_permute_bias(layer.bias) + layer.bias = torch.nn.Parameter(bias, requires_grad=False) return @@ -168,6 +199,9 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: "be used leveraging the Marlin kernel. This may degrade " "performance for compute-heavy workloads.") + is_nvfp4 = hasattr(layer, "w13_weight_scale_2") + group_size = 16 if is_nvfp4 else 32 + e = layer.num_experts k = layer.hidden_size n = layer.intermediate_size_per_partition @@ -208,8 +242,13 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: # WEIGHT SCALES # Permute scales for name in ["w13", "w2"]: - scales = getattr(layer, name + "_weight_scale").to(param_dtype) - global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype) + scales = getattr(layer, name + "_weight_scale") + if not is_nvfp4: + scales = scales.view(torch.float8_e8m0fnu) + scales = scales.to(param_dtype) + if is_nvfp4: + global_scale = getattr(layer, + name + "_weight_scale_2").to(param_dtype) tensor_list = [] if "w13" in name: @@ -218,23 +257,47 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: size_n, size_k = k, n for i in range(e): - marlin_scales = marlin_permute_scales(s=scales[i].T, + scale = scales[i].T + + marlin_scales = marlin_permute_scales(s=scale, size_k=size_k, size_n=size_n, - group_size=16) - marlin_scales = fp4_marlin_process_scales(marlin_scales) + group_size=group_size) + if is_nvfp4: + marlin_scales = nvfp4_marlin_process_scales(marlin_scales) + else: + marlin_scales = mxfp4_marlin_process_scales(marlin_scales) tensor_list.append(marlin_scales) scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) scales = torch.nn.Parameter(scales, requires_grad=False) setattr(layer, name + "_weight_scale", scales) - global_scale = fp4_marlin_process_global_scale(global_scale) - global_scale = torch.nn.Parameter(global_scale, requires_grad=False) - setattr(layer, name + "_weight_scale_2", global_scale) + if is_nvfp4: + global_scale = nvfp4_marlin_process_global_scale(global_scale) + global_scale = torch.nn.Parameter(global_scale, + requires_grad=False) + setattr(layer, name + "_weight_scale_2", global_scale) + + # BIAS + # Permute bias + for name in ["w13_bias", "w2_bias"]: + if not hasattr(layer, name): + continue + bias = getattr(layer, name).to(param_dtype) + + tensor_list = [] + for i in range(e): + expert_bias = bias[i] + + tensor_list.append(marlin_permute_bias(expert_bias)) + + bias = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + bias = torch.nn.Parameter(bias, requires_grad=False) + setattr(layer, name, bias) -def rand_marlin_weight_fp4_like(weight, group_size): +def rand_marlin_weight_nvfp4_like(weight, group_size): assert group_size > 0 size_n, size_k = weight.shape device = weight.device @@ -276,8 +339,58 @@ def rand_marlin_weight_fp4_like(weight, group_size): size_k=size_k, size_n=size_n, group_size=group_size) - marlin_scales = fp4_marlin_process_scales(marlin_scales) + marlin_scales = nvfp4_marlin_process_scales(marlin_scales) - global_scale = fp4_marlin_process_global_scale(global_scale) + global_scale = nvfp4_marlin_process_global_scale(global_scale) return weight_ref.T, marlin_qweight, marlin_scales, global_scale + + +def rand_marlin_weight_mxfp4_like(weight, group_size): + assert group_size > 0 + size_n, size_k = weight.shape + device = weight.device + + scales = torch.randint(100, + 125, (size_n, size_k // group_size), + dtype=torch.uint8, + device=weight.device) + scales = scales.view(torch.float8_e8m0fnu) + + fp4_weight = torch.randint(0, + 256, (size_n, size_k // 2), + dtype=torch.uint8, + device=weight.device) + fp4_weight_part_1 = ((fp4_weight & 0b10000000) | + ((fp4_weight & 0b01110000) >> 2)) + fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn) + fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6) + + fp4_weight2 = fp4_weight << 4 + fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) | + ((fp4_weight2 & 0b01110000) >> 2)) + fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn) + fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6) + + weight_ref = torch.cat( + [fp4_weight_part_2.unsqueeze(2), + fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k) + weight_ref = weight_ref * \ + scales.repeat_interleave(group_size, 1).to(weight.dtype) + + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=fp4_weight.view(torch.int32).T.contiguous(), + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=size_k, + size_n=size_n, + num_bits=4, + ) + + marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype), + size_k=size_k, + size_n=size_n, + group_size=group_size) + + marlin_scales = mxfp4_marlin_process_scales(marlin_scales) + + return weight_ref.T, marlin_qweight, marlin_scales.to(torch.float8_e8m0fnu) diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index 9d4a188..deeb69b 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -1,45 +1,133 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Callable, Optional import torch +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer + +logger = init_logger(__name__) + OCP_MX_BLOCK_SIZE = 32 -def per_token_group_quant_mxfp4(x: torch.Tensor, - block_k: int, - scale_calculation_mode: str = "even" - ) -> tuple[torch.Tensor, torch.Tensor]: +def _swizzle_mxfp4(quant_tensor, scale, num_warps): + """ weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel + """ + import triton_kernels.matmul_ogs_details.opt_flags as opt_flags + from triton_kernels.numerics import InFlexData + from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor + from triton_kernels.tensor_details import layout + from triton_kernels.tensor_details.layout import StridedLayout + if (current_platform.is_cuda() + and current_platform.is_device_capability(90) + and not is_torch_equal_or_newer("2.8.1")): + logger.warning_once( + "Mxfp4 on hopper is running on torch < 2.8.1, " + "this cause swizling to be disabled, which may " + "cause performance degradation. Please upgrade to torch nightly") + value_layout, value_layout_opts = StridedLayout, dict() + scale_layout, scale_layout_opts = StridedLayout, dict() + else: + value_layout, value_layout_opts = \ + layout.make_default_matmul_mxfp4_w_layout(mx_axis=1) + scale_layout, scale_layout_opts = ( + layout.make_default_matmul_mxfp4_w_scale_layout( + mx_axis=1, num_warps=num_warps)) + if current_platform.is_cuda() and \ + current_platform.is_device_capability(100): + constraints = { + "is_persistent": True, + "epilogue_subtile": 1, + } + opt_flags.update_opt_flags_constraints(constraints) + # transpose the tensor so that the quantization axis is on dim1 + quant_tensor = quant_tensor.transpose(-2, -1) + scale = scale.transpose(-2, -1) + quant_tensor = convert_layout(wrap_torch_tensor(quant_tensor, dtype=FP4), + value_layout, **value_layout_opts) + scale = convert_layout(wrap_torch_tensor(scale), scale_layout, + **scale_layout_opts) + return quant_tensor, InFlexData(), scale + + +def _can_support_mxfp4(use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + scoring_func: str = "softmax", + activation: str = "swiglu_oai", + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None): + return not (use_grouped_topk or topk_group or num_expert_group + or expert_map or custom_routing_function + or e_score_correction_bias or apply_router_weight_on_input + or scoring_func != "softmax" or activation != "swiglu_oai" + or expert_load_view or logical_to_physical_map + or logical_replica_count) + + +def _dequant_mxfp4(x: torch.Tensor, scale: torch.Tensor, + float_dtype: torch.dtype) -> torch.Tensor: try: - from quark.torch.kernel.hw_emulation.hw_emulation_interface import ( - fake_quantize_fp4_fp6_per_group_with_scale) - from quark.torch.quantization.utils import (even_round, - reshape_to_blocks) + from quark.torch.kernel import mx except ImportError as err: raise ImportError("The package `amd-quark` is required to use " "MX-FP4 models. Please install it with `pip install " "amd-quark`.") from err - axis = -1 - block_x = reshape_to_blocks(x, block_k, axis) - amax, _ = torch.max(torch.abs(block_x), dim=-1, keepdim=True) - amax = amax.squeeze(-1) + return mx.dq_mxfp4(x, scale, float_dtype) - # TODO: there are other rounding strategies supported in quark and in the - # config.json that we do not check for here! - if scale_calculation_mode != "even": - raise NotImplementedError( - f"Scale calculation mode {scale_calculation_mode} is not yet " - "supported in MX-FP4 quantization") - scale = even_round(amax, "fp4") - # Apply dequantize(quantize(x)). - x = fake_quantize_fp4_fp6_per_group_with_scale( - x, - scale.to(x.device), - axis=axis, - group_size=block_k, - quant_dtype="fp4", +def _dequant_mxfp4_fake(x: torch.Tensor, scale: torch.Tensor, + float_dtype: torch.dtype) -> torch.Tensor: + return torch.empty((*x.shape[:-1], x.shape[-1] * 2), + dtype=float_dtype, + device=x.device) + + +def _quant_dequant_mxfp4(x: torch.Tensor, + scale_calculation_mode: str = "even") -> torch.Tensor: + try: + from quark.torch.kernel import mx + except ImportError as err: + raise ImportError("The package `amd-quark` is required to use " + "MX-FP4 models. Please install it with `pip install " + "amd-quark`.") from err + + return mx.qdq_mxfp4(x, scale_calculation_mode) + + +def _quant_dequant_mxfp4_fake(x: torch.Tensor, + scale_calculation_mode: str = "even" + ) -> torch.Tensor: + return torch.empty_like(x) + + +try: + direct_register_custom_op( + op_name="dequant_mxfp4", + op_func=_dequant_mxfp4, + mutates_args=[], + fake_impl=_dequant_mxfp4_fake, ) + dequant_mxfp4 = torch.ops.vllm.dequant_mxfp4 +except AttributeError as error: + raise error - return x, scale +try: + direct_register_custom_op( + op_name="quant_dequant_mxfp4", + op_func=_quant_dequant_mxfp4, + mutates_args=[], + fake_impl=_quant_dequant_mxfp4_fake, + ) + quant_dequant_mxfp4 = torch.ops.vllm.quant_dequant_mxfp4 +except AttributeError as error: + raise error diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index d6b9677..428e9e9 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -3,22 +3,41 @@ """This file is used for /tests and /benchmarks""" from collections.abc import Mapping from types import MappingProxyType -from typing import Optional +from typing import ClassVar, NamedTuple, Optional import numpy import torch +from vllm._custom_ops import cutlass_scaled_mm_supports_fp4 from vllm.model_executor.layers.quantization.qqq import ( MARLIN_QQQ_SUPPORTED_NUM_BITS) +from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types -SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] -SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +# Use proxy as NamedTuple direct subclasses cannot have static members +class _GroupShape(NamedTuple): + row: int + col: int + + +class GroupShape(_GroupShape): + """ + This class describes the quantization group shape. + It includes static members for common shapes (per-tensor, per-token). + """ + + # Aliases for common quantization group shapes + PER_TENSOR: ClassVar['GroupShape'] + PER_TOKEN: ClassVar['GroupShape'] + + +GroupShape.PER_TENSOR = GroupShape(-1, -1) +GroupShape.PER_TOKEN = GroupShape(1, -1) # Normalize the group_shape to the full extent for any dims that are -1 -def _normalize_quant_group_shape(x: torch.Tensor, group_shape: tuple[int, - int]): +def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape): # -1 means full extent return (group_shape[0] if group_shape[0] > 0 else x.shape[-2], group_shape[1] if group_shape[1] > 0 else x.shape[-1]) @@ -58,7 +77,7 @@ def group_broadcast(t, shape): # (i.e. per-token-per-group) def scaled_quantize( x: torch.Tensor, - group_shape: tuple[int, int], + group_shape: GroupShape, quant_dtype: torch.dtype, ) -> tuple[torch.Tensor, torch.Tensor]: group_shape = _normalize_quant_group_shape(x, group_shape) @@ -99,7 +118,7 @@ def scaled_quantize( def scaled_dequantize( x_q: torch.Tensor, x_s: torch.Tensor, - group_shape: Optional[tuple[int, int]] = None, + group_shape: Optional[GroupShape] = None, out_dtype: torch.dtype = torch.float32, ) -> tuple[torch.Tensor, torch.Tensor]: if group_shape is not None: @@ -332,6 +351,10 @@ def quantize_weights(w: torch.Tensor, ) +SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + + def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType, group_size: int, @@ -571,3 +594,56 @@ def awq_pack( q_w = q_w.reshape((-1, size_n)).contiguous() return pack_cols(q_w, num_bits, size_k, size_n) + + +def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor: + """ + Pad and block-interleave the FP4 block-scales so that they match the data + layout expected by the CUTLASS / FlashInfer kernels. + + Parameters + ---------- + scale: torch.Tensor + + Returns + ------- + torch.Tensor + The swizzled tensor with the same logical shape as *scale*. + """ + assert scale.dtype == torch.float8_e4m3fn, ( + "swizzle_blockscale expects the input tensor to be in " + "torch.float8_e4m3fn format.") + + scale_ndim = scale.ndim + if scale_ndim == 2: + scale = scale.unsqueeze(0) # (1, M, K) + assert scale.ndim == 3, "Expected a 2-D or 3-D tensor for block scales." + + B, M, K = scale.shape + + def _round_up(x: int, m: int) -> int: + return (x + m - 1) // m * m + + M_padded = _round_up(M, 128) + K_padded = _round_up(K, 4) + + padded = torch.zeros((B, M_padded, K_padded), + dtype=scale.dtype, + device=scale.device) + padded[:B, :M, :K] = scale + + # Reshape / permute to the layout required by the kernel. + padded = padded.reshape(B, M_padded // 128, 4, 32, K_padded // 4, 4) + swizzled = padded.permute(0, 1, 4, 3, 2, 5).contiguous().cuda() + + if scale_ndim == 2: + return swizzled.reshape(M, K) + return swizzled.reshape(B, M, K) + + +def cutlass_fp4_supported() -> bool: + if not current_platform.is_cuda(): + return False + capability_tuple = current_platform.get_device_capability() + capability = -1 if capability_tuple is None else capability_tuple.to_int() + return cutlass_scaled_mm_supports_fp4(capability) diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 7d4293a..d8a468d 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -8,7 +8,6 @@ import torch.distributed as dist from torch import nn from transformers import GptOssConfig -from vllm import envs from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig @@ -28,7 +27,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.utils import cdiv -from .utils import extract_layer_index, maybe_prefix +from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index, + maybe_prefix) class OAIAttention(nn.Module): @@ -70,12 +70,9 @@ class OAIAttention(nn.Module): tp_size = get_tensor_model_parallel_world_size() - # attention_sink_dtype = (torch.float32 if envs.VLLM_USE_TRTLLM_ATTENTION - # else torch.bfloat16) - attention_sink_dtype = torch.bfloat16 self.sinks = torch.nn.Parameter( torch.empty(config.num_attention_heads // tp_size, - dtype=attention_sink_dtype, + dtype=torch.bfloat16, requires_grad=False)) self.norm = RMSNorm(config.hidden_size, eps=1e-5) @@ -207,6 +204,7 @@ class GptOssModel(nn.Module): super().__init__() self.config = vllm_config.model_config.hf_config self.quant_config = vllm_config.quant_config + self.parallel_config = vllm_config.parallel_config self.config.hidden_size = self.config.hidden_size self.embedding = VocabParallelEmbedding( self.config.vocab_size, @@ -229,8 +227,364 @@ class GptOssModel(nn.Module): x = self.norm(x) return x + def _load_weights_mxfp4( + self, + ep_rank_end: int, + ep_rank_start: int, + heads_per_rank: int, + head_start: int, + weights: Iterable[tuple[str, torch.Tensor]], + stacked_params_mapping: list[tuple[str, ...]], + ) -> set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + mxfp4_block = 32 + use_ep = self.parallel_config.enable_expert_parallel + num_experts = self.config.num_local_experts + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + intermediate_size = self.config.intermediate_size + intermediate_size_block = intermediate_size // mxfp4_block + per_rank_intermediate_size_block = cdiv(intermediate_size_block, + tp_size) + per_rank_intermediate_size = (per_rank_intermediate_size_block * + mxfp4_block) + + # Calculate common slicing bounds for current rank + tp_rank_start = tp_rank * per_rank_intermediate_size + tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, + intermediate_size) + + for name, weight in weights: + # FIXME(woosuk): Remove this after testing. + weight = weight.cuda() + + if ".w13_weight_scale" in name: + # Handle MLP gate and up projection weights scale + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, + 2 * tp_rank_start:2 * tp_rank_end, + ...] + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=None) + loaded_params.add(name) + continue + elif ".w2_weight_scale" in name: + # Handle MLP down projection weights + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[..., tp_rank_start // + mxfp4_block:tp_rank_end // + mxfp4_block] + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=None) + loaded_params.add(name) + continue + elif ".w13_weight" in name: + # Handle MLP gate and up projection weights + # flat weight from (E, 2 * N, block_size, entry_per_block) + # to (E, 2 * N, -1), shouldn't trigger copy for contiguous + weight = weight.view(num_experts, 2 * intermediate_size, + -1).contiguous() + + # Extract gate and up projection parts + # since the weight is shuffled, we can slice directly + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, + 2 * tp_rank_start:2 * tp_rank_end, + ...] + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=None) + loaded_params.add(name) + continue + elif ".w2_weight" in name: + # Handle MLP down projection weights + # same flatten here, but since 2 mx4 value are packed in 1 + # uint8, divide by 2 + weight = weight.view(num_experts, -1, + intermediate_size // 2).contiguous() + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[..., + tp_rank_start // 2:tp_rank_end // 2] + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=None) + loaded_params.add(name) + continue + elif ".w13_bias" in name: + # Handle MLP gate and up projection biases + # Extract gate and up projection bias parts + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, + 2 * tp_rank_start:2 * tp_rank_end] + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=None) + loaded_params.add(name) + continue + elif ".w2_bias" in name: + # Handle MLP down projection bias + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if use_ep: + weight = weight[ep_rank_start:ep_rank_end, ...] + else: + # (only load on rank 0 to avoid duplication) + if tp_rank != 0: + weight.zero_() + weight_loader(param, + weight, + weight_name=name, + shard_id=None, + expert_id=None) + loaded_params.add(name) + continue + elif "sinks" in name: + # Handle attention sinks (distributed across ranks) + param = params_dict[name] + narrow_weight = weight.narrow(0, head_start, heads_per_rank) + param.data.copy_(narrow_weight) + loaded_params.add(name) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if weight_loader == default_weight_loader: + weight_loader(param, weight) + else: + weight_loader(param, weight, shard_id) + break + else: + # Handle all other weights with potential renaming + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, weight) + loaded_params.add(name) + return loaded_params + + def _load_weights_other( + self, + ep_rank_start: int, + ep_rank_end: int, + heads_per_rank: int, + head_start: int, + weights: Iterable[tuple[str, torch.Tensor]], + stacked_params_mapping: list[tuple[str, ...]], + ) -> set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + use_ep = self.parallel_config.enable_expert_parallel + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + intermediate_size = self.config.intermediate_size + per_rank_intermediate_size = cdiv(intermediate_size, tp_size) + # Calculate common slicing bounds for current rank + tp_rank_start = tp_rank * per_rank_intermediate_size + tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, + intermediate_size) + + for name, weight in weights: + if ".w13_weight" in name: + # Handle MLP gate and up projection weights + # Extract gate and up projection parts + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, :, + 2 * tp_rank_start:2 * tp_rank_end] + + narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() + param = params_dict[name] + + param.copy_(narrow_weight) + loaded_params.add(name) + continue + elif ".w2_weight" in name: + # Handle MLP down projection weights + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, tp_rank_start:tp_rank_end, :] + narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() + param = params_dict[name] + + param.copy_(narrow_weight) + loaded_params.add(name) + continue + elif ".w13_bias" in name: + # Handle MLP gate and up projection biases + # Extract gate and up projection bias parts + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, + 2 * tp_rank_start:2 * tp_rank_end] + + param = params_dict[name] + param.copy_(narrow_weight) + loaded_params.add(name) + continue + elif ".w2_bias" in name: + # Handle MLP down projection bias + if use_ep: + weight = weight[ep_rank_start:ep_rank_end, ...] + else: + # (only load on rank 0 to avoid duplication) + if tp_rank != 0: + weight.zero_() + param = params_dict[name] + param.copy_(weight) + loaded_params.add(name) + continue + elif "sinks" in name: + # Handle attention sinks (distributed across ranks) + param = params_dict[name] + narrow_weight = weight.narrow(0, head_start, heads_per_rank) + param.data.copy_(narrow_weight) + loaded_params.add(name) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if weight_loader == default_weight_loader: + weight_loader(param, weight) + else: + weight_loader(param, weight, shard_id) + break + else: + # Handle all other weights with potential renaming + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, weight) + loaded_params.add(name) + return loaded_params + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv", ".q_proj", "q"), + (".qkv", ".k_proj", "k"), + (".qkv", ".v_proj", "v"), + ] + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + # Attention heads per rank + heads_per_rank = self.config.num_attention_heads // tp_size + head_start = tp_rank * heads_per_rank + + ep_size = get_ep_group().world_size + ep_rank = get_ep_group().rank + num_experts = self.config.num_local_experts + experts_per_rank = num_experts // ep_size + ep_rank_start = ep_rank * experts_per_rank + ep_rank_end = (ep_rank + 1) * experts_per_rank + + quant_method = (self.config.quantization_config['quant_method'] if + hasattr(self.config, "quantization_config") else None) + if quant_method == "mxfp4": + return self._load_weights_mxfp4(ep_rank_end, ep_rank_start, + heads_per_rank, head_start, + weights, stacked_params_mapping) + else: + return self._load_weights_other(ep_rank_end, ep_rank_start, + heads_per_rank, head_start, + weights, stacked_params_mapping) + class GptOssForCausalLM(nn.Module): + packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]} + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".self_attn.": ".attn.", + ".post_attention_layernorm.": ".mlp.norm.", + }, + orig_to_new_suffix={ + ".embed_tokens.weight": ".embedding.weight", + ".input_layernorm.weight": ".attn.norm.weight", + ".post_attention_layernorm.weight": ".mlp.norm.weight", + + # MoE MXFP4 weights + ".gate_up_proj_blocks": ".w13_weight", + ".down_proj_blocks": ".w2_weight", + ".gate_up_proj_scales": ".w13_weight_scale", + ".down_proj_scales": ".w2_weight_scale", + + # MoE other weights + ".gate_up_proj": ".w13_weight", + ".down_proj": ".w2_weight", + + # MoE Bias + ".gate_up_proj_bias": ".w13_bias", + ".down_proj_bias": ".w2_bias", + }, + ) def __init__( self, @@ -239,16 +593,17 @@ class GptOssForCausalLM(nn.Module): ): super().__init__() self.vllm_config = vllm_config - self.model_config = vllm_config.model_config.hf_config + self.config = vllm_config.model_config.hf_config + self.model = GptOssModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model"), ) self.lm_head = ParallelLMHead( - self.model_config.vocab_size, - self.model_config.hidden_size, + self.config.vocab_size, + self.config.hidden_size, ) - self.logits_processor = LogitsProcessor(self.model_config.vocab_size) + self.logits_processor = LogitsProcessor(self.config.vocab_size) def forward(self, input_ids: torch.Tensor, @@ -265,354 +620,11 @@ class GptOssForCausalLM(nn.Module): sampling_metadata) return logits - def _load_weights_mxfp4( - self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - rename_mapping = { - "self_attn": "attn", - "input_layernorm.weight": "attn.norm.weight", - "post_attention_layernorm.weight": "mlp.norm.weight", - "embed_tokens": "embedding", - } - - def maybe_rename(name: str) -> str: - for remap_name, new_name in rename_mapping.items(): - if remap_name in name: - return name.replace(remap_name, new_name) - return name - - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - mxfp4_block = 32 - - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - intermediate_size = self.model_config.intermediate_size - intermediate_size_block = intermediate_size // mxfp4_block - per_rank_intermediate_size_block = cdiv(intermediate_size_block, - tp_size) - per_rank_intermediate_size = (per_rank_intermediate_size_block * - mxfp4_block) - - # Calculate common slicing bounds for current rank - tp_rank_start = tp_rank * per_rank_intermediate_size - tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, - intermediate_size) - - # Attention heads per rank - heads_per_rank = self.model_config.num_attention_heads // tp_size - head_start = tp_rank * heads_per_rank - - use_ep = self.vllm_config.parallel_config.enable_expert_parallel - ep_size = get_ep_group().world_size - ep_rank = get_ep_group().rank - num_experts = self.model_config.num_local_experts - experts_per_rank = num_experts // ep_size - ep_rank_start = ep_rank * experts_per_rank - ep_rank_end = (ep_rank + 1) * experts_per_rank - - for name, weight in weights: - # FIXME(woosuk): Remove this after testing. - weight = weight.cuda() - - if "gate_up_proj_blocks" in name: - # Handle MLP gate and up projection weights - new_name = name.replace("gate_up_proj_blocks", "w13_weight") - - # flat weight from (E, 2 * N, block_size, entry_per_block) - # to (E, 2 * N, -1), shouldn't trigger copy for contiguous - weight = weight.view(num_experts, 2 * intermediate_size, - -1).contiguous() - - # Extract gate and up projection parts - # since the weight is shuffled, we can slice directly - if use_ep: - narrow_weight = weight[ep_rank_start:ep_rank_end, ...] - else: - narrow_weight = weight[:, - 2 * tp_rank_start:2 * tp_rank_end, - ...] - - param = params_dict[new_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, - narrow_weight, - weight_name=new_name, - shard_id=None, - expert_id=None) - loaded_params.add(new_name) - - elif "down_proj_blocks" in name: - # Handle MLP down projection weights - new_name = name.replace("down_proj_blocks", "w2_weight") - # same flatten here, but since 2 mx4 value are packed in 1 - # uint8, divide by 2 - weight = weight.view(num_experts, -1, - intermediate_size // 2).contiguous() - if use_ep: - narrow_weight = weight[ep_rank_start:ep_rank_end, ...] - else: - narrow_weight = weight[..., - tp_rank_start // 2:tp_rank_end // 2] - - param = params_dict[new_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, - narrow_weight, - weight_name=new_name, - shard_id=None, - expert_id=None) - loaded_params.add(new_name) - - elif "gate_up_proj_scales" in name: - # Handle MLP gate and up projection weights scale - new_name = name.replace("gate_up_proj_scales", - "w13_weight_scale") - if use_ep: - narrow_weight = weight[ep_rank_start:ep_rank_end, ...] - else: - narrow_weight = weight[:, - 2 * tp_rank_start:2 * tp_rank_end, - ...] - - param = params_dict[new_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, - narrow_weight, - weight_name=new_name, - shard_id=None, - expert_id=None) - loaded_params.add(new_name) - - elif "down_proj_scales" in name: - # Handle MLP down projection weights - new_name = name.replace("down_proj_scales", "w2_weight_scale") - if use_ep: - narrow_weight = weight[ep_rank_start:ep_rank_end, ...] - else: - narrow_weight = weight[..., tp_rank_start // - mxfp4_block:tp_rank_end // - mxfp4_block] - - param = params_dict[new_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, - narrow_weight, - weight_name=new_name, - shard_id=None, - expert_id=None) - loaded_params.add(new_name) - elif "gate_up_proj_bias" in name: - # Handle MLP gate and up projection biases - new_name = name.replace("gate_up_proj_bias", "w13_bias") - - # Extract gate and up projection bias parts - if use_ep: - narrow_weight = weight[ep_rank_start:ep_rank_end, ...] - else: - narrow_weight = weight[:, - 2 * tp_rank_start:2 * tp_rank_end] - - param = params_dict[new_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, - narrow_weight, - weight_name=new_name, - shard_id=None, - expert_id=None) - loaded_params.add(new_name) - - elif "down_proj_bias" in name: - # Handle MLP down projection bias - new_name = name.replace("down_proj_bias", "w2_bias") - param = params_dict[new_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - if use_ep: - weight = weight[ep_rank_start:ep_rank_end, ...] - else: - # (only load on rank 0 to avoid duplication) - if tp_rank != 0: - weight.zero_() - weight_loader(param, - weight, - weight_name=new_name, - shard_id=None, - expert_id=None) - loaded_params.add(new_name) - elif "sinks" in name: - # Handle attention sinks (distributed across ranks) - name = name.replace("self_attn", "attn") - param = params_dict[name] - narrow_weight = weight.narrow(0, head_start, heads_per_rank) - param.data.copy_(narrow_weight) - loaded_params.add(name) - elif "q_proj" in name or "k_proj" in name or "v_proj" in name: - shard_id = ("q" if "q_proj" in name else - "k" if "k_proj" in name else "v") - name = name.replace("self_attn", "attn") - param_name = name.replace(f"{shard_id}_proj", "qkv") - param = params_dict[param_name] - weight_loader = param.weight_loader - weight_loader(param, weight, loaded_shard_id=shard_id) - loaded_params.add(param_name) - else: - # Handle all other weights with potential renaming - renamed_name = maybe_rename(name) - if renamed_name not in params_dict: - continue - param = params_dict[renamed_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, weight) - loaded_params.add(renamed_name) - - return loaded_params - - def _load_weights_other( - self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - rename_mapping = { - "self_attn": "attn", - "input_layernorm.weight": "attn.norm.weight", - "post_attention_layernorm.weight": "mlp.norm.weight", - "embed_tokens": "embedding", - } - - def maybe_rename(name: str) -> str: - for remap_name, new_name in rename_mapping.items(): - if remap_name in name: - return name.replace(remap_name, new_name) - return name - - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - intermediate_size = self.model_config.intermediate_size - - per_rank_intermediate_size = cdiv(intermediate_size, tp_size) - # Calculate common slicing bounds for current rank - tp_rank_start = tp_rank * per_rank_intermediate_size - tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, - intermediate_size) - - # Attention heads per rank - heads_per_rank = self.model_config.num_attention_heads // tp_size - head_start = tp_rank * heads_per_rank - - use_ep = self.vllm_config.parallel_config.enable_expert_parallel - ep_size = get_ep_group().world_size - ep_rank = get_ep_group().rank - num_experts = self.model_config.num_local_experts - experts_per_rank = num_experts // ep_size - ep_rank_start = ep_rank * experts_per_rank - ep_rank_end = (ep_rank + 1) * experts_per_rank - - for name, weight in weights: - if ".experts.gate_up_proj" in name and "bias" not in name: - # Handle MLP gate and up projection weights - new_name = name.replace(".experts.gate_up_proj", - ".experts.w13_weight") - - # Extract gate and up projection parts - # since the weight is shuffled, we can slice directly - if use_ep: - narrow_weight = weight[ep_rank_start:ep_rank_end, ...] - else: - narrow_weight = weight[:, :, - 2 * tp_rank_start:2 * tp_rank_end] - - narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() - param = params_dict[new_name] - - param.copy_(narrow_weight) - loaded_params.add(new_name) - - elif ".experts.down_proj" in name and "bias" not in name: - # Handle MLP down projection weights - new_name = name.replace(".experts.down_proj", - ".experts.w2_weight") - - if use_ep: - narrow_weight = weight[ep_rank_start:ep_rank_end, ...] - else: - narrow_weight = weight[:, tp_rank_start:tp_rank_end, :] - narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() - param = params_dict[new_name] - - param.copy_(narrow_weight) - loaded_params.add(new_name) - - elif "gate_up_proj_bias" in name: - # Handle MLP gate and up projection biases - new_name = name.replace("gate_up_proj_bias", "w13_bias") - - # Extract gate and up projection bias parts - if use_ep: - narrow_weight = weight[ep_rank_start:ep_rank_end, ...] - else: - narrow_weight = weight[:, - 2 * tp_rank_start:2 * tp_rank_end] - - param = params_dict[new_name] - - param.copy_(narrow_weight) - loaded_params.add(new_name) - - elif "down_proj_bias" in name: - # Handle MLP down projection bias - new_name = name.replace("down_proj_bias", "w2_bias") - - if use_ep: - weight = weight[ep_rank_start:ep_rank_end, ...] - else: - # (only load on rank 0 to avoid duplication) - if tp_rank != 0: - weight.zero_() - param = params_dict[new_name] - param.copy_(weight) - loaded_params.add(new_name) - elif "sinks" in name: - # Handle attention sinks (distributed across ranks) - name = name.replace("self_attn", "attn") - param = params_dict[name] - narrow_weight = weight.narrow(0, head_start, heads_per_rank) - param.data.copy_(narrow_weight) - loaded_params.add(name) - elif "q_proj" in name or "k_proj" in name or "v_proj" in name: - shard_id = ("q" if "q_proj" in name else - "k" if "k_proj" in name else "v") - name = name.replace("self_attn", "attn") - param_name = name.replace(f"{shard_id}_proj", "qkv") - param = params_dict[param_name] - weight_loader = param.weight_loader - weight_loader(param, weight, loaded_shard_id=shard_id) - loaded_params.add(param_name) - else: - # Handle all other weights with potential renaming - - renamed_name = maybe_rename(name) - if renamed_name not in params_dict: - continue - param = params_dict[renamed_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, weight) - loaded_params.add(renamed_name) - - return loaded_params - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - quant_method = (self.model_config.quantization_config['quant_method'] - if hasattr(self.model_config, "quantization_config") - else None) - if quant_method == "mxfp4": - return self._load_weights_mxfp4(weights) - else: - return self._load_weights_other(weights) + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) \ No newline at end of file diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index dee2bb7..129ee9b 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -179,12 +179,14 @@ class CudaPlatformBase(Platform): @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, - kv_cache_dtype, block_size, use_v1, - use_mla) -> str: + kv_cache_dtype, block_size, use_v1, use_mla, + has_sink) -> str: if use_mla: - # TODO(lucas): refactor to be more concise + # TODO(lucas): refactor to be more concise # we should probably consider factoring out V1 here - if selected_backend == _Backend.CUTLASS_MLA_VLLM_V1: + if selected_backend == _Backend.CUTLASS_MLA or ( + cls.is_device_capability(100) and selected_backend is None + and block_size == 128): if use_v1: logger.info_once("Using Cutlass MLA backend on V1 engine.") return ("vllm.v1.attention.backends.mla." @@ -223,37 +225,101 @@ class CudaPlatformBase(Platform): return ("vllm.attention.backends." "flashmla.FlashMLABackend") if use_v1: + FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501 + FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 + TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 + FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 + TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501 + XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501 + if selected_backend == _Backend.FLASHINFER: logger.info_once("Using FlashInfer backend on V1 engine.") - return "vllm.v1.attention.backends.flashinfer.FlashInferBackend" - if selected_backend == _Backend.FLEX_ATTENTION: - logger.info("Using FlexAttenion backend on V1 engine.") - return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 - if selected_backend == _Backend.TRITON_ATTN_VLLM_V1: + if cls.has_device_capability(100): + from vllm.v1.attention.backends.utils import ( + set_kv_cache_layout) + set_kv_cache_layout("HND") + return FLASHINFER_V1 + elif selected_backend == _Backend.FLEX_ATTENTION: + logger.info_once("Using FlexAttention backend on V1 engine.") + return FLEX_ATTENTION_V1 + elif selected_backend == _Backend.TRITON_ATTN_VLLM_V1: logger.info_once("Using Triton backend on V1 engine.") - return ("vllm.v1.attention.backends." - "triton_attn.TritonAttentionBackend") + return TRITON_ATTN_VLLM_V1 + elif selected_backend == _Backend.FLASH_ATTN: + logger.info_once("Using Flash Attention backend on V1 engine.") + return FLASH_ATTN_V1 + # elif selected_backend == _Backend.TREE_ATTN: + # logger.info_once("Using Tree Attention backend on V1 engine.") + # return TREE_ATTN_V1 + # elif selected_backend == _Backend.XFORMERS_VLLM_V1: + # logger.info_once("Using XFormers backend on V1 engine.") + # return XFORMERS_V1 + + from vllm.attention.selector import is_attn_backend_supported + + # Default backends for V1 engine + # Prefer FlashInfer for Blackwell GPUs if installed if cls.is_device_capability(100): - # Prefer FlashInfer for V1 on Blackwell GPUs if installed - try: - import flashinfer # noqa: F401 - logger.info_once( - "Using FlashInfer backend on V1 engine by default for " - "Blackwell (SM 10.0) GPUs.") - return ("vllm.v1.attention.backends." - "flashinfer.FlashInferBackend") - except ImportError: + if is_default_backend_supported := is_attn_backend_supported( + FLASHINFER_V1, head_size, dtype): + from vllm.v1.attention.backends.utils import ( + set_kv_cache_layout) + logger.info_once( + "Using FlashInfer backend with HND KV cache layout on " + "V1 engine by default for Blackwell (SM 10.0) GPUs.") + set_kv_cache_layout("HND") + + return FLASHINFER_V1 + + if not is_default_backend_supported.can_import: + logger.warning_once( "FlashInfer failed to import for V1 engine on " "Blackwell (SM 10.0) GPUs; it is recommended to " "install FlashInfer for better performance.") - pass + + # FlashAttention is the default for SM 8.0+ GPUs if cls.has_device_capability(80): - logger.info_once("Using Flash Attention backend on V1 engine.") - return ("vllm.v1.attention.backends." - "flash_attn.FlashAttentionBackend") + if has_sink and not cls.is_device_capability(90): + logger.info_once("Using Triton backend on V1 engine.") + return TRITON_ATTN_VLLM_V1 + if is_default_backend_supported := is_attn_backend_supported( + FLASH_ATTN_V1, head_size, dtype, + allow_import_error=False): + logger.info_once("Using Flash Attention backend on " + "V1 engine.") + return FLASH_ATTN_V1 + + # FlexAttention is the default for older GPUs + else: + logger.info_once("Using FlexAttention backend on V1 engine.") + return FLEX_ATTENTION_V1 + + assert not is_default_backend_supported + + use_flex_attention_reason = {} + if not is_default_backend_supported.head_size: + use_flex_attention_reason["head_size"] = head_size + if not is_default_backend_supported.dtype: + use_flex_attention_reason["dtype"] = dtype + + logger.info_once( + "Using FlexAttention backend for %s on V1 engine.", + ", ".join(f"{k}={v}" + for k, v in use_flex_attention_reason.items()), + ) + return FLEX_ATTENTION_V1 + + # Backends for V0 engine if selected_backend == _Backend.FLASHINFER: logger.info("Using FlashInfer backend.") + if cls.has_device_capability(100): + from vllm.v1.attention.backends.utils import ( + set_kv_cache_layout) + logger.info_once( + "Using HND KV cache layout on V1 engine by default for " + "Blackwell (SM 10.0) GPUs.") + set_kv_cache_layout("HND") return "vllm.attention.backends.flashinfer.FlashInferBackend" elif selected_backend == _Backend.XFORMERS: logger.info("Using XFormers backend.") @@ -262,6 +328,10 @@ class CudaPlatformBase(Platform): logger.info("Using DualChunkFlashAttention backend.") return ("vllm.attention.backends.dual_chunk_flash_attn." "DualChunkFlashAttentionBackend") + elif selected_backend == _Backend.DIFFERENTIAL_FLASH_ATTN: + logger.info("Using DifferentialFlashAttention backend.") + return ("vllm.attention.backends.differential_flash_attn." + "DifferentialFlashAttentionBackend") elif selected_backend == _Backend.FLASH_ATTN: pass elif selected_backend: @@ -291,7 +361,7 @@ class CudaPlatformBase(Platform): # installed. if target_backend == _Backend.FLASH_ATTN: try: - import flash_attn # noqa: F401 + import vllm.vllm_flash_attn # noqa: F401 from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend, flash_attn_supports_fp8) diff --git a/vllm/utils.py b/vllm/utils.py index 12590b7..8126ae3 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2908,3 +2908,37 @@ def is_torch_equal_or_newer(target: str) -> bool: except Exception: # Fallback to PKG-INFO to load the package info, needed by the doc gen. return Version(importlib.metadata.version('torch')) >= Version(target) + + +@cache +def _has_module(module_name: str) -> bool: + """Return True if *module_name* can be found in the current environment. + + The result is cached so that subsequent queries for the same module incur + no additional overhead. + """ + return importlib.util.find_spec(module_name) is not None + + +def has_pplx() -> bool: + """Whether the optional `pplx_kernels` package is available.""" + + return _has_module("pplx_kernels") + + +def has_deep_ep() -> bool: + """Whether the optional `deep_ep` package is available.""" + + return _has_module("deep_ep") + + +def has_deep_gemm() -> bool: + """Whether the optional `deep_gemm` package is available.""" + + return _has_module("deep_gemm") + + +def has_triton_kernels() -> bool: + """Whether the optional `triton_kernels` package is available.""" + + return _has_module("triton_kernels") diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 22d361c..5c6c57a 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -44,10 +44,25 @@ class FlashAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - @staticmethod - def get_supported_head_sizes() -> list[int]: + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 80, 96, 112, 128, 160, 192, 224, 256] + @classmethod + def validate_head_size(cls, head_size: int) -> None: + supported_head_sizes = cls.get_supported_head_sizes() + if head_size not in supported_head_sizes: + attn_type = cls.__name__.removesuffix("Backend") + raise ValueError( + f"Head size {head_size} is not supported by {attn_type}. " + f"Supported head sizes are: {supported_head_sizes}. " + "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " + "FlexAttention backend which supports all head sizes.") + @staticmethod def get_name() -> str: return "FLASH_ATTN_VLLM_V1" @@ -657,15 +672,12 @@ class FlashAttentionImpl(AttentionImpl): alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, + sinks: Optional[torch.Tensor] = None, use_irope: bool = False, ) -> None: - if blocksparse_params is not None: - raise ValueError( - "FlashAttention does not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -675,6 +687,8 @@ class FlashAttentionImpl(AttentionImpl): self.alibi_slopes = alibi_slopes if sliding_window is None: self.sliding_window = (-1, -1) + elif attn_type == AttentionType.ENCODER_ONLY: + self.sliding_window = (sliding_window - 1, sliding_window - 1) else: self.sliding_window = (sliding_window - 1, 0) self.kv_cache_dtype = kv_cache_dtype @@ -684,28 +698,33 @@ class FlashAttentionImpl(AttentionImpl): self.logits_soft_cap = logits_soft_cap self.kv_sharing_target_layer_name = kv_sharing_target_layer_name - assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - support_head_sizes = FlashAttentionBackend.get_supported_head_sizes() - if head_size not in support_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by FlashAttention. " - f"Supported head sizes are: {support_head_sizes}. " - "Set VLLM_USE_V1=0 to use another attention backend.") + FlashAttentionBackend.validate_head_size(head_size) - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " + if attn_type not in [ + AttentionType.DECODER, AttentionType.ENCODER_ONLY + ]: + raise NotImplementedError("Encoder/decoder cross-attention " + "is not implemented for " "FlashAttentionImpl") + self.use_irope = use_irope + self.attn_type = attn_type self.vllm_flash_attn_version = get_flash_attn_version() if is_quantized_kv_cache(self.kv_cache_dtype) \ and not flash_attn_supports_fp8(): raise NotImplementedError( "FlashAttention does not support fp8 kv-cache on this device.") + self.sinks = sinks + if self.sinks is not None: + assert self.vllm_flash_attn_version == 3, ( + "Sinks are only supported in FlashAttention 3") + assert self.sinks.shape[0] == num_heads, ( + "Sinks must have the same number of heads as the number of " + "heads in the layer") + def forward( self, layer: torch.nn.Module, @@ -715,6 +734,7 @@ class FlashAttentionImpl(AttentionImpl): kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -732,6 +752,11 @@ class FlashAttentionImpl(AttentionImpl): """ assert output is not None, "Output tensor must be provided." + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for FlashAttentionImpl") + if attn_metadata is None: # Profiling run. return output diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index f8455b5..7bb7b61 100644 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -37,6 +37,10 @@ class FlashInferBackend(AttentionBackend): accept_output_buffer: bool = True + @staticmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + @staticmethod def get_supported_head_sizes() -> list[int]: return [64, 128, 256] @@ -510,10 +514,10 @@ class FlashInferImpl(AttentionImpl): alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[int] = None, + sinks: Optional[torch.Tensor] = None, use_irope: bool = False, ) -> None: if use_irope: @@ -543,6 +547,19 @@ class FlashInferImpl(AttentionImpl): "encoder/decoder cross-attention " "are not implemented for " "FlashInferImpl") + + self.sinks: Optional[torch.Tensor] = None + if sinks is not None: + if sinks.shape[0] != num_heads: + raise ValueError( + "Sinks must have the same number of heads as the number of " + f"heads in the layer. Expected {num_heads}, but got " + f"{sinks.shape[0]}." + ) + # Cast sinks to float32 if needed (FlashInfer requirement) + if sinks.dtype != torch.float32: + sinks = sinks.to(torch.float32) + self.sinks = sinks def forward( self, @@ -553,6 +570,7 @@ class FlashInferImpl(AttentionImpl): kv_cache: torch.Tensor, attn_metadata: FlashInferMetadata, output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashInfer. @@ -567,6 +585,11 @@ class FlashInferImpl(AttentionImpl): """ assert output is not None, "Output tensor must be provided." + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for FlashInferImpl") + if attn_metadata is None: # Profiling run. return output diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index c3efb93..4579c5c 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -44,6 +44,10 @@ def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor: class FlexAttentionBackend(AttentionBackend): accept_output_buffer: bool = True + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16, torch.float32] + @staticmethod def get_supported_head_sizes() -> list[int]: return [16, 32, 64, 96, 128, 160, 192, 224, 256] @@ -346,15 +350,10 @@ class FlexAttentionImpl(AttentionImpl): alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, ) -> None: - if blocksparse_params is not None: - # TODO we should support this :think - raise ValueError( - "FlashAttention does not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -410,6 +409,7 @@ class FlexAttentionImpl(AttentionImpl): kv_cache: torch.Tensor, attn_metadata: FlexAttentionMetadata, output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FLexAttention. @@ -423,6 +423,11 @@ class FlexAttentionImpl(AttentionImpl): shape = [num_tokens, num_heads * head_size] """ assert output is not None, "Output tensor must be provided." + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for FlexAttentionImpl") + enable_gqa = self.num_kv_heads != self.num_heads if attn_metadata is None: diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 0f956ba..1373c1c 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -110,7 +110,6 @@ class PallasAttentionBackendImpl(AttentionImpl): alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[int] = None, @@ -120,9 +119,6 @@ class PallasAttentionBackendImpl(AttentionImpl): logger.warning_once( "Using irope in Pallas is not supported yet, it will fall back " "to global attention for long context.") - if blocksparse_params is not None: - raise ValueError("Paged attention Pallas kernel does " - "not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -139,8 +135,6 @@ class PallasAttentionBackendImpl(AttentionImpl): raise NotImplementedError("Alibi slopes is not supported.") if kv_cache_dtype != "auto": raise NotImplementedError("FP8 KV cache dtype is not supported.") - if blocksparse_params is not None: - raise NotImplementedError("Blocksparse is not supported.") if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " @@ -161,6 +155,7 @@ class PallasAttentionBackendImpl(AttentionImpl): kv_cache: torch.Tensor, attn_metadata: PallasMetadata, output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with Pallas attention. @@ -173,6 +168,11 @@ class PallasAttentionBackendImpl(AttentionImpl): Returns: shape = [num_tokens, num_heads * head_size] """ + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for PallasAttentionBackendImpl") + # For determine_available_memory case. if kv_cache.numel() == 0: if output is None: diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 6a7c704..ebf06a5 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -3,6 +3,10 @@ """Attention layer with PagedAttention and Triton prefix prefill.""" from typing import TYPE_CHECKING, Any, Optional +from dataclasses import dataclass +from functools import cache +from typing import ClassVar, Optional + import torch from vllm import _custom_ops as ops @@ -12,7 +16,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.ops.chunked_prefill_paged_decode import ( chunked_prefill_paged_decode) from vllm.attention.ops.paged_attn import PagedAttention -from vllm.attention.ops.triton_unified_attention import unified_attention +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import ( @@ -38,10 +42,25 @@ class TritonAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - @staticmethod - def get_supported_head_sizes() -> list[int]: + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] + @classmethod + def validate_head_size(cls, head_size: int) -> None: + supported_head_sizes = cls.get_supported_head_sizes() + if head_size not in supported_head_sizes: + attn_type = cls.__name__.removesuffix("Backend") + raise ValueError( + f"Head size {head_size} is not supported by {attn_type}. " + f"Supported head sizes are: {supported_head_sizes}. " + "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " + "FlexAttention backend which supports all head sizes.") + @staticmethod def get_name() -> str: return "TRITON_ATTN_VLLM_V1" @@ -74,6 +93,15 @@ class TritonAttentionBackend(AttentionBackend): return TritonAttentionMetadataBuilder +@cache +def use_aiter_unified_attention() -> bool: + """Check if aiter unified attention should be used.""" + # VLLM_ROCM_USE_AITER_MHA needs to set to 0 as well as it is set + # to 1 as default + return envs.VLLM_ROCM_USE_AITER \ + and envs.VLLM_USE_AITER_UNIFIED_ATTENTION + + class TritonAttentionImpl(AttentionImpl): def __init__( @@ -85,16 +113,11 @@ class TritonAttentionImpl(AttentionImpl): alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[int] = None, - use_irope: bool = False, sinks: Optional[torch.Tensor] = None, ) -> None: - if blocksparse_params is not None: - raise ValueError( - "TritonAttention does not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -113,16 +136,9 @@ class TritonAttentionImpl(AttentionImpl): self.logits_soft_cap = logits_soft_cap self.kv_sharing_target_layer_name = kv_sharing_target_layer_name - self.use_irope = use_irope - - assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - support_head_sizes = TritonAttentionBackend.get_supported_head_sizes() - if head_size not in support_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by TritonAttention. " - f"Supported head sizes are: {support_head_sizes}.") + TritonAttentionBackend.validate_head_size(head_size) if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " @@ -133,7 +149,23 @@ class TritonAttentionImpl(AttentionImpl): self.fp8_dtype = current_platform.fp8_dtype() self.force_prefill_decode_attn = \ envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION - + + if not self.force_prefill_decode_attn: + # If not using prefill decode attention, we use the Triton + # unified attention implementation. + if use_aiter_unified_attention(): + logger.info_once( + "Using aiter unified attention for TritonAttentionImpl") + from aiter.ops.triton.unified_attention import ( + unified_attention) + self.unified_attention = unified_attention + else: + logger.info_once( + "Using vllm unified attention for TritonAttentionImpl") + from vllm.attention.ops.triton_unified_attention import ( + unified_attention) + self.unified_attention = unified_attention + self.sinks = sinks if sinks is not None: assert sinks.shape[0] == num_heads, ( @@ -150,6 +182,7 @@ class TritonAttentionImpl(AttentionImpl): kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -164,6 +197,11 @@ class TritonAttentionImpl(AttentionImpl): """ assert output is not None, "Output tensor must be provided." + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for TritonAttentionImpl") + if attn_metadata is None: # Profiling run. return output @@ -227,51 +265,41 @@ class TritonAttentionImpl(AttentionImpl): query.reshape( (num_tokens, num_heads * head_size)).contiguous(), layer._q_scale) - query = query.reshape((num_tokens, num_heads, head_size)) + query = query.reshape((num_tokens, num_heads, head_size)) - use_local_attn = \ - (self.use_irope and attn_metadata.local_attn_metadata is not None) - - if use_local_attn: - assert attn_metadata.local_attn_metadata is not None - local_metadata = attn_metadata.local_attn_metadata - cu_seqlens_q = local_metadata.local_query_start_loc - seqused_k = local_metadata.local_seqused_k - max_seqlen_q = local_metadata.local_max_query_len - max_seqlen_k = local_metadata.local_max_seq_len - block_table = local_metadata.local_block_table - else: - cu_seqlens_q = attn_metadata.query_start_loc - seqused_k = attn_metadata.seq_lens - max_seqlen_q = attn_metadata.max_query_len - max_seqlen_k = attn_metadata.max_seq_len - block_table = attn_metadata.block_table + cu_seqlens_q = attn_metadata.query_start_loc + seqused_k = attn_metadata.seq_lens + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_seq_len + block_table = attn_metadata.block_table if use_prefill_decode_attn: # Compute attention and update output up to `num_actual_tokens`. - chunked_prefill_paged_decode(query=query[:num_actual_tokens], - key=key[:num_actual_tokens], - value=value[:num_actual_tokens], - output=output[:num_actual_tokens], - kv_cache_dtype=self.kv_cache_dtype, - key_cache=key_cache, - value_cache=value_cache, - block_table=block_table, - query_start_loc=cu_seqlens_q, - seq_lens=seqused_k, - max_seq_len=max_seqlen_k, - max_query_len=max_seqlen_q, - k_scale=layer._k_scale, - v_scale=layer._v_scale, - alibi_slopes=self.alibi_slopes, - sliding_window=self.sliding_window[0], - sm_scale=self.scale, - sinks=self.sinks) + chunked_prefill_paged_decode( + query=query[:num_actual_tokens], + key=key[:num_actual_tokens], + value=value[:num_actual_tokens], + output=output[:num_actual_tokens], + kv_cache_dtype=self.kv_cache_dtype, + key_cache=key_cache, + value_cache=value_cache, + block_table=block_table, + query_start_loc=cu_seqlens_q, + seq_lens=seqused_k, + max_seq_len=max_seqlen_k, + max_query_len=max_seqlen_q, + k_scale=layer._k_scale, + v_scale=layer._v_scale, + alibi_slopes=self.alibi_slopes, + sliding_window=self.sliding_window[0], + sm_scale=self.scale, + sinks=self.sinks, + ) else: descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) - unified_attention( + self.unified_attention( q=query[:num_actual_tokens], k=key_cache, v=value_cache,