[gpt-oss] Add gpt-oss mxfp4 support
This commit is contained in:
212
.gitignore
vendored
212
.gitignore
vendored
@@ -1,2 +1,212 @@
|
||||
__pycache__/
|
||||
# version file generated by setuptools-scm
|
||||
/vllm/_version.py
|
||||
|
||||
# vllm-flash-attn built from source
|
||||
vllm/vllm_flash_attn/*
|
||||
|
||||
# triton jit
|
||||
.triton
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
cmake-build-*/
|
||||
CMakeUserPresets.json
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
/.deps/
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# generated files
|
||||
**/generated/**
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
docs/argparse
|
||||
docs/examples/*
|
||||
!docs/examples/README.md
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
.idea/
|
||||
|
||||
# VSCode
|
||||
.vscode/
|
||||
|
||||
# DS Store
|
||||
.DS_Store
|
||||
|
||||
# Results
|
||||
*.csv
|
||||
|
||||
# Python pickle files
|
||||
*.pkl
|
||||
|
||||
# Sphinx documentation
|
||||
_build/
|
||||
|
||||
# vim swap files
|
||||
*.swo
|
||||
*.swp
|
||||
|
||||
# hip files generated by PyTorch
|
||||
*.hip
|
||||
*_hip*
|
||||
hip_compat.h
|
||||
|
||||
# Benchmark dataset
|
||||
benchmarks/**/*.json
|
||||
|
||||
# Linting
|
||||
actionlint
|
||||
shellcheck*/
|
||||
|
||||
# Ignore moe/marlin_moe gen code
|
||||
csrc/moe/marlin_moe_wna16/kernel_*
|
||||
|
||||
# Ignore ep_kernels_workspace folder
|
||||
ep_kernels_workspace/
|
||||
@@ -1,4 +1,6 @@
|
||||
# metax-c500-vllm
|
||||
|
||||
1. 支持 `gpt-oss-BF16`:将 `vllm` 目录覆盖到镜像中的 `/opt/conda/lib/python3.10/site-packages/vllm`
|
||||
1. 支持 `gpt-oss`:将 `vllm` 目录覆盖到镜像中的 `/opt/conda/lib/python3.10/site-packages/vllm`。运行`gpt-oss`时需指定指定`VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1`
|
||||
2. 将 `code_generator.py` 覆盖到镜像中的 `/opt/conda/lib/python3.10/site-packages/triton/compiler/code_generator.py`
|
||||
|
||||
*此版本改动较大,可能因为接口改动,存在部分模型运行出错的问题。*
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -9,19 +9,49 @@ import torch.nn.functional as F
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import AttentionType
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
||||
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
||||
from vllm.config import CacheConfig, get_current_vllm_config
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.utils import validate_kv_sharing_target
|
||||
|
||||
logger = init_logger(__name__)
|
||||
USE_XFORMERS_OPS = None
|
||||
|
||||
|
||||
def check_xformers_availability():
|
||||
global USE_XFORMERS_OPS
|
||||
if USE_XFORMERS_OPS is not None:
|
||||
return USE_XFORMERS_OPS
|
||||
|
||||
if current_platform.is_cuda() and current_platform.has_device_capability(
|
||||
100):
|
||||
# Xformers FA is not compatible with B200
|
||||
USE_XFORMERS_OPS = False
|
||||
else:
|
||||
try:
|
||||
from importlib.util import find_spec
|
||||
|
||||
find_spec("xformers.ops")
|
||||
USE_XFORMERS_OPS = True
|
||||
except ImportError:
|
||||
USE_XFORMERS_OPS = False
|
||||
|
||||
# the warning only needs to be shown once
|
||||
if not USE_XFORMERS_OPS:
|
||||
logger.warning("Xformers is not available, falling back.")
|
||||
|
||||
return USE_XFORMERS_OPS
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
@@ -45,13 +75,13 @@ class Attention(nn.Module):
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
per_layer_sliding_window: Optional[int] = None,
|
||||
use_mla: bool = False,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
attn_backend: Optional[type[AttentionBackend]] = None,
|
||||
**extra_impl_args,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -80,6 +110,9 @@ class Attention(nn.Module):
|
||||
calculate_kv_scales = False
|
||||
if num_kv_heads is None:
|
||||
num_kv_heads = num_heads
|
||||
assert num_heads % num_kv_heads == 0, \
|
||||
f"num_heads ({num_heads}) is not " \
|
||||
f"divisible by num_kv_heads ({num_kv_heads})"
|
||||
|
||||
# The default k/v_scale is set to 1.0. This is ignored
|
||||
# when kv-cache is not fp8, and should be used with
|
||||
@@ -105,6 +138,7 @@ class Attention(nn.Module):
|
||||
self.head_size = head_size
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
self.has_sink = extra_impl_args.get("sinks") is not None
|
||||
|
||||
quant_method = quant_config.get_quant_method(
|
||||
self, prefix=prefix) if quant_config else None
|
||||
@@ -126,19 +160,23 @@ class Attention(nn.Module):
|
||||
# During model initialization, the default dtype is set as the model
|
||||
# weight and activation dtype.
|
||||
dtype = torch.get_default_dtype()
|
||||
attn_backend = get_attn_backend(head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
is_attention_free,
|
||||
blocksparse_params is not None,
|
||||
use_mla=use_mla)
|
||||
impl_cls = attn_backend.get_impl_cls()
|
||||
if attn_backend is None:
|
||||
self.attn_backend = get_attn_backend(head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
is_attention_free,
|
||||
use_mla=use_mla,
|
||||
has_sink=self.has_sink)
|
||||
else:
|
||||
self.attn_backend = attn_backend
|
||||
|
||||
impl_cls = self.attn_backend.get_impl_cls()
|
||||
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **extra_impl_args)
|
||||
self.backend = backend_name_to_enum(attn_backend.get_name())
|
||||
self.backend = backend_name_to_enum(self.attn_backend.get_name())
|
||||
self.dtype = dtype
|
||||
|
||||
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
|
||||
@@ -148,7 +186,7 @@ class Attention(nn.Module):
|
||||
self.use_direct_call = not current_platform.is_cuda_alike(
|
||||
) and not current_platform.is_cpu()
|
||||
|
||||
self.use_output = attn_backend.accept_output_buffer
|
||||
self.use_output = self.attn_backend.accept_output_buffer
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
@@ -206,7 +244,7 @@ class Attention(nn.Module):
|
||||
if self.use_output:
|
||||
output_shape = (output_shape
|
||||
if output_shape is not None else query.shape)
|
||||
output = torch.empty(output_shape,
|
||||
output = torch.zeros(output_shape,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
hidden_size = output_shape[-1]
|
||||
@@ -274,6 +312,9 @@ class Attention(nn.Module):
|
||||
if hasattr(self.impl, "process_weights_after_loading"):
|
||||
self.impl.process_weights_after_loading(act_dtype)
|
||||
|
||||
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||
return self.attn_backend
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""Multi-headed attention without any cache, used for ViT."""
|
||||
@@ -291,7 +332,9 @@ class MultiHeadAttention(nn.Module):
|
||||
self.scale = scale
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
assert self.num_heads % self.num_kv_heads == 0, \
|
||||
f"num_heads ({self.num_heads}) is not " \
|
||||
f"divisible by num_kv_heads ({self.num_kv_heads})"
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
dtype = torch.get_default_dtype()
|
||||
@@ -301,12 +344,21 @@ class MultiHeadAttention(nn.Module):
|
||||
block_size=16,
|
||||
is_attention_free=False)
|
||||
backend = backend_name_to_enum(attn_backend.get_name())
|
||||
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
|
||||
backend = _Backend.XFORMERS
|
||||
if current_platform.is_rocm():
|
||||
# currently, only torch_sdpa is supported on rocm
|
||||
self.attn_backend = _Backend.TORCH_SDPA
|
||||
else:
|
||||
if backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1,
|
||||
_Backend.FLEX_ATTENTION):
|
||||
backend = _Backend.XFORMERS
|
||||
|
||||
self.attn_backend = backend if backend in {
|
||||
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
|
||||
} else _Backend.TORCH_SDPA
|
||||
self.attn_backend = backend if backend in {
|
||||
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
|
||||
} else _Backend.TORCH_SDPA
|
||||
|
||||
if (self.attn_backend == _Backend.XFORMERS
|
||||
and not check_xformers_availability()):
|
||||
self.attn_backend = _Backend.TORCH_SDPA
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -430,6 +482,7 @@ def unified_attention_with_output(
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
@@ -444,7 +497,8 @@ def unified_attention_with_output(
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output=output)
|
||||
output=output,
|
||||
output_scale=output_scale)
|
||||
|
||||
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||
|
||||
@@ -455,6 +509,7 @@ def unified_attention_with_output_fake(
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
@@ -3,8 +3,9 @@
|
||||
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import cache
|
||||
from typing import Generator, Optional, Type
|
||||
from typing import Generator, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -79,15 +80,72 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
|
||||
return forced_attn_backend
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _IsSupported:
|
||||
can_import: bool
|
||||
head_size: bool
|
||||
dtype: bool
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return self.can_import and self.head_size and self.dtype
|
||||
|
||||
|
||||
def is_attn_backend_supported(
|
||||
attn_backend: Union[str, type[AttentionBackend]],
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
allow_import_error: bool = True,
|
||||
) -> _IsSupported:
|
||||
if isinstance(attn_backend, str):
|
||||
try:
|
||||
attn_backend = resolve_obj_by_qualname(attn_backend)
|
||||
except ImportError:
|
||||
if not allow_import_error:
|
||||
raise
|
||||
|
||||
return _IsSupported(can_import=False, head_size=False, dtype=False)
|
||||
|
||||
assert isinstance(attn_backend, type)
|
||||
|
||||
# TODO: Update the interface once V0 is removed
|
||||
if get_supported_head_sizes := getattr(attn_backend,
|
||||
"get_supported_head_sizes", None):
|
||||
is_head_size_supported = head_size in get_supported_head_sizes()
|
||||
elif validate_head_size := getattr(attn_backend, "validate_head_size",
|
||||
None):
|
||||
try:
|
||||
validate_head_size(head_size)
|
||||
is_head_size_supported = True
|
||||
except Exception:
|
||||
is_head_size_supported = False
|
||||
else:
|
||||
raise NotImplementedError(f"{attn_backend.__name__} does not support "
|
||||
"head size validation")
|
||||
|
||||
if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes",
|
||||
None):
|
||||
is_dtype_supported = dtype in get_supported_dtypes()
|
||||
else:
|
||||
raise NotImplementedError(f"{attn_backend.__name__} does not support "
|
||||
"dtype validation")
|
||||
|
||||
return _IsSupported(
|
||||
can_import=True,
|
||||
head_size=is_head_size_supported,
|
||||
dtype=is_dtype_supported,
|
||||
)
|
||||
|
||||
|
||||
def get_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
is_attention_free: bool,
|
||||
is_blocksparse: bool = False,
|
||||
is_attention_free: bool = False,
|
||||
use_mla: bool = False,
|
||||
) -> Type[AttentionBackend]:
|
||||
has_sink: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
"""Selects which attention backend to use and lazily imports it."""
|
||||
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
|
||||
# value to be returned from the cache if the value changes between calls.
|
||||
@@ -99,9 +157,9 @@ def get_attn_backend(
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
block_size=block_size,
|
||||
is_attention_free=is_attention_free,
|
||||
is_blocksparse=is_blocksparse,
|
||||
use_v1=envs.VLLM_USE_V1,
|
||||
use_mla=use_mla,
|
||||
has_sink=has_sink,
|
||||
)
|
||||
|
||||
|
||||
@@ -112,16 +170,10 @@ def _cached_get_attn_backend(
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
is_attention_free: bool,
|
||||
is_blocksparse: bool = False,
|
||||
use_v1: bool = False,
|
||||
use_mla: bool = False,
|
||||
) -> Type[AttentionBackend]:
|
||||
if is_blocksparse:
|
||||
logger.info("Using BlocksparseFlashAttention backend.")
|
||||
from vllm.attention.backends.blocksparse_attn import (
|
||||
BlocksparseFlashAttentionBackend)
|
||||
return BlocksparseFlashAttentionBackend
|
||||
|
||||
has_sink: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
# If there are no attention layers (e.g. we are running Mamba),
|
||||
# use the placeholder NO_ATTENTION
|
||||
if is_attention_free:
|
||||
@@ -144,11 +196,15 @@ def _cached_get_attn_backend(
|
||||
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
|
||||
if backend_by_env_var is not None:
|
||||
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||
if selected_backend is None:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: '{backend_by_env_var}'. "
|
||||
f"Valid backends are: {list(_Backend.__members__.keys())}")
|
||||
|
||||
# get device-specific attn_backend
|
||||
attention_cls = current_platform.get_attn_backend_cls(
|
||||
selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1,
|
||||
use_mla)
|
||||
use_mla, has_sink)
|
||||
if not attention_cls:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for {current_platform.device_name}")
|
||||
|
||||
33
vllm/attention/utils/kv_sharing_utils.py
Normal file
33
vllm/attention/utils/kv_sharing_utils.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
def validate_kv_sharing_target(current_layer_name, target_layer_name,
|
||||
static_forward_context):
|
||||
error_msg = (f"Specified KV sharing target layer for {current_layer_name} "
|
||||
f"is not valid: target layer {target_layer_name} ")
|
||||
|
||||
if current_layer_name == target_layer_name:
|
||||
raise ValueError(error_msg +
|
||||
"cannot be the same as the current layer.")
|
||||
|
||||
if target_layer_name not in static_forward_context:
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
|
||||
# If target layer name is not in the static fwd context, it means either
|
||||
# a) the target layer does not come BEFORE the current layer, or
|
||||
# b) the target layer is not an Attention layer that exists in the model
|
||||
current_layer_idx = extract_layer_index(current_layer_name)
|
||||
target_layer_idx = extract_layer_index(target_layer_name)
|
||||
if current_layer_idx <= target_layer_idx:
|
||||
raise ValueError(error_msg + "must come before the current layer.")
|
||||
else:
|
||||
raise ValueError(error_msg +
|
||||
"is not a valid Attention layer in the model.")
|
||||
|
||||
# Currently KV sharing is only supported between layers of the same type
|
||||
target_layer_attn_type = static_forward_context[
|
||||
target_layer_name].attn_type
|
||||
expected = static_forward_context[current_layer_name].attn_type
|
||||
if target_layer_attn_type != expected:
|
||||
raise ValueError(
|
||||
error_msg +
|
||||
f"must be the same type as the current layer ({expected}).")
|
||||
32
vllm/envs.py
32
vllm/envs.py
@@ -112,8 +112,10 @@ if TYPE_CHECKING:
|
||||
VLLM_DP_SIZE: int = 1
|
||||
VLLM_DP_MASTER_IP: str = ""
|
||||
VLLM_DP_MASTER_PORT: int = 0
|
||||
VLLM_MOE_DP_CHUNK_SIZE: int = 256
|
||||
VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False
|
||||
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
|
||||
VLLM_MXFP4_USE_MARLIN: Optional[bool] = None
|
||||
VLLM_V0_USE_OUTLINES_CACHE: bool = False
|
||||
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
|
||||
VLLM_USE_DEEP_GEMM: bool = False
|
||||
@@ -128,6 +130,8 @@ if TYPE_CHECKING:
|
||||
VLLM_SLEEP_WHEN_IDLE: bool = False
|
||||
VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16
|
||||
MACA_VLLM_USE_TN_2_NN: bool = True
|
||||
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
|
||||
|
||||
def get_default_cache_root():
|
||||
return os.getenv(
|
||||
@@ -149,6 +153,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
|
||||
return int(value)
|
||||
|
||||
|
||||
def maybe_convert_bool(value: Optional[str]) -> Optional[bool]:
|
||||
if value is None:
|
||||
return None
|
||||
return bool(int(value))
|
||||
|
||||
|
||||
def get_vllm_port() -> Optional[int]:
|
||||
"""Get the port from VLLM_PORT environment variable.
|
||||
|
||||
@@ -769,6 +779,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_DP_MASTER_IP":
|
||||
lambda: os.getenv("VLLM_DP_MASTER_IP", "127.0.0.1"),
|
||||
|
||||
# In the context of executing MoE models with Data-Parallel, Expert-Parallel
|
||||
# and Batched All-to-All dispatch/combine kernels, VLLM_MOE_DP_CHUNK_SIZE
|
||||
# dictates the quantum of tokens that can be dispatched from a DP
|
||||
# rank. All DP ranks process the activations in VLLM_MOE_DP_CHUNK_SIZE
|
||||
# units.
|
||||
"VLLM_MOE_DP_CHUNK_SIZE":
|
||||
lambda: int(os.getenv("VLLM_MOE_DP_CHUNK_SIZE", "256")),
|
||||
|
||||
# Port of the master node in the data parallel setting
|
||||
"VLLM_DP_MASTER_PORT":
|
||||
lambda: int(os.getenv("VLLM_DP_MASTER_PORT", "0")),
|
||||
@@ -794,6 +812,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_MARLIN_USE_ATOMIC_ADD":
|
||||
lambda: os.environ.get("VLLM_MARLIN_USE_ATOMIC_ADD", "0") == "1",
|
||||
|
||||
# Whether to use marlin kernel in mxfp4 quantization method
|
||||
"VLLM_MXFP4_USE_MARLIN":
|
||||
lambda: maybe_convert_bool(os.environ.get("VLLM_MXFP4_USE_MARLIN", None)),
|
||||
|
||||
# Whether to turn on the outlines cache for V0
|
||||
# This cache is unbounded and on disk, so it's not safe to use in
|
||||
# an environment with potentially malicious users.
|
||||
@@ -810,6 +832,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_USE_DEEP_GEMM":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
|
||||
|
||||
# If set to 1, use the FlashInfer
|
||||
# MXFP8 (activation) x MXFP4 (weight) MoE backend.
|
||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "0"))),
|
||||
|
||||
# If set to 1, use the FlashInfer
|
||||
# BF16 (activation) x MXFP4 (weight) MoE backend.
|
||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "0"))),
|
||||
|
||||
# Control the cache sized used by the xgrammar compiler. The default
|
||||
# of 512 MB should be enough for roughly 1000 JSON schemas.
|
||||
# It can be changed with this variable if needed for some reason.
|
||||
|
||||
@@ -4,8 +4,12 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Optional
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize)
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
_config: Optional[dict[str, Any]] = None
|
||||
@@ -26,8 +30,12 @@ def get_config() -> Optional[dict[str, Any]]:
|
||||
|
||||
__all__ = [
|
||||
"FusedMoE",
|
||||
"FusedMoEConfig",
|
||||
"FusedMoEMethodBase",
|
||||
"FusedMoeWeightScaleSupported",
|
||||
"FusedMoEPermuteExpertsUnpermute",
|
||||
"FusedMoEActivationFormat",
|
||||
"FusedMoEPrepareAndFinalize",
|
||||
"override_config",
|
||||
"get_config",
|
||||
]
|
||||
@@ -36,11 +44,21 @@ if HAS_TRITON:
|
||||
# import to register the custom ops
|
||||
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
|
||||
import vllm.model_executor.layers.fused_moe.fused_moe # noqa
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
||||
BatchedTritonOrDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
cutlass_moe_fp4, cutlass_moe_fp8)
|
||||
CutlassExpertsFp8, cutlass_moe_fp4, cutlass_moe_fp8)
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
DeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
TritonExperts, fused_experts, fused_moe, fused_topk,
|
||||
get_config_file_name, grouped_topk)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts)
|
||||
|
||||
__all__ += [
|
||||
"fused_moe",
|
||||
@@ -50,5 +68,11 @@ if HAS_TRITON:
|
||||
"grouped_topk",
|
||||
"cutlass_moe_fp8",
|
||||
"cutlass_moe_fp4",
|
||||
"CutlassExpertsFp8",
|
||||
"TritonExperts",
|
||||
"BatchedTritonExperts",
|
||||
"DeepGemmExperts",
|
||||
"BatchedDeepGemmExperts",
|
||||
"TritonOrDeepGemmExperts",
|
||||
"BatchedTritonOrDeepGemmExperts",
|
||||
]
|
||||
|
||||
490
vllm/model_executor/layers/fused_moe/config.py
Normal file
490
vllm/model_executor/layers/fused_moe/config.py
Normal file
@@ -0,0 +1,490 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import (QuantizationArgs,
|
||||
QuantizationStrategy,
|
||||
QuantizationType)
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.utils import cdiv
|
||||
# from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_quant_config_quantization_args(
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
prop_name: str,
|
||||
) -> Optional[QuantizationArgs]:
|
||||
if (quant_config is not None and hasattr(quant_config, 'target_scheme_map')
|
||||
and "Linear" in quant_config.target_scheme_map and
|
||||
"input_activations" in quant_config.target_scheme_map["Linear"]):
|
||||
return quant_config.target_scheme_map["Linear"].get(prop_name)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_quant_config_input_quant(
|
||||
quant_config: Optional[QuantizationConfig]
|
||||
) -> Optional[QuantizationArgs]:
|
||||
return _get_quant_config_quantization_args(quant_config,
|
||||
"input_activations")
|
||||
|
||||
|
||||
def get_quant_config_weight_quant(
|
||||
quant_config: Optional[QuantizationConfig]
|
||||
) -> Optional[QuantizationArgs]:
|
||||
return _get_quant_config_quantization_args(quant_config, "weights")
|
||||
|
||||
|
||||
# TODO (bnell): use scalar_type instead of bools?
|
||||
def get_config_quant_dtype(
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
use_mxfp4_w4a4: bool,
|
||||
) -> Union[None, torch.dtype, str]:
|
||||
if use_fp8_w8a8:
|
||||
return torch.float8_e4m3fn
|
||||
elif use_int8_w8a8:
|
||||
return torch.int8
|
||||
elif use_mxfp4_w4a4:
|
||||
return "mxfp4"
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusedMoEQuantConfig:
|
||||
# The post quantization activation type.
|
||||
quant_dtype: Optional[torch.dtype] = None
|
||||
per_act_token_quant: bool = False
|
||||
per_out_ch_quant: bool = False
|
||||
block_shape: Optional[list[int]] = None
|
||||
|
||||
# TODO: add col major flag?
|
||||
# add detailed quant info for input, intermediates, weights, etc?
|
||||
|
||||
def __post_init__(self):
|
||||
assert (not self.per_act_token_quant
|
||||
or self.block_shape is None), "illegal quantization"
|
||||
|
||||
@property
|
||||
def is_quantized(self) -> bool:
|
||||
return self.quant_dtype is not None
|
||||
|
||||
@property
|
||||
def is_per_act_token(self) -> bool:
|
||||
return self.per_act_token_quant
|
||||
|
||||
@property
|
||||
def is_block_quantized(self) -> bool:
|
||||
return self.block_shape is not None
|
||||
|
||||
@property
|
||||
def is_per_tensor(self) -> bool:
|
||||
return not self.per_act_token_quant and self.block_shape is None
|
||||
|
||||
def scale_shape(
|
||||
self,
|
||||
max_tokens: int,
|
||||
hidden_dim: int,
|
||||
) -> Optional[tuple[int, int]]:
|
||||
if self.is_quantized:
|
||||
if self.is_block_quantized:
|
||||
assert self.block_shape is not None
|
||||
_, block_k = self.block_shape
|
||||
k_tiles = cdiv(hidden_dim, block_k)
|
||||
return (max_tokens, k_tiles)
|
||||
elif self.is_per_act_token:
|
||||
return (max_tokens, 1)
|
||||
else:
|
||||
return (1, 1)
|
||||
else:
|
||||
return None
|
||||
|
||||
def batched_scale_shape(
|
||||
self,
|
||||
num_experts: int,
|
||||
max_tokens: int,
|
||||
hidden_dim: int,
|
||||
) -> Optional[tuple[int, int, int]]:
|
||||
if self.is_quantized:
|
||||
scale_shape = self.scale_shape(max_tokens, hidden_dim)
|
||||
assert scale_shape is not None
|
||||
return (num_experts, *scale_shape)
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_act_token_quant: bool = False,
|
||||
per_out_ch_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> "FusedMoEQuantConfig":
|
||||
assert sum([
|
||||
int(flag) for flag in [
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16,
|
||||
]
|
||||
]) <= 1, "Quantization flags are mutually exclusive."
|
||||
|
||||
quant_dtype = get_config_quant_dtype(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
||||
)
|
||||
return FusedMoEQuantConfig(
|
||||
quant_dtype,
|
||||
per_act_token_quant,
|
||||
per_out_ch_quant,
|
||||
block_shape,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusedMoEParallelConfig:
|
||||
tp_size: int
|
||||
dp_size: int
|
||||
ep_size: int
|
||||
tp_rank: int
|
||||
dp_rank: int
|
||||
ep_rank: int
|
||||
|
||||
use_ep: bool # whether to use EP or not
|
||||
|
||||
@property
|
||||
def use_all2all_kernels(self):
|
||||
return self.dp_size > 1 and self.use_ep
|
||||
|
||||
@property
|
||||
def use_pplx_kernels(self):
|
||||
return (self.use_all2all_kernels
|
||||
and envs.VLLM_ALL2ALL_BACKEND == "pplx")
|
||||
|
||||
@property
|
||||
def use_deepep_ht_kernels(self):
|
||||
return (self.use_all2all_kernels
|
||||
and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput")
|
||||
|
||||
@property
|
||||
def use_deepep_ll_kernels(self):
|
||||
return (self.use_all2all_kernels
|
||||
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
|
||||
|
||||
@property
|
||||
def use_flashinfer_cutlass_kernels(self):
|
||||
# return (envs.VLLM_USE_FLASHINFER_MOE_FP4
|
||||
# and has_flashinfer_cutlass_fused_moe()
|
||||
# and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def make(tp_size_: int, dp_size_: int,
|
||||
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
|
||||
"""
|
||||
Determine MoE parallel configuration. Based on the input `tp_size_`,
|
||||
`dp_size_` and vllm's parallel config, determine what
|
||||
level's of parallelism to use in the fused moe layer.
|
||||
|
||||
Args:
|
||||
tp_size_ (int): `tp_size` passed into the FusedMoE constructor.
|
||||
dp_size_ (int): `dp_size` passed into the FusedMoE constructor.
|
||||
vllm_parallel_config (ParallelConfig): vLLM's parallel config
|
||||
object which contains the `enable_expert_parallel` flag.
|
||||
|
||||
Examples:
|
||||
When there is no parallelism requested,
|
||||
i.e. `tp_size_` = `dp_size_` = 1, we simply return the sizes
|
||||
unaltered and the ranks set to 0.
|
||||
|
||||
Expert Parallelism is considered only when either `dp_size_` or
|
||||
`tp_size_` is non trivial.
|
||||
|
||||
When TP = 2, DP = 1 and EP = False, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} //
|
||||
legend : {size, rank}
|
||||
- device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0}
|
||||
- Comment : Tensors are sharded across 2 devices.
|
||||
|
||||
When TP = 1, DP = 2 and EP = False, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0}
|
||||
- device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0}
|
||||
- Comment: There are 2 engine instances and the tensors are sharded
|
||||
across 2 decvices.
|
||||
|
||||
When TP = 2, DP = 2 and EP = False, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0}
|
||||
- device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0}
|
||||
- device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0}
|
||||
- device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0}
|
||||
- Comment: There are 2 engine instances and the tensors are sharded
|
||||
across 4 devices.
|
||||
|
||||
When, TP = 2, DP = 1 and EP = True, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0}
|
||||
- device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1}
|
||||
- Comment: The experts are split between the 2 devices.
|
||||
|
||||
When, TP = 1, DP = 2 and EP = True, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0}
|
||||
- device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1}
|
||||
- Comment: There are 2 engine instances and the experts are split
|
||||
between the 2 devices.
|
||||
|
||||
When TP = 2, DP = 2 and EP = True, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0}
|
||||
- device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1}
|
||||
- device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2}
|
||||
- device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3}
|
||||
- Comment: There are 2 engine instances and the experts are split
|
||||
between the 4 devices.
|
||||
"""
|
||||
|
||||
def flatten_tp_across_dp(dp_rank: int):
|
||||
tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank()
|
||||
# There are actually dp_size_ * tp_size_ devices. Update tp_size
|
||||
# and tp_rank so we shard across all devices.
|
||||
tp_size = dp_size_ * tp_size_
|
||||
tp_rank = dp_rank * tp_size_ + tp_rank
|
||||
return tp_size, tp_rank
|
||||
|
||||
use_ep = (dp_size_ * tp_size_ > 1
|
||||
and vllm_parallel_config.enable_expert_parallel)
|
||||
|
||||
dp_size = dp_size_
|
||||
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
|
||||
tp_size, tp_rank = flatten_tp_across_dp(dp_rank)
|
||||
|
||||
if not use_ep:
|
||||
return FusedMoEParallelConfig(tp_size=tp_size,
|
||||
tp_rank=tp_rank,
|
||||
dp_size=dp_size,
|
||||
dp_rank=dp_rank,
|
||||
ep_size=1,
|
||||
ep_rank=0,
|
||||
use_ep=False)
|
||||
# DP + EP / TP + EP / DP + TP + EP
|
||||
assert use_ep
|
||||
# In EP, each device owns a set of experts fully. There is no tensor
|
||||
# parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that.
|
||||
ep_size = tp_size
|
||||
ep_rank = tp_rank
|
||||
return FusedMoEParallelConfig(tp_size=1,
|
||||
tp_rank=0,
|
||||
dp_size=dp_size,
|
||||
dp_rank=dp_rank,
|
||||
ep_size=ep_size,
|
||||
ep_rank=ep_rank,
|
||||
use_ep=True)
|
||||
|
||||
|
||||
# Adapted from pplx-kernels tests/all_to_all_utils.py
|
||||
@dataclass
|
||||
class FusedMoEConfig:
|
||||
num_experts: int
|
||||
experts_per_token: int
|
||||
hidden_dim: int
|
||||
|
||||
num_local_experts: int
|
||||
moe_parallel_config: FusedMoEParallelConfig
|
||||
|
||||
# The activation type.
|
||||
in_dtype: torch.dtype
|
||||
|
||||
quant_config: Optional[FusedMoEQuantConfig] = None
|
||||
|
||||
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
|
||||
|
||||
has_bias: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dp_size > 1:
|
||||
logger.debug_once("Using FusedMoEConfig::max_num_tokens=%d",
|
||||
self.max_num_tokens)
|
||||
|
||||
assert self.max_num_tokens > 0
|
||||
|
||||
@property
|
||||
def quant_dtype(self) -> Optional[torch.dtype]:
|
||||
if self.quant_config is not None:
|
||||
return self.quant_config.quant_dtype
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def block_shape(self) -> Optional[list[int]]:
|
||||
if self.quant_config is not None:
|
||||
return self.quant_config.block_shape
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def per_act_token_quant(self) -> bool:
|
||||
if self.quant_config is not None:
|
||||
return self.quant_config.per_act_token_quant
|
||||
else:
|
||||
return False
|
||||
|
||||
@property
|
||||
def per_out_ch_quant(self) -> bool:
|
||||
if self.quant_config is not None:
|
||||
return self.quant_config.per_out_ch_quant
|
||||
else:
|
||||
return False
|
||||
|
||||
@property
|
||||
def tp_size(self):
|
||||
return self.moe_parallel_config.tp_size
|
||||
|
||||
@property
|
||||
def dp_size(self):
|
||||
return self.moe_parallel_config.dp_size
|
||||
|
||||
@property
|
||||
def ep_size(self):
|
||||
return self.moe_parallel_config.ep_size
|
||||
|
||||
@property
|
||||
def tp_rank(self):
|
||||
return self.moe_parallel_config.tp_rank
|
||||
|
||||
@property
|
||||
def dp_rank(self):
|
||||
return self.moe_parallel_config.dp_rank
|
||||
|
||||
@property
|
||||
def ep_rank(self):
|
||||
return self.moe_parallel_config.ep_rank
|
||||
|
||||
@property
|
||||
def use_ep(self):
|
||||
return self.moe_parallel_config.use_ep
|
||||
|
||||
@property
|
||||
def use_pplx_kernels(self):
|
||||
return self.moe_parallel_config.use_pplx_kernels
|
||||
|
||||
@property
|
||||
def use_deepep_ht_kernels(self):
|
||||
return self.moe_parallel_config.use_deepep_ht_kernels
|
||||
|
||||
@property
|
||||
def use_deepep_ll_kernels(self):
|
||||
return self.moe_parallel_config.use_deepep_ll_kernels
|
||||
|
||||
@property
|
||||
def use_flashinfer_cutlass_kernels(self):
|
||||
return self.moe_parallel_config.use_flashinfer_cutlass_kernels
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
num_experts: int,
|
||||
experts_per_token: int,
|
||||
hidden_dim: int,
|
||||
num_local_experts: int,
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
in_dtype: torch.dtype,
|
||||
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE,
|
||||
quant_config: Optional[Union[FusedMoEQuantConfig,
|
||||
QuantizationConfig]] = None,
|
||||
has_bias: bool = False,
|
||||
) -> "FusedMoEConfig":
|
||||
|
||||
_quant_config: Optional[FusedMoEQuantConfig] = None
|
||||
|
||||
if quant_config is not None and isinstance(quant_config,
|
||||
QuantizationConfig):
|
||||
if hasattr(quant_config, 'weight_block_size'):
|
||||
block_shape = quant_config.weight_block_size
|
||||
else:
|
||||
block_shape = None
|
||||
per_act_token_quant = False
|
||||
per_out_ch_quant = False
|
||||
quant_dtype: Optional[torch.dtype] = None
|
||||
|
||||
input_quant = get_quant_config_input_quant(quant_config)
|
||||
weight_quant = get_quant_config_weight_quant(quant_config)
|
||||
|
||||
if input_quant is not None:
|
||||
per_act_token_quant = (input_quant.strategy
|
||||
== QuantizationStrategy.TOKEN
|
||||
if input_quant is not None else False)
|
||||
|
||||
if input_quant.num_bits == 8:
|
||||
if input_quant.type == QuantizationType.FLOAT:
|
||||
quant_dtype = torch.float8_e4m3fn
|
||||
elif input_quant.type == QuantizationType.INT:
|
||||
quant_dtype = torch.int8
|
||||
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
if quant_dtype is None and isinstance(quant_config, Fp8Config):
|
||||
quant_dtype = torch.float8_e4m3fn
|
||||
|
||||
from vllm.model_executor.layers.quantization.modelopt import (
|
||||
ModelOptNvFp4Config)
|
||||
if quant_dtype is None and isinstance(quant_config,
|
||||
ModelOptNvFp4Config):
|
||||
quant_dtype = torch.uint8
|
||||
|
||||
if weight_quant is not None:
|
||||
per_out_ch_quant = (
|
||||
weight_quant.strategy == QuantizationStrategy.CHANNEL)
|
||||
|
||||
if quant_dtype is not None:
|
||||
_quant_config = FusedMoEQuantConfig(
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=per_out_ch_quant,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
else:
|
||||
_quant_config = FusedMoEQuantConfig()
|
||||
if moe_parallel_config.dp_size > 1:
|
||||
logger.warning_once("MoE DP setup unable to determine "
|
||||
"quantization scheme or unsupported "
|
||||
"quantization type. This model will "
|
||||
"not run with DP enabled.")
|
||||
else:
|
||||
_quant_config = quant_config
|
||||
|
||||
return FusedMoEConfig(
|
||||
num_experts=num_experts,
|
||||
experts_per_token=experts_per_token,
|
||||
hidden_dim=hidden_dim,
|
||||
num_local_experts=num_local_experts,
|
||||
moe_parallel_config=moe_parallel_config,
|
||||
in_dtype=in_dtype,
|
||||
quant_config=_quant_config,
|
||||
max_num_tokens=max_num_tokens,
|
||||
has_bias=has_bias,
|
||||
)
|
||||
@@ -1503,8 +1503,8 @@ def fused_experts_impl(
|
||||
qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
|
||||
A=curr_hidden_states,
|
||||
A_scale=a1_scale,
|
||||
qtype=qtype,
|
||||
per_channel_quant=per_channel_quant,
|
||||
quant_dtype=qtype,
|
||||
per_act_token_quant=per_channel_quant,
|
||||
block_shape=block_shape)
|
||||
|
||||
invoke_fused_moe_kernel(qcurr_hidden_states,
|
||||
@@ -1562,8 +1562,8 @@ def fused_experts_impl(
|
||||
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||
A=intermediate_cache2,
|
||||
A_scale=a2_scale,
|
||||
qtype=qtype,
|
||||
per_channel_quant=per_channel_quant,
|
||||
quant_dtype=qtype,
|
||||
per_act_token_quant=per_channel_quant,
|
||||
block_shape=block_shape)
|
||||
|
||||
invoke_fused_moe_kernel(qintermediate_cache2,
|
||||
|
||||
@@ -0,0 +1,248 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate)
|
||||
from vllm.model_executor.layers.fused_moe.utils import extract_required_args
|
||||
from vllm.utils import has_triton_kernels
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if has_triton_kernels():
|
||||
try:
|
||||
import triton_kernels.swiglu
|
||||
from triton_kernels.matmul_ogs import (FnSpecs, FusedActivation,
|
||||
matmul_ogs)
|
||||
from triton_kernels.routing import routing
|
||||
except ModuleNotFoundError:
|
||||
logger.error(
|
||||
"Failed to import Triton kernels. Please make sure your triton "
|
||||
"version is compatible.")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from triton_kernels.matmul_ogs import PrecisionConfig
|
||||
|
||||
|
||||
def triton_kernel_moe_forward(
|
||||
hidden_states: torch.Tensor,
|
||||
w1, # Tensor or triton_kernels.Tensor
|
||||
w2, # Tensor or triton_kernels.Tensor
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
w1_precision: Optional["PrecisionConfig"] = None,
|
||||
w2_precision: Optional["PrecisionConfig"] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
routing_data, gather_idx, scatter_idx = routing(gating_output,
|
||||
topk,
|
||||
sm_first=not renormalize)
|
||||
|
||||
return triton_kernel_fused_experts(
|
||||
None,
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
routing_data,
|
||||
gather_idx,
|
||||
scatter_idx,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
per_channel_quant=per_channel_quant,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
w1_precision=w1_precision,
|
||||
w2_precision=w2_precision,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape)
|
||||
|
||||
|
||||
# This is a triton implementation of the fused_experts function
|
||||
def triton_kernel_fused_experts(
|
||||
output_tensor: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1, # Tensor or triton_kernels.Tensor
|
||||
w2, # Tensor or triton_kernels.Tensor
|
||||
routing_data, # RoutingData
|
||||
gather_indx, # GatherIndx
|
||||
scatter_indx, # ScatterIndx
|
||||
activation: str = "silu",
|
||||
swiglu_alpha: float = 1.702,
|
||||
swiglu_limit: float = 7.0,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
w1_precision: Optional["PrecisionConfig"] = None,
|
||||
w2_precision: Optional["PrecisionConfig"] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# type check, uint8 means mxfp4
|
||||
assert hidden_states.dtype == torch.bfloat16
|
||||
assert w1_bias is None or w1_bias.dtype == torch.float32
|
||||
assert w2_bias is None or w2_bias.dtype == torch.float32
|
||||
|
||||
# Shape check, only check non-mxfp4
|
||||
assert hidden_states.shape[-1] == w1.shape[-2]
|
||||
assert w2.shape[-1] == w1.shape[1]
|
||||
|
||||
E, _, N = w1.shape
|
||||
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
|
||||
act = FusedActivation(
|
||||
FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")),
|
||||
(swiglu_alpha, swiglu_limit), 2)
|
||||
gammas = routing_data.gate_scal if routing_data else None
|
||||
|
||||
intermediate_cache1 = matmul_ogs(
|
||||
hidden_states,
|
||||
w1,
|
||||
w1_bias,
|
||||
routing_data,
|
||||
gather_indx=gather_indx,
|
||||
precision_config=w1_precision,
|
||||
gammas=gammas if apply_router_weight_on_input else None,
|
||||
fused_activation=act)
|
||||
|
||||
intermediate_cache3 = matmul_ogs(
|
||||
intermediate_cache1,
|
||||
w2,
|
||||
w2_bias,
|
||||
routing_data,
|
||||
scatter_indx=scatter_indx,
|
||||
precision_config=w2_precision,
|
||||
gammas=None if apply_router_weight_on_input else gammas,
|
||||
y=output_tensor,
|
||||
)
|
||||
return intermediate_cache3
|
||||
|
||||
|
||||
class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config,
|
||||
max_num_tokens: int,
|
||||
num_dispatchers: int,
|
||||
w1_precision: "PrecisionConfig",
|
||||
w2_precision: "PrecisionConfig",
|
||||
):
|
||||
super().__init__(quant_config)
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.num_dispatchers = num_dispatchers
|
||||
self.w1_precision = w1_precision
|
||||
self.w2_precision = w2_precision
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self
|
||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||
return (mk.FusedMoEActivationFormat.BatchedExperts,
|
||||
mk.FusedMoEActivationFormat.BatchedExperts)
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
return TopKWeightAndReduceDelegate()
|
||||
|
||||
def workspace_shapes(
|
||||
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
|
||||
topk: int, global_num_experts: int, local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
# workspace are allocated inside the kernel
|
||||
assert a.dim() == 2
|
||||
num_dp = self.num_dispatchers
|
||||
num_experts = local_num_experts
|
||||
max_num_tokens = self.max_num_tokens
|
||||
workspace2 = (0, 0, 0)
|
||||
output = (num_experts, max_num_tokens * num_dp, N)
|
||||
return (output, workspace2, output, a.dtype)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
extra_expert_args: Optional[dict[str, Any]],
|
||||
):
|
||||
w1_bias, w2_bias = (extract_required_args(extra_expert_args,
|
||||
["w1_bias", "w2_bias"]))
|
||||
|
||||
return triton_kernel_fused_experts(
|
||||
output,
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=False,
|
||||
use_fp8_w8a8=False,
|
||||
per_channel_quant=False,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
w1_precision=self.w1_precision,
|
||||
w2_precision=self.w2_precision,
|
||||
a1_scale=a1q_scale,
|
||||
a2_scale=a2_scale)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,10 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from math import prod
|
||||
from typing import Any, Optional, final
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable
|
||||
_resize_cache, count_expert_num_tokens)
|
||||
from vllm.utils import cdiv
|
||||
|
||||
#
|
||||
# This file defines a set of base classes used to make MoE kernels more modular.
|
||||
# The goal is to be able to utilize different communication mechanisms with
|
||||
@@ -14,7 +23,7 @@ import torch
|
||||
#
|
||||
# [Router] → [Quantize-Dispatch] → [Permute-Experts-Unpermute] → [Combine]
|
||||
#
|
||||
# Each component will be independent of the others except for
|
||||
# Each component will be independent of (but may inform) the others except for
|
||||
# [Quantize-Dispatch] and `[Combine] (see below). The components can then be
|
||||
# mixed and matched with so that DP+EP can be supported easily for multiple
|
||||
# MoE kernel implementations.
|
||||
@@ -23,13 +32,19 @@ import torch
|
||||
# * FusedMoEPrepareAndFinalize - an abstract base class for preparation of MoE
|
||||
# inputs (e.g. quantization, distribution) and finalization of Moe outputs.
|
||||
# The prepare method must take care of any needed quantization and the
|
||||
# finalize method must apply weights and do the final reduction of the output.
|
||||
# finalize method, informed by the FusedMoEPermuteExpertsUnpermute method,
|
||||
# may apply weights and/or do the final reduction of the output.
|
||||
# * FusedMoEPermuteExpertsUnpermute - an abstract base class for the main fused
|
||||
# MoE operation. One important feature to note is that this class does not
|
||||
# apply topk weights or reduce the final output.
|
||||
# MoE operation, i.e matmul + act_mul + optionally quant + matmul.
|
||||
# Some FusedMoEPermuteExpertsUnpermute implementations may choose to do
|
||||
# the weight application and/or reduction. The class communicates this
|
||||
# to [Finalize] via a TopKWeightAndReduce object.
|
||||
# * FusedMoEModularKernel - an interface class that combines a
|
||||
# FusedMoEPrepareAndFinalize and a FusedMoEPermuteExpertsUnpermute to
|
||||
# provide the standard fused MoE kernel interface.
|
||||
# * TopKWeightAndReduce - A TopKWeightAndReduce implementation chosen
|
||||
# by the FusedMoEPermuteExpertsUnpermute implementation that is passed
|
||||
# on to [Finalize].
|
||||
#
|
||||
# [Quantize-Prepare] and [Finalize] functionality are bundled into a single
|
||||
# class `FusedMoEPrepareAndFinalize` since they could use collective
|
||||
@@ -77,6 +92,56 @@ def _moe_problem_size(
|
||||
return E, M, N, K, topk
|
||||
|
||||
|
||||
class FusedMoEActivationFormat(Enum):
|
||||
"""
|
||||
The standard activation format (num_tokens, hidden dim).
|
||||
"""
|
||||
Standard = "standard",
|
||||
"""
|
||||
The batched experts format (num experts, max tokens per expert, hidden dim)
|
||||
"""
|
||||
BatchedExperts = "batched_experts",
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpertTokensMetadata:
|
||||
"""
|
||||
Metadata regarding expert-token routing.
|
||||
"""
|
||||
expert_num_tokens: torch.Tensor
|
||||
expert_num_tokens_cpu: Optional[torch.Tensor]
|
||||
|
||||
@staticmethod
|
||||
def make_from_list(expert_num_tokens_list: list[int],
|
||||
device: str) -> "ExpertTokensMetadata":
|
||||
expert_num_tokens_cpu = torch.tensor(expert_num_tokens_list,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
return ExpertTokensMetadata(
|
||||
expert_num_tokens=expert_num_tokens_cpu.to(device,
|
||||
non_blocking=True),
|
||||
expert_num_tokens_cpu=expert_num_tokens_cpu)
|
||||
|
||||
|
||||
class TopKWeightAndReduce(ABC):
|
||||
"""
|
||||
An abstract base class for weight application and reduction implementations.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, output: Optional[torch.Tensor],
|
||||
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool) -> torch.Tensor:
|
||||
"""
|
||||
Apply topk_weights to the fused_experts_outputs and/or reduce.
|
||||
If an output tensor is not passed, it will be created in the
|
||||
function.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
|
||||
class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
An abstract base class for the [Quantize-Prepare] and [Finalize] steps
|
||||
@@ -85,17 +150,15 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor, num_experts: int,
|
||||
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
extra_prepare_args: Optional[dict[str, Any]]
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[ExpertTokensMetadata], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
"""
|
||||
Perform any quantization (and/or) dispatching needed
|
||||
for this kernel.
|
||||
@@ -114,22 +177,20 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
Returns a tuple of:
|
||||
- quantized + dispatched a.
|
||||
- quantized + dispatched a1_scales.
|
||||
- Optional tensor as big as number of local experts that contains the
|
||||
number of tokens assigned to each local expert.
|
||||
- Optional ExpertTokensMetadata containing gpu/cpu tensors
|
||||
as big as the number of local experts with the information about the
|
||||
number of tokens assigned to each local expert.
|
||||
- Optional dispatched expert topk IDs
|
||||
- Optional dispatched expert topk weight
|
||||
- Optional dispatched expert topk weight
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> None:
|
||||
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: TopKWeightAndReduce,
|
||||
extra_finalize_args: Optional[dict[str, Any]]) -> None:
|
||||
"""
|
||||
Perform any combine plus apply weights and perform a reduction on the
|
||||
fused experts output.
|
||||
@@ -140,6 +201,17 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
- topk_ids: The topk_ids.
|
||||
- apply_router_weight_on_input: When False, apply the weights to
|
||||
fused_expert_output.
|
||||
- weight_and_reduce_impl: An optional TopKWeightAndReduce
|
||||
implementation.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def activation_format(self) -> FusedMoEActivationFormat:
|
||||
"""
|
||||
A property indicating the output format of the activations for the
|
||||
'prepare' method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -159,11 +231,15 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
Some PrepareFinalize All2All implementations are batched. Meaning,
|
||||
they can processes only as set of tokens at a time. This
|
||||
function returns the batch size i.e the maximum number of tokens
|
||||
the implementation can process at a time.
|
||||
the implementation can process at a time.
|
||||
Return None if there are no such restrictions.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def num_dispatchers(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
"""
|
||||
@@ -171,6 +247,57 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
above.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: Optional[FusedMoEQuantConfig],
|
||||
):
|
||||
if quant_config is not None:
|
||||
self.quant_config = quant_config
|
||||
else:
|
||||
self.quant_config = FusedMoEQuantConfig()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def activation_formats(
|
||||
self) -> tuple[FusedMoEActivationFormat, FusedMoEActivationFormat]:
|
||||
"""
|
||||
A property which is a tuple of the input and output activation formats
|
||||
for the 'apply' method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def quant_dtype(self) -> Optional[torch.dtype]:
|
||||
return self.quant_config.quant_dtype
|
||||
|
||||
@property
|
||||
def block_shape(self) -> Optional[list[int]]:
|
||||
return self.quant_config.block_shape
|
||||
|
||||
@property
|
||||
def per_act_token_quant(self) -> bool:
|
||||
return self.quant_config.per_act_token_quant
|
||||
|
||||
@property
|
||||
def per_out_ch_quant(self) -> bool:
|
||||
return self.quant_config.per_out_ch_quant
|
||||
|
||||
# TODO (bnell): make this return a CHUNK_SIZE or None instead?
|
||||
@abstractmethod
|
||||
def supports_chunking(self) -> bool:
|
||||
"""
|
||||
A flag indicating whether or not this class supports activation
|
||||
chunking.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def supports_expert_map(self) -> bool:
|
||||
"""
|
||||
A flag indicating whether or not this class supports expert maps
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def workspace_shapes(
|
||||
self,
|
||||
@@ -180,20 +307,25 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
) -> tuple[int, int, torch.dtype]:
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
"""
|
||||
Compute the number of elements for the temporary outputs of the two
|
||||
gemms and activation in the fused expert function. Since the
|
||||
gemms are independent, the workspace for the first gemm can be shared
|
||||
with the workspace for the last gemm.
|
||||
Compute the shapes for the temporary and final outputs of the two gemms
|
||||
and activation in the fused expert function. Since the gemms are
|
||||
independent, the workspace for the first gemm can be shared with the
|
||||
workspace for the last gemm.
|
||||
|
||||
Returns a tuple of:
|
||||
- Number of workspace13 elements: must be large enough to hold the
|
||||
- workspace13 shape tuple: must be large enough to hold the
|
||||
result of either expert gemm.
|
||||
- Number of workspace2 elements: must be large enough to hold the
|
||||
- workspace2 shape tuple: must be large enough to hold the
|
||||
result of the activation function.
|
||||
- output shape tuple: must be exact size of the final gemm output.
|
||||
- Workspace type: The dtype to use for the workspace tensors.
|
||||
- Note: in order for activation chunking to work, the first dimension
|
||||
of each tuple must be the number of tokens.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -207,12 +339,21 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
else:
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||
|
||||
def enable_chunking(self):
|
||||
return envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and \
|
||||
self.supports_chunking()
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> TopKWeightAndReduce:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
@@ -225,17 +366,22 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
extra_expert_args: Optional[dict[str, Any]],
|
||||
):
|
||||
"""
|
||||
This function computes the intermediate result of a Mixture of Experts
|
||||
(MoE) layer using two sets of weights, w1 and w2.
|
||||
|
||||
Parameters:
|
||||
- output: (torch.Tensor): The unweighted, unreduced output tensor.
|
||||
- hidden_states: (torch.Tensor): The (quantized) input tensor to the MoE
|
||||
layer.
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- topk_weights: A map of row to expert weights. Some implementations
|
||||
choose to do weight application.
|
||||
- topk_ids (torch.Tensor): A map of row to expert id.
|
||||
- activation (str): The activation function to apply after the first
|
||||
MoE layer.
|
||||
@@ -257,15 +403,28 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
must be large enough to hold output of either MoE gemm.
|
||||
- workspace2 (torch.Tensor): A scratch tensor used for the activation
|
||||
function.
|
||||
- expert_num_tokens: An optional tensor containing the number of tokens
|
||||
assigned to each expert when using batched experts format input.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The unweighted, unreduced output tensor
|
||||
- expert_tokens_meta (Optional[ExpertTokensMetadata]) - An optional
|
||||
ExpertTokensMetadata object containing gpu/cpu tensors
|
||||
as big as the number of local experts with the information about the
|
||||
number of tokens assigned to each local expert.
|
||||
- apply_router_weight_on_input: True if router weights are already
|
||||
applied on the input. This is relevant if the implementation
|
||||
chooses to do weight application.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _chunk_scales(scales: Optional[torch.Tensor], start: int,
|
||||
end: int) -> Optional[torch.Tensor]:
|
||||
if scales is not None:
|
||||
if scales.numel() == 1:
|
||||
return scales
|
||||
else:
|
||||
return scales[start:end]
|
||||
return None
|
||||
|
||||
|
||||
@final
|
||||
class FusedMoEModularKernel(torch.nn.Module):
|
||||
"""
|
||||
This class combines a FusedMoEPrepareAndFinalize instance and
|
||||
@@ -287,46 +446,56 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.prepare_finalize = prepare_finalize
|
||||
self.fused_experts = fused_experts
|
||||
assert prepare_finalize.activation_format == \
|
||||
fused_experts.activation_formats[0], (
|
||||
f"{prepare_finalize.__class__.__name__}."
|
||||
f"{prepare_finalize.activation_format} == "
|
||||
f"{fused_experts.__class__.__name__}."
|
||||
f"{fused_experts.activation_formats[0]}")
|
||||
|
||||
def _do_fused_experts(
|
||||
self,
|
||||
a1: torch.Tensor, # input to forward fn
|
||||
a1q: torch.Tensor, # output of prepare fn
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_num_tokens: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
self, fused_out: Optional[torch.Tensor], a1: torch.Tensor,
|
||||
a1q: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
activation: str, global_num_experts: int, local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
extra_expert_args: Optional[dict[str, Any]]) -> torch.Tensor:
|
||||
|
||||
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
|
||||
|
||||
# Use a1 here to decipher the correct workspace datatype
|
||||
workspace13_shape, workspace2_shape, workspace_dtype = (
|
||||
self.fused_experts.workspace_shapes(a1, a1q, M, N, K, top_k,
|
||||
global_num_experts))
|
||||
(workspace13_shape, workspace2_shape, fused_out_shape,
|
||||
workspace_dtype) = self.fused_experts.workspace_shapes(
|
||||
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
|
||||
expert_tokens_meta)
|
||||
|
||||
# We can reuse the memory between cache1 and cache3 because by the time
|
||||
# we need cache3, we're done with cache1
|
||||
workspace13 = torch.zeros(workspace13_shape,
|
||||
# We can reuse the memory between cache1 and cache3 because by the
|
||||
# time we need cache3, we're done with cache1.
|
||||
workspace13 = torch.empty(prod(workspace13_shape),
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
workspace2 = torch.zeros(workspace2_shape,
|
||||
workspace2 = torch.empty(prod(workspace2_shape),
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
|
||||
fused_out = self.fused_experts.apply(
|
||||
assert fused_out is None or fused_out.shape == fused_out_shape, (
|
||||
f"fused_out {fused_out.shape} but expected {fused_out_shape}")
|
||||
if fused_out is None:
|
||||
# reuse workspace13 for the output
|
||||
fused_out = _resize_cache(workspace13, fused_out_shape)
|
||||
|
||||
self.fused_experts.apply(
|
||||
fused_out,
|
||||
a1q,
|
||||
w1,
|
||||
w2,
|
||||
topk_ids,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
@@ -338,8 +507,162 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
a2_scale=a2_scale,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_num_tokens=expert_num_tokens,
|
||||
)
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
extra_expert_args=extra_expert_args)
|
||||
|
||||
return fused_out
|
||||
|
||||
def _maybe_chunk_fused_experts(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1q: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
extra_expert_args: Optional[dict[str, Any]],
|
||||
) -> torch.Tensor:
|
||||
|
||||
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
|
||||
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
num_chunks = cdiv(M, CHUNK_SIZE)
|
||||
|
||||
if not self.fused_experts.supports_chunking() or num_chunks == 1:
|
||||
return self._do_fused_experts(
|
||||
fused_out=None,
|
||||
a1=a1,
|
||||
a1q=a1q,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
extra_expert_args=extra_expert_args)
|
||||
|
||||
# Chunking required case
|
||||
assert num_chunks > 1
|
||||
|
||||
# Construct the entire output that can then be processed in chunks.
|
||||
(_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes(
|
||||
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
|
||||
expert_tokens_meta)
|
||||
fused_out = torch.empty(fused_out_shape,
|
||||
device=a1q.device,
|
||||
dtype=a1.dtype)
|
||||
|
||||
def slice_input_tensors(
|
||||
chunk_idx: int
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], torch.Tensor, torch.Tensor]:
|
||||
s = chunk_idx * CHUNK_SIZE
|
||||
e = min(s + CHUNK_SIZE, M)
|
||||
return (a1q[s:e], _chunk_scales(a1q_scale, s, e),
|
||||
_chunk_scales(a2_scale, s,
|
||||
e), topk_ids[s:e], topk_weights[s:e])
|
||||
|
||||
def slice_output_tensor(chunk_idx: int) -> torch.Tensor:
|
||||
assert fused_out.size(0) % M == 0, (
|
||||
f"fused_out shape {fused_out.shape} vs M {M}")
|
||||
factor = fused_out.size(0) // M
|
||||
out_chunk_size = CHUNK_SIZE * factor
|
||||
s = chunk_idx * out_chunk_size
|
||||
e = min(s + out_chunk_size, fused_out.size(0))
|
||||
return fused_out[s:e]
|
||||
|
||||
def slice_expert_tokens_metadata(
|
||||
full_expert_tokens_meta: ExpertTokensMetadata,
|
||||
chunk_topk_ids: torch.Tensor, local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor]) -> ExpertTokensMetadata:
|
||||
# The existing expert_num_tokens is for the entire a1q
|
||||
# input. Chunking forces recomputation of the number
|
||||
# of tokens assigned to each expert.
|
||||
c_expert_num_tokens = count_expert_num_tokens(
|
||||
chunk_topk_ids, local_num_experts, expert_map)
|
||||
|
||||
c_expert_num_tokens_cpu = None
|
||||
need_expert_num_tokens_cpu = (
|
||||
full_expert_tokens_meta.expert_num_tokens_cpu is not None)
|
||||
if need_expert_num_tokens_cpu:
|
||||
# This is blocking as some implementations need the count
|
||||
# on the CPU to determine appropriate input/out fused-moe
|
||||
# buffers
|
||||
c_expert_num_tokens_cpu = c_expert_num_tokens.to(
|
||||
"cpu", non_blocking=False)
|
||||
|
||||
return ExpertTokensMetadata(
|
||||
expert_num_tokens=c_expert_num_tokens,
|
||||
expert_num_tokens_cpu=c_expert_num_tokens_cpu)
|
||||
|
||||
m = None
|
||||
if extra_expert_args is not None and 'm' in extra_expert_args:
|
||||
m = extra_expert_args.get('m')
|
||||
|
||||
if extra_expert_args is not None:
|
||||
chunked_extra_expert_args = extra_expert_args
|
||||
else:
|
||||
chunked_extra_expert_args = {}
|
||||
|
||||
for chunk_idx in range(num_chunks):
|
||||
c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = (
|
||||
slice_input_tensors(chunk_idx))
|
||||
|
||||
c_expert_tokens_meta = None
|
||||
if expert_tokens_meta is not None:
|
||||
c_expert_tokens_meta = slice_expert_tokens_metadata(
|
||||
expert_tokens_meta, c_topk_ids, local_num_experts,
|
||||
expert_map)
|
||||
|
||||
s = chunk_idx * CHUNK_SIZE
|
||||
e = min(s + CHUNK_SIZE, M)
|
||||
|
||||
if m is not None:
|
||||
chunked_extra_expert_args['m'] = e - s
|
||||
self._do_fused_experts(
|
||||
fused_out=slice_output_tensor(chunk_idx),
|
||||
a1=a1,
|
||||
a1q=c_a1q,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=c_topk_weights,
|
||||
topk_ids=c_topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=c_a1q_scale,
|
||||
a2_scale=c_a2_scale,
|
||||
expert_tokens_meta=c_expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
extra_expert_args=chunked_extra_expert_args)
|
||||
|
||||
return fused_out
|
||||
|
||||
@@ -361,6 +684,9 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
extra_expert_args: Optional[dict] = None,
|
||||
extra_prepare_args: Optional[dict] = None,
|
||||
extra_finalize_args: Optional[dict] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets
|
||||
@@ -393,6 +719,12 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
- apply_router_weight_on_input (bool): When true, the topk weights are
|
||||
applied directly on the inputs. This is only applicable when topk is
|
||||
1.
|
||||
- extra_expert_args (Optional[dict]): Extra keyword arguments to pass to
|
||||
fused_experts.apply.
|
||||
- extra_prepare_args (Optional[dict]): Extra keyword arguments to pass
|
||||
to prepare.
|
||||
- extra_finalize_args (Optional[dict]): Extra keyword arguments to pass
|
||||
to finalize.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
@@ -401,19 +733,31 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
a1 = hidden_states
|
||||
output = a1 if inplace else torch.zeros_like(a1)
|
||||
|
||||
local_num_experts = w1.size(0)
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = w1.size(0)
|
||||
global_num_experts = local_num_experts
|
||||
|
||||
(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
|
||||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
||||
_expert_topk_weights) = self.prepare_finalize.prepare(
|
||||
a1, a1_scale, a2_scale, topk_weights, topk_ids,
|
||||
global_num_experts, expert_map, apply_router_weight_on_input)
|
||||
a1,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.quant_config,
|
||||
extra_prepare_args,
|
||||
)
|
||||
|
||||
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
|
||||
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
|
||||
topk_weights = (topk_weights if _expert_topk_weights is None else
|
||||
_expert_topk_weights)
|
||||
|
||||
fused_out = None
|
||||
|
||||
if a1q.numel() == 0:
|
||||
# This happens when none of the tokens from the all2all reach this
|
||||
# EP rank. Also, note that this is only relevant for CUDAGraph
|
||||
@@ -423,24 +767,31 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
# and can never run into the tensor.numel() == 0 case.
|
||||
fused_out = torch.empty_like(a1q).to(dtype=a1.dtype)
|
||||
else:
|
||||
fused_out = self._do_fused_experts(
|
||||
fused_out = self._maybe_chunk_fused_experts(
|
||||
a1=a1,
|
||||
a1q=a1q,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_num_tokens=expert_num_tokens,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale)
|
||||
a2_scale=a2_scale,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
extra_expert_args=extra_expert_args)
|
||||
|
||||
self.prepare_finalize.finalize(output, fused_out, topk_weights,
|
||||
topk_ids, apply_router_weight_on_input)
|
||||
self.prepare_finalize.finalize(
|
||||
output, fused_out, topk_weights, topk_ids,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.finalize_weight_and_reduce_impl(),
|
||||
extra_finalize_args)
|
||||
|
||||
return output
|
||||
|
||||
146
vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py
Normal file
146
vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
|
||||
|
||||
class TopKWeightAndReduceDelegate(mk.TopKWeightAndReduce):
|
||||
"""
|
||||
Useful in the case when some FusedMoEPermuteExpertsUnpermute
|
||||
implementation does not perform weight application and reduction
|
||||
but cannot address the needs of all the compatible PrepareAndFinalize
|
||||
implementations.
|
||||
For example, BatchedTritonExperts is compatible with both
|
||||
PplxPrepareAndFinalize and BatchedPrepareAndFinalize. PplxPrepareAndFinalize
|
||||
does the weight-application + reduction as part of the pplx combine kernel.
|
||||
But the BatchedPrepareAndFinalize needs an implementation. To facilitate
|
||||
this case, the BatchedTritonExperts could use TopKWeightAndReduceDelegate
|
||||
so the PrepareAndFinalize implementations could choose how to
|
||||
weight + reduce.
|
||||
"""
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, TopKWeightAndReduceDelegate)
|
||||
|
||||
def apply(self, output: Optional[torch.Tensor],
|
||||
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool) -> torch.Tensor:
|
||||
raise RuntimeError("The caller is expected to choose an appropriate "
|
||||
"TopKWeightAndReduce implementation.")
|
||||
|
||||
|
||||
class TopKWeightAndReduceNoOP(mk.TopKWeightAndReduce):
|
||||
"""
|
||||
The fused_experts outputs have already been weight applied and reduced.
|
||||
This implementation is a no-op.
|
||||
"""
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, TopKWeightAndReduceNoOP)
|
||||
|
||||
def apply(self, output: Optional[torch.Tensor],
|
||||
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool) -> torch.Tensor:
|
||||
# Weight application and reduction operations are already done.
|
||||
if output is None:
|
||||
return fused_expert_output
|
||||
|
||||
# MoEPrepareAndFinalizeNoEP needs the output to be in the `output`
|
||||
# tensor.
|
||||
assert output.size() == fused_expert_output.size(), (
|
||||
"output shape is expected to match the fused_expert_output shape. "
|
||||
f"But got output={output.size()}, "
|
||||
f"used_expert_output={fused_expert_output.size()}")
|
||||
output.copy_(fused_expert_output, non_blocking=True)
|
||||
return output
|
||||
|
||||
|
||||
class TopKWeightAndReduceContiguous(mk.TopKWeightAndReduce):
|
||||
"""
|
||||
TopKWeightAndReduce implementation for a fused_experts output
|
||||
of shape (m, topk, K)
|
||||
"""
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, TopKWeightAndReduceContiguous)
|
||||
|
||||
def apply(self, output: Optional[torch.Tensor],
|
||||
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool) -> torch.Tensor:
|
||||
|
||||
m, num_topk = topk_ids.size()
|
||||
k = fused_expert_output.size(-1)
|
||||
if fused_expert_output.ndim == 2:
|
||||
fused_expert_output = fused_expert_output.view(m, num_topk, k)
|
||||
|
||||
assert fused_expert_output.size() == (m, num_topk, k), (
|
||||
f"Expected fused_expert_output size {(m, num_topk, k)}. But got "
|
||||
f"{fused_expert_output.size()}")
|
||||
|
||||
if not apply_router_weight_on_input:
|
||||
fused_expert_output.mul_(topk_weights.view(m, -1, 1))
|
||||
|
||||
if output is None:
|
||||
output = torch.empty((m, k),
|
||||
device=fused_expert_output.device,
|
||||
dtype=fused_expert_output.dtype)
|
||||
assert output.size() == (m, k), (
|
||||
f"Expected output size {(m, k)}. But got {output.size()}")
|
||||
|
||||
ops.moe_sum(fused_expert_output, output)
|
||||
return output
|
||||
|
||||
|
||||
class TopKWeightAndReduceNaiveBatched(mk.TopKWeightAndReduce):
|
||||
"""
|
||||
TopKWeightAndReduce implementation for a fused_experts output
|
||||
of shape (num_experts, batch_size, K)
|
||||
"""
|
||||
|
||||
def __init__(self, rank: int):
|
||||
self.rank = rank
|
||||
|
||||
def __eq__(self, other):
|
||||
return (isinstance(other, TopKWeightAndReduceNaiveBatched)
|
||||
and (other.rank == self.rank))
|
||||
|
||||
def apply(self, output: Optional[torch.Tensor],
|
||||
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool) -> torch.Tensor:
|
||||
assert fused_expert_output.ndim == 3
|
||||
num_tokens = topk_ids.size(0)
|
||||
num_local_experts = fused_expert_output.size(0)
|
||||
K = fused_expert_output.size(-1)
|
||||
|
||||
if output is None:
|
||||
output = torch.zeros((num_tokens, K),
|
||||
device=fused_expert_output.device,
|
||||
dtype=fused_expert_output.dtype)
|
||||
else:
|
||||
output.fill_(0)
|
||||
|
||||
assert output.size() == (num_tokens, K), (
|
||||
f"Expected output size {(num_tokens, K)}, but got {output.size()}")
|
||||
|
||||
first_expert = num_local_experts * self.rank
|
||||
last_expert = first_expert + num_local_experts
|
||||
|
||||
for expert_id in range(first_expert, last_expert):
|
||||
matching_tokens = topk_ids == expert_id
|
||||
topks = torch.any(matching_tokens, dim=1).flatten()
|
||||
rows = torch.count_nonzero(topks)
|
||||
rhs = fused_expert_output[expert_id - first_expert, :rows, :]
|
||||
if not apply_router_weight_on_input:
|
||||
rhs.mul_(topk_weights[matching_tokens].view(rhs.size(0), 1))
|
||||
output[topks] = output[topks] + rhs
|
||||
|
||||
return output
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from math import prod
|
||||
from typing import Optional
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -10,7 +10,83 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
per_token_group_quant_int8, per_token_quant_int8)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
quant_dequant_mxfp4)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import cdiv
|
||||
# from vllm.utils.flashinfer import fp4_quantize
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _count_expert_num_tokens(topk_ids_ptr, expert_num_tokens_ptr, num_experts,
|
||||
topk_numel, expert_map,
|
||||
HAS_EXPERT_MAP: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr):
|
||||
|
||||
curr_expert = tl.program_id(0)
|
||||
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
topk_ids_ptrs = topk_ids_ptr + offsets
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE, ), dtype=tl.int32)
|
||||
for x in range(tl.cdiv(topk_numel, BLOCK_SIZE)):
|
||||
mask = offsets < (topk_numel - x * BLOCK_SIZE)
|
||||
expert_ids = tl.load(topk_ids_ptrs, mask=mask, other=-1)
|
||||
if HAS_EXPERT_MAP:
|
||||
expert_map_ptrs = expert_map + expert_ids
|
||||
expert_map_mask = expert_ids >= 0
|
||||
expert_ids = tl.load(expert_map_ptrs,
|
||||
mask=expert_map_mask,
|
||||
other=-1)
|
||||
|
||||
has_curr_expert = tl.where(expert_ids == curr_expert, 1, 0)
|
||||
acc = acc + has_curr_expert
|
||||
topk_ids_ptrs += BLOCK_SIZE
|
||||
|
||||
if curr_expert < num_experts:
|
||||
tl.store(expert_num_tokens_ptr + curr_expert, tl.sum(acc))
|
||||
|
||||
|
||||
def count_expert_num_tokens(
|
||||
topk_ids: torch.Tensor, num_local_experts: int,
|
||||
expert_map: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Count the number to tokens assigned to each expert.
|
||||
|
||||
Parameters:
|
||||
- topk_ids (torch.Tensor): Tensor mapping each token to its
|
||||
list of experts.
|
||||
- num_local_experts (int): Number of experts in this rank.
|
||||
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
||||
from the global expert space to the local expert space of the expert
|
||||
parallel shard.
|
||||
|
||||
Returns:
|
||||
A tensor of size num_local_experts, where tensor[i] holds the number
|
||||
of tokens assigned to the ith expert.
|
||||
"""
|
||||
assert topk_ids.dtype.is_signed, (
|
||||
"The kernel uses -1 to represent invalid topk_ids")
|
||||
expert_num_tokens = torch.empty((num_local_experts),
|
||||
device=topk_ids.device,
|
||||
dtype=torch.int32)
|
||||
|
||||
grid = num_local_experts
|
||||
BLOCK_SIZE = min(topk_ids.numel(), 1024)
|
||||
BLOCK_SIZE = triton.next_power_of_2(BLOCK_SIZE)
|
||||
|
||||
_count_expert_num_tokens[(grid, )](
|
||||
topk_ids,
|
||||
expert_num_tokens,
|
||||
num_local_experts,
|
||||
topk_ids.numel(),
|
||||
expert_map,
|
||||
HAS_EXPERT_MAP=expert_map is not None,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
return expert_num_tokens
|
||||
|
||||
|
||||
def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor:
|
||||
@@ -23,6 +99,16 @@ def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor:
|
||||
return x.flatten()[:prod(v)].view(*v)
|
||||
|
||||
|
||||
# def _fp4_quantize(
|
||||
# A: torch.Tensor,
|
||||
# A_scale: Optional[torch.Tensor],
|
||||
# is_sf_swizzled_layout: bool,
|
||||
# ) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# return fp4_quantize(A,
|
||||
# A_scale,
|
||||
# is_sf_swizzled_layout=is_sf_swizzled_layout)
|
||||
|
||||
|
||||
def _fp8_quantize(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
@@ -34,9 +120,12 @@ def _fp8_quantize(
|
||||
is provided, the output will be blocked.
|
||||
"""
|
||||
if block_shape is None:
|
||||
# TODO(luka): use QuantFP8 custom op
|
||||
# https://github.com/vllm-project/vllm/issues/20711
|
||||
A, A_scale = ops.scaled_fp8_quant(
|
||||
A, A_scale, use_per_token_if_dynamic=per_act_token)
|
||||
else:
|
||||
assert not per_act_token
|
||||
assert len(block_shape) == 2
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
||||
@@ -62,9 +151,9 @@ def _int8_quantize(
|
||||
if block_shape is None:
|
||||
assert per_act_token, \
|
||||
"int8 quantization only supports block or channel-wise"
|
||||
# A, A_scale = per_token_quant_int8(A)
|
||||
A, A_scale, _ = ops.scaled_int8_quant(A, A_scale)
|
||||
A, A_scale = per_token_quant_int8(A)
|
||||
else:
|
||||
assert not per_act_token
|
||||
assert len(block_shape) == 2
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
A, A_scale = per_token_group_quant_int8(A, block_k)
|
||||
@@ -73,19 +162,40 @@ def _int8_quantize(
|
||||
return A, A_scale
|
||||
|
||||
|
||||
def _mxfp4_quantize(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> tuple[torch.Tensor, None]:
|
||||
assert block_shape is None
|
||||
if not current_platform.supports_mx():
|
||||
A = quant_dequant_mxfp4(A)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return A, None
|
||||
|
||||
|
||||
def moe_kernel_quantize_input(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
qtype: Optional[torch.dtype],
|
||||
per_channel_quant: bool,
|
||||
quant_dtype: Union[None, torch.dtype, str],
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
is_fp4_scale_swizzled: bool = True,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if qtype == torch.float8_e4m3fn:
|
||||
return _fp8_quantize(A, A_scale, per_channel_quant, block_shape)
|
||||
elif qtype == torch.int8:
|
||||
return _int8_quantize(A, A_scale, per_channel_quant, block_shape)
|
||||
if quant_dtype == torch.float8_e4m3fn:
|
||||
return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
|
||||
elif quant_dtype == torch.int8:
|
||||
return _int8_quantize(A, A_scale, per_act_token_quant, block_shape)
|
||||
elif quant_dtype == torch.uint8: # nvfp4
|
||||
return _fp4_quantize(A,
|
||||
A_scale,
|
||||
is_sf_swizzled_layout=is_fp4_scale_swizzled)
|
||||
elif quant_dtype == "mxfp4":
|
||||
return _mxfp4_quantize(A, A_scale, per_act_token_quant, block_shape)
|
||||
else:
|
||||
assert A_scale is None
|
||||
return A, A_scale
|
||||
|
||||
|
||||
@@ -97,3 +207,62 @@ def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
||||
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
|
||||
else:
|
||||
return m[idx, ...]
|
||||
|
||||
|
||||
def normalize_scales_shape(
|
||||
scales: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
||||
if scales is not None:
|
||||
if scales.numel() == 1:
|
||||
scales = scales.view(1, 1)
|
||||
else:
|
||||
scales = scales.view(-1, scales.size(-1))
|
||||
return scales
|
||||
|
||||
|
||||
def normalize_batched_scales_shape(
|
||||
scales: Optional[torch.Tensor],
|
||||
num_experts: int,
|
||||
) -> Optional[torch.Tensor]:
|
||||
if scales is not None and scales.ndim < 3:
|
||||
if scales.numel() == 1:
|
||||
scales = scales.view(1)
|
||||
scales = torch.repeat_interleave(scales, num_experts,
|
||||
dim=0).view(num_experts, 1, 1)
|
||||
else:
|
||||
scales = scales.view(num_experts, -1, scales.size(-1))
|
||||
|
||||
return scales
|
||||
|
||||
|
||||
def _validate_scale_shape(
|
||||
a: torch.Tensor,
|
||||
a_scale: Optional[torch.Tensor],
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]],
|
||||
) -> None:
|
||||
if a_scale is None:
|
||||
return
|
||||
|
||||
if not per_act_token_quant and block_shape is None:
|
||||
assert a_scale.numel() == 1, f"{a_scale.shape}"
|
||||
elif per_act_token_quant:
|
||||
assert a_scale.shape[0] == a.shape[0] and a_scale.shape[1] == 1, (
|
||||
f"{a_scale.shape[0]} == {a.shape[0]} and {a_scale.shape[1]} == 1")
|
||||
else:
|
||||
assert block_shape is not None
|
||||
expected = (a.shape[0], cdiv(a.shape[1], block_shape[1]))
|
||||
assert a_scale.shape == expected, f"{a_scale.shape} == {expected}"
|
||||
|
||||
|
||||
def extract_required_args(
|
||||
extra_args: Optional[dict[str, Any]],
|
||||
required_keys: list[str],
|
||||
) -> tuple[Any, ...]:
|
||||
if extra_args is None:
|
||||
raise ValueError("`extra_args` must be provided.")
|
||||
|
||||
missing_keys = [k for k in required_keys if k not in extra_args]
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing keys in `extra_args`: {missing_keys}")
|
||||
|
||||
return tuple(extra_args[k] for k in required_keys)
|
||||
|
||||
@@ -36,6 +36,7 @@ QuantizationMethods = Literal[
|
||||
"moe_wna16",
|
||||
"torchao",
|
||||
"auto-round",
|
||||
"mxfp4",
|
||||
]
|
||||
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
|
||||
|
||||
@@ -108,6 +109,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
from .marlin import MarlinConfig
|
||||
from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config
|
||||
from .moe_wna16 import MoeWNA16Config
|
||||
from .mxfp4 import Mxfp4Config
|
||||
from .neuron_quant import NeuronQuantConfig
|
||||
from .ptpc_fp8 import PTPCFp8Config
|
||||
from .qqq import QQQConfig
|
||||
@@ -143,6 +145,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
"moe_wna16": MoeWNA16Config,
|
||||
"torchao": TorchAOConfig,
|
||||
"auto-round": AutoRoundConfig,
|
||||
"mxfp4": Mxfp4Config,
|
||||
}
|
||||
# Update the `method_to_config` with customized quantization methods.
|
||||
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
|
||||
|
||||
581
vllm/model_executor/layers/quantization/mxfp4.py
Normal file
581
vllm/model_executor/layers/quantization/mxfp4.py
Normal file
@@ -0,0 +1,581 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import envs
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||
FusedMoEMethodBase, fused_experts)
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||
triton_kernel_moe_forward)
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
prepare_moe_fp4_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
_can_support_mxfp4, _swizzle_mxfp4)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils import (has_triton_kernels, is_torch_equal_or_newer,
|
||||
next_power_of_2, round_up)
|
||||
|
||||
if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
||||
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
|
||||
# from flashinfer.fused_moe import cutlass_fused_moe
|
||||
from flashinfer import (mxfp8_quantize, shuffle_matrix_a,
|
||||
shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe)
|
||||
|
||||
|
||||
class Mxfp4Config(QuantizationConfig):
|
||||
|
||||
def __init__(self, ignored_layers: Optional[list[str]] = None):
|
||||
super().__init__()
|
||||
self.ignored_layers = ignored_layers
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls()
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "mxfp4"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if self.ignored_layers and is_layer_skipped(
|
||||
prefix=prefix,
|
||||
ignored_layers=self.ignored_layers,
|
||||
fused_mapping=self.packed_modules_mapping):
|
||||
return UnquantizedLinearMethod()
|
||||
raise NotImplementedError("Mxfp4 linear layer is not implemented")
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return Mxfp4MoEMethod(layer.moe_config)
|
||||
elif isinstance(layer, Attention):
|
||||
raise NotImplementedError(
|
||||
"Mxfp4 attention layer is not implemented")
|
||||
return None
|
||||
|
||||
|
||||
class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, moe: FusedMoEConfig):
|
||||
super().__init__()
|
||||
self.topk_indices_dtype = None
|
||||
self.moe = moe
|
||||
self.use_marlin = self._should_use_marlin()
|
||||
|
||||
def _should_use_marlin(self):
|
||||
if envs.VLLM_MXFP4_USE_MARLIN is not None:
|
||||
return envs.VLLM_MXFP4_USE_MARLIN
|
||||
# if current_platform.is_cuda() and \
|
||||
# not current_platform.has_device_capability(100):
|
||||
# if not current_platform.is_device_capability(90):
|
||||
# # marlin kernel has better performance on ampere
|
||||
# return True
|
||||
# if not has_triton_kernels():
|
||||
# return True
|
||||
# if not is_torch_equal_or_newer("2.8.0"):
|
||||
# return True
|
||||
return False
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
self.num_experts = num_experts
|
||||
weight_dtype = torch.uint8
|
||||
scale_dtype = torch.uint8
|
||||
|
||||
# FIXME (zyongye): ship after torch and safetensors support mxfp4
|
||||
# is_torch_mxfp4_available = (
|
||||
# hasattr(torch, "float4_e2m1fn_x2") and
|
||||
# hasattr(torch, "float8_e8m0fnu"))
|
||||
# if is_torch_mxfp4_available:
|
||||
# weight_dtype = torch.float4_e2m1fn_x2
|
||||
# scale_dtype = torch.float8_e8m0fnu
|
||||
|
||||
mxfp4_block = 32
|
||||
|
||||
intermediate_size_per_partition_after_pad = \
|
||||
intermediate_size_per_partition
|
||||
if self.use_marlin:
|
||||
# The moe marlin kernel requires that for each linear
|
||||
# n % 256 == 0 and k % 128 == 0.
|
||||
# In gate_up_proj:
|
||||
# n = 2 * intermediate_size_per_partition_after_pad
|
||||
# k = hidden_size
|
||||
# In down_proj
|
||||
# n = hidden_size
|
||||
# k = intermediate_size_per_partition_after_pad
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 128)
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
|
||||
layer.params_dtype = params_dtype
|
||||
layer.num_experts = num_experts
|
||||
layer.hidden_size = hidden_size
|
||||
layer.intermediate_size_per_partition = \
|
||||
intermediate_size_per_partition_after_pad
|
||||
elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
||||
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
|
||||
# pad the intermediate size to be a multiple of 2 * mxfp4_block
|
||||
# for to hold non-uniform sharded tensor as well as swizzling
|
||||
# other padding to increase performance
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 256)
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
elif current_platform.is_rocm():
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 128)
|
||||
else:
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 64)
|
||||
|
||||
self.intermediate_size = intermediate_size_per_partition_after_pad
|
||||
self.hidden_size = hidden_size
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
hidden_size // 2,
|
||||
dtype=weight_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
hidden_size // mxfp4_block,
|
||||
dtype=scale_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
|
||||
w13_bias = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
dtype=torch.bfloat16,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_bias", w13_bias)
|
||||
set_weight_attrs(w13_bias, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition_after_pad // 2,
|
||||
dtype=weight_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition_after_pad // mxfp4_block,
|
||||
dtype=scale_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
w2_bias = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
dtype=torch.bfloat16,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_bias", w2_bias)
|
||||
set_weight_attrs(w2_bias, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.use_marlin:
|
||||
prepare_moe_fp4_layer_for_marlin(layer)
|
||||
elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
||||
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
|
||||
layer.gemm1_alpha = Parameter(torch.tensor(
|
||||
[1.702] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False)
|
||||
layer.gemm1_beta = Parameter(torch.tensor(
|
||||
[1.0] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False)
|
||||
layer.gemm1_clamp_limit = Parameter(torch.tensor(
|
||||
[7.0] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False)
|
||||
sf_block_size = 32 # mxfp4 block size
|
||||
|
||||
assert (layer.w13_weight.dim() == 3
|
||||
and layer.w13_weight.shape[0] == self.num_experts
|
||||
and layer.w13_weight.shape[1] == self.intermediate_size * 2
|
||||
and layer.w13_weight.shape[2] == self.hidden_size // 2)
|
||||
assert (layer.w13_weight_scale.dim() == 3
|
||||
and layer.w13_weight_scale.shape[0] == self.num_experts
|
||||
and layer.w13_weight_scale.shape[1]
|
||||
== self.intermediate_size * 2
|
||||
and layer.w13_weight_scale.shape[2]
|
||||
== self.hidden_size // sf_block_size)
|
||||
assert (layer.w2_weight.dim() == 3
|
||||
and layer.w2_weight.shape[0] == self.num_experts
|
||||
and layer.w2_weight.shape[1] == self.hidden_size and
|
||||
layer.w2_weight.shape[2] == self.intermediate_size // 2)
|
||||
assert (layer.w2_weight_scale.dim() == 3
|
||||
and layer.w2_weight_scale.shape[1] == self.hidden_size
|
||||
and layer.w2_weight_scale.shape[2]
|
||||
== self.intermediate_size // sf_block_size)
|
||||
assert (layer.w13_bias.dim() == 2
|
||||
and layer.w13_bias.shape[0] == self.num_experts
|
||||
and layer.w13_bias.shape[1] == self.intermediate_size * 2)
|
||||
assert (layer.w2_bias.dim() == 2
|
||||
and layer.w2_bias.shape[0] == self.num_experts
|
||||
and layer.w2_bias.shape[1] == self.hidden_size)
|
||||
|
||||
w13_weight_scale = layer.w13_weight_scale.data
|
||||
w2_weight_scale = layer.w2_weight_scale.data
|
||||
w13_weight = layer.w13_weight.data
|
||||
w2_weight = layer.w2_weight.data
|
||||
w13_bias = layer.w13_bias.data.to(torch.float32)
|
||||
w2_bias = layer.w2_bias.data.to(torch.float32)
|
||||
|
||||
# Swap w1 and w3 as the defenition of
|
||||
# swiglu is different in the trtllm-gen
|
||||
def swap_every_two_rows(x, axis=-1):
|
||||
shape = x.shape
|
||||
if axis < 0:
|
||||
axis = len(shape) + axis
|
||||
|
||||
# Create a new shape with pairs swapped along specified axis
|
||||
new_shape = list(shape)
|
||||
new_shape[axis] = shape[axis] // 2
|
||||
new_shape.insert(axis + 1, 2)
|
||||
|
||||
# Reshape to expose pairs, swap them, and reshape back
|
||||
x = x.reshape(*new_shape)
|
||||
x = x.flip(axis + 1)
|
||||
new_shape = list(shape)
|
||||
return x.reshape(*new_shape)
|
||||
|
||||
w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2)
|
||||
w13_weight = swap_every_two_rows(w13_weight, -2)
|
||||
w13_bias = swap_every_two_rows(w13_bias, -1)
|
||||
|
||||
# Do not interleave as the checkpoint is already interleaved
|
||||
|
||||
# Shuffle weights and scaling factors for transposed mma output
|
||||
gemm1_weights_mxfp4_shuffled = []
|
||||
gemm1_scales_mxfp4_shuffled = []
|
||||
gemm2_weights_mxfp4_shuffled = []
|
||||
gemm2_scales_mxfp4_shuffled = []
|
||||
gemm1_bias_shuffled = []
|
||||
gemm2_bias_shuffled = []
|
||||
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
||||
for i in range(self.num_experts):
|
||||
gemm1_weights_mxfp4_shuffled.append(
|
||||
shuffle_matrix_a(w13_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
gemm1_scales_mxfp4_shuffled.append(
|
||||
shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
gemm1_bias_shuffled.append(
|
||||
shuffle_matrix_a(w13_bias[i].clone().reshape(-1, 1),
|
||||
epilogue_tile_m))
|
||||
|
||||
gemm2_weights_mxfp4_shuffled.append(
|
||||
shuffle_matrix_a(w2_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
gemm2_scales_mxfp4_shuffled.append(
|
||||
shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
gemm2_bias_shuffled.append(
|
||||
shuffle_matrix_a(w2_bias[i].clone().reshape(-1, 1),
|
||||
epilogue_tile_m))
|
||||
|
||||
w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled)
|
||||
w13_weight_scale = torch.stack(
|
||||
gemm1_scales_mxfp4_shuffled).reshape(
|
||||
self.num_experts, 2 * self.intermediate_size,
|
||||
self.hidden_size // sf_block_size).view(
|
||||
torch.float8_e4m3fn)
|
||||
|
||||
w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled)
|
||||
w2_weight_scale = torch.stack(gemm2_scales_mxfp4_shuffled).reshape(
|
||||
self.num_experts, self.hidden_size, self.intermediate_size //
|
||||
sf_block_size).view(torch.float8_e4m3fn)
|
||||
|
||||
layer.w13_weight = Parameter(w13_weight, requires_grad=False)
|
||||
layer.w13_weight_scale = Parameter(w13_weight_scale,
|
||||
requires_grad=False)
|
||||
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
|
||||
layer.w2_weight_scale = Parameter(w2_weight_scale,
|
||||
requires_grad=False)
|
||||
layer.w13_bias = Parameter(
|
||||
torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1),
|
||||
requires_grad=False)
|
||||
layer.w2_bias = Parameter(torch.stack(gemm2_bias_shuffled).reshape(
|
||||
self.num_experts, -1),
|
||||
requires_grad=False)
|
||||
elif has_triton_kernels():
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
|
||||
w13_bias = layer.w13_bias.to(torch.float32)
|
||||
w2_bias = layer.w2_bias.to(torch.float32)
|
||||
|
||||
layer.w13_bias = Parameter(w13_bias, requires_grad=False)
|
||||
layer.w2_bias = Parameter(w2_bias, requires_grad=False)
|
||||
|
||||
# FIXME warp need to be adjusted based on batch size
|
||||
# only apply to batched mode
|
||||
if self.moe.use_ep:
|
||||
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
|
||||
else:
|
||||
num_warps = 8
|
||||
|
||||
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
|
||||
layer.w13_weight, layer.w13_weight_scale, num_warps)
|
||||
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
|
||||
layer.w2_weight, layer.w2_weight_scale, num_warps)
|
||||
|
||||
self.w13_precision_config = PrecisionConfig(
|
||||
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex))
|
||||
self.w2_precision_config = PrecisionConfig(
|
||||
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex))
|
||||
|
||||
self.w13_weight_triton_tensor = w13_weight
|
||||
self.w2_weight_triton_tensor = w2_weight
|
||||
|
||||
# need to delete the original weights to save memory on single GPU
|
||||
del layer.w13_weight
|
||||
del layer.w2_weight
|
||||
layer.w13_weight = None
|
||||
layer.w2_weight = None
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
# normal triton
|
||||
from .triton_kernels_numerics_details.mxfp import upcast_from_mxfp
|
||||
w13_weight = upcast_from_mxfp(
|
||||
layer.w13_weight, layer.w13_weight_scale, dtype=torch.bfloat16, axis=-1
|
||||
)
|
||||
w2_weight = upcast_from_mxfp(
|
||||
layer.w2_weight, layer.w2_weight_scale, dtype=torch.bfloat16, axis=-1
|
||||
)
|
||||
del layer.w13_weight
|
||||
del layer.w2_weight
|
||||
del layer.w13_weight_scale
|
||||
del layer.w2_weight_scale
|
||||
layer.w13_weight = Parameter(w13_weight.data, requires_grad=False)
|
||||
layer.w2_weight = Parameter(w2_weight.data, requires_grad=False)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
|
||||
# Number of tokens in the input tensor.
|
||||
num_tokens = x.shape[0]
|
||||
# Factor to account for the imbalance of the experts.
|
||||
# factor equals to the
|
||||
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
|
||||
# - 1.0 means perfect expert distribution.
|
||||
# - > 1.0 means some experts have more
|
||||
# tokens than the perfect distribution.
|
||||
# - < 1.0 does not make sense.
|
||||
imbalance_factor = 1.3
|
||||
# Calculate the number of tokens per expert
|
||||
# assuming perfect distribution.
|
||||
num_tokens_per_expert = (num_tokens * top_k) // self.num_experts
|
||||
# Apply the imbalance factor.
|
||||
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
|
||||
# And pad the number to the next power of 2.
|
||||
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
||||
# Cap to 8-64 tokens per CTA tile
|
||||
# as it's the range supported by the kernel.
|
||||
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
||||
|
||||
return tile_tokens_dim
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError("EPLB is not supported for mxfp4")
|
||||
|
||||
if self.use_marlin:
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
layer.w13_bias,
|
||||
layer.w2_bias,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_scale1=None,
|
||||
global_scale2=None,
|
||||
quant_type_id=scalar_types.float4_e2m1f.id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
activation=activation,
|
||||
expert_map=expert_map)
|
||||
|
||||
assert _can_support_mxfp4(
|
||||
use_grouped_topk, topk_group, num_expert_group, expert_map,
|
||||
custom_routing_function, e_score_correction_bias,
|
||||
apply_router_weight_on_input, scoring_func, activation,
|
||||
expert_load_view, logical_to_physical_map,
|
||||
logical_replica_count), (
|
||||
"MXFP4 are not supported with this configuration.")
|
||||
|
||||
if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
||||
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
|
||||
assert not self.moe.use_ep, (
|
||||
"EP is not supported for flashinfer mxfp4 moe backend yet.")
|
||||
if envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16:
|
||||
assert x.dtype == torch.bfloat16
|
||||
x_quant = x
|
||||
x_scale = None
|
||||
else:
|
||||
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
|
||||
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
|
||||
trtllm_gen_output = trtllm_fp4_block_scale_moe(
|
||||
router_logits.to(torch.bfloat16),
|
||||
None, # routing_bias
|
||||
x_quant,
|
||||
x_scale,
|
||||
layer.w13_weight, # uint8 (e2m1 x 2)
|
||||
layer.w13_weight_scale, # uint8 (e4m3 x 2)
|
||||
layer.w13_bias, # fp32 per expert per channel
|
||||
layer.gemm1_alpha, # fp32 per expert
|
||||
layer.gemm1_beta, # fp32 per expert
|
||||
layer.gemm1_clamp_limit, # fp32 per expert
|
||||
layer.w2_weight, # uint8 (e2m1 x 2)
|
||||
layer.w2_weight_scale, # ue8m0
|
||||
layer.w2_bias, # fp32 per expert per channel
|
||||
None, # output1_scale_scalar
|
||||
None, # output1_scale_gate_scalar
|
||||
None, # output2_scale_scalar
|
||||
self.num_experts,
|
||||
top_k,
|
||||
None, # n_group
|
||||
None, # topk_group
|
||||
self.intermediate_size, # padded to multiple of 256
|
||||
0, # local_expert_offset
|
||||
self.num_experts, # local num experts
|
||||
None,
|
||||
self._get_tile_tokens_dim(x, top_k),
|
||||
1 if renormalize else 0, # routing_method_type, renormalize
|
||||
True, # do finalize
|
||||
)[0]
|
||||
return trtllm_gen_output
|
||||
elif has_triton_kernels():
|
||||
return triton_kernel_moe_forward(
|
||||
hidden_states=x,
|
||||
w1=self.w13_weight_triton_tensor,
|
||||
w2=self.w2_weight_triton_tensor,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_bias=layer.w13_bias,
|
||||
w2_bias=layer.w2_bias,
|
||||
w1_precision=self.w13_precision_config,
|
||||
w2_precision=self.w2_precision_config,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_bias=layer.w13_bias,
|
||||
w2_bias=layer.w2_bias,
|
||||
)
|
||||
@@ -6,14 +6,16 @@ from typing import Any, Callable, Optional
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
OCP_MX_BLOCK_SIZE, per_token_group_quant_mxfp4)
|
||||
OCP_MX_BLOCK_SIZE, dequant_mxfp4, quant_dequant_mxfp4)
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = ["QuarkW4A4MXFP4"]
|
||||
|
||||
|
||||
@@ -25,7 +27,29 @@ class QuarkW4A4MXFP4(QuarkScheme):
|
||||
self.qscheme = "per_group"
|
||||
self.weight_quant_spec = weight_quant_spec
|
||||
self.input_quant_spec = input_quant_spec
|
||||
self.emulate = not current_platform.supports_mx()
|
||||
|
||||
self.static_input_scales = not input_quant_spec.get("is_dynamic")
|
||||
|
||||
if self.static_input_scales:
|
||||
raise NotImplementedError(
|
||||
"QuarkW4A4MXFP4 with static input scales is currently not "
|
||||
"implemented. Please open an issue.")
|
||||
|
||||
if not current_platform.supports_mx():
|
||||
self.emulate = True
|
||||
logger.warning_once(
|
||||
"The current platform does not support native MXFP4 "
|
||||
"computation. Simulated weight dequantization and activation "
|
||||
"QDQ (quantize and dequantize) will be used, with the linear "
|
||||
"layers computed in high precision.")
|
||||
else:
|
||||
self.emulate = True
|
||||
logger.warning_once(
|
||||
"The current platform supports native MXFP4 "
|
||||
"computation, but kernels are not yet integrated in vLLM. "
|
||||
"Simulated weight dequantization and activation "
|
||||
"QDQ (quantize and dequantize) will be used, with the linear "
|
||||
"layers computed in high precision.")
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
@@ -37,43 +61,6 @@ class QuarkW4A4MXFP4(QuarkScheme):
|
||||
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
|
||||
requires_grad=False)
|
||||
|
||||
if self.emulate:
|
||||
try:
|
||||
from quark.torch.export.nn.modules import realquantizer
|
||||
from quark.torch.quantization.config.config import (
|
||||
QuantizationSpec)
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"The package `amd-quark` is required to use AMD Quark "
|
||||
"MX-FP4 models. Please install it with `pip install "
|
||||
"amd-quark`.") from err
|
||||
|
||||
weight_quant_spec = QuantizationSpec.from_dict(
|
||||
self.weight_quant_spec)
|
||||
|
||||
weight_quantizer = realquantizer.get_real_quantizer(
|
||||
qspec=weight_quant_spec,
|
||||
quantizer=None,
|
||||
real_quantized=True,
|
||||
reorder=False,
|
||||
float_dtype=self.out_dtype,
|
||||
scale_shape=layer.weight_scale.shape,
|
||||
zero_point_shape=None,
|
||||
)
|
||||
weight_quantizer.scale.data = layer.weight_scale.data
|
||||
|
||||
if not envs.VLLM_QUARK_EMU_MEM_OPT:
|
||||
layer.weight = torch.nn.Parameter(
|
||||
weight_quantizer(layer.weight.data).to(self.out_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
else:
|
||||
self.weight_quantizer = weight_quantizer
|
||||
layer.weight_scale = None
|
||||
|
||||
# This call is necessary to release the scales memory.
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
@@ -116,11 +103,10 @@ class QuarkW4A4MXFP4(QuarkScheme):
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
if self.emulate:
|
||||
if envs.VLLM_QUARK_EMU_MEM_OPT:
|
||||
dq_w = self.weight_quantizer(layer.weight).to(self.out_dtype)
|
||||
else:
|
||||
dq_w = layer.weight
|
||||
qdq_x, _ = per_token_group_quant_mxfp4(x, OCP_MX_BLOCK_SIZE)
|
||||
return F.linear(qdq_x, dq_w, bias)
|
||||
dq_w = dequant_mxfp4(layer.weight, layer.weight_scale, x.dtype)
|
||||
|
||||
x = quant_dequant_mxfp4(x)
|
||||
|
||||
return F.linear(x, dq_w, bias)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -0,0 +1,158 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
# fmt: off
|
||||
|
||||
|
||||
MXFP_BLOCK_SIZE = tl.constexpr(32)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _get_max_quant_val(dtype: tl.constexpr):
|
||||
if dtype == tl.uint8:
|
||||
return 6.0
|
||||
elif dtype == tl.float8e5:
|
||||
return 57344.0
|
||||
elif dtype == tl.float8e4nv:
|
||||
return 448.0
|
||||
else:
|
||||
tl.static_assert(False, f"Invalid {dtype=}")
|
||||
|
||||
@triton.jit
|
||||
def _compute_quant_and_scale(src_tensor, valid_src_mask, mx_tensor_dtype: tl.constexpr,
|
||||
DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr = 0):
|
||||
is_fp8: tl.constexpr = mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5
|
||||
BLOCK_SIZE_OUT_DIM: tl.constexpr = src_tensor.shape[0]
|
||||
BLOCK_SIZE_QUANT_DIM: tl.constexpr = src_tensor.shape[1]
|
||||
BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = src_tensor.shape[1] // MXFP_BLOCK_SIZE
|
||||
|
||||
# Explicit cast to fp32 since most ops are not supported on bfloat16. We avoid needless conversions to and from bf16
|
||||
f32_tensor = src_tensor.to(tl.float32)
|
||||
abs_tensor = tl.abs(f32_tensor)
|
||||
abs_tensor = tl.where(valid_src_mask, abs_tensor, -1.0) # Don't consider padding tensors in scale computation
|
||||
abs_tensor = tl.reshape(abs_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
|
||||
max_val = tl.max(abs_tensor, axis=2, keep_dims=True)
|
||||
dequant_scale = max_val / _get_max_quant_val(mx_tensor_dtype)
|
||||
if DEQUANT_SCALE_ROUNDING_MODE == 0:
|
||||
# DequantScaleRoundingMode.ROUND_UP
|
||||
# compute 2 ** ceil(log2(dequant_scale))
|
||||
# Adding 0x007FFFFF adds exponent by 1 unless mantissa is all zeros
|
||||
# A corner case: exponent is 0xFF that will overflow but that's already
|
||||
# NaN so assume we don't care.
|
||||
dequant_scale_exponent = (dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF) & 0x7F800000
|
||||
else:
|
||||
# DequantScaleRoundingMode.ROUND_DOWN
|
||||
# compute 2 ** floor(log2(dequant_scale))
|
||||
assert DEQUANT_SCALE_ROUNDING_MODE == 1
|
||||
dequant_scale_exponent = dequant_scale.to(tl.uint32, bitcast=True) & 0x7F800000
|
||||
dequant_scale_rounded = dequant_scale_exponent.to(tl.float32, bitcast=True)
|
||||
quant_scale = tl.where(dequant_scale_rounded == 0, 0, 1.0 / dequant_scale_rounded)
|
||||
|
||||
f32_tensor = tl.reshape(f32_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
|
||||
quant_tensor = f32_tensor * quant_scale
|
||||
|
||||
# Reshape the tensors after scaling
|
||||
quant_tensor = quant_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])
|
||||
# Set the invalid portions of the tensor to 0. This will ensure that any padding tensors are 0 in the mx format.
|
||||
quant_tensor = tl.where(valid_src_mask, quant_tensor, 0)
|
||||
dequant_scale_exponent = dequant_scale_exponent.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE])
|
||||
|
||||
# First, we simply extract the exponent part of the scales and store the result
|
||||
dequant_scale_exponent = (dequant_scale_exponent >> 23).to(tl.uint8)
|
||||
# Now we must convert the tensors to the mx format.
|
||||
if is_fp8:
|
||||
out_tensor = quant_tensor.to(mx_tensor_dtype)
|
||||
else:
|
||||
quant_tensor = quant_tensor.to(tl.uint32, bitcast=True)
|
||||
signs = quant_tensor & 0x80000000
|
||||
exponents = (quant_tensor >> 23) & 0xFF
|
||||
mantissas = (quant_tensor & 0x7FFFFF)
|
||||
|
||||
# 0.25 <= x < 0.75 maps to 0.5, a denormal number
|
||||
E8_BIAS = 127
|
||||
E2_BIAS = 1
|
||||
# Move implicit bit 1 at the beginning to mantissa for denormals
|
||||
adjusted_exponents = tl.core.sub(E8_BIAS, exponents + 1, sanitize_overflow=False)
|
||||
mantissas = tl.where(exponents < E8_BIAS, (0x400000 | (mantissas >> 1)) >> adjusted_exponents, mantissas)
|
||||
|
||||
# For normal numbers, we change the bias from 127 to 1, and for subnormals, we keep exponent as 0.
|
||||
exponents = tl.maximum(exponents, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS)
|
||||
|
||||
# Combine sign, exponent, and mantissa, while saturating
|
||||
# rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right
|
||||
e2m1_tmp = tl.minimum((((exponents << 2) | (mantissas >> 21)) + 1) >> 1, 0x7)
|
||||
e2m1_value = ((signs >> 28) | e2m1_tmp).to(tl.uint8)
|
||||
|
||||
e2m1_value = tl.reshape(e2m1_value, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM // 2, 2])
|
||||
evens, odds = tl.split(e2m1_value)
|
||||
out_tensor = evens | (odds << 4)
|
||||
|
||||
return out_tensor, dequant_scale_exponent
|
||||
|
||||
@triton.jit
|
||||
def _downcast_to_mxfp(mx_tensor_ptr, stride_mxt_outer, stride_mxt_quant: tl.constexpr,
|
||||
mx_scale_ptr, stride_mx_scale_outer, stride_mx_scale_quant,
|
||||
src_ptr, stride_src_outer, stride_src_quant,
|
||||
outer_dim, quant_dim,
|
||||
BLOCK_SIZE_OUT_DIM: tl.constexpr, BLOCK_SIZE_QUANT_DIM: tl.constexpr,
|
||||
DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr):
|
||||
|
||||
tl.static_assert(stride_mxt_quant == 1, f"Output stride, {stride_mxt_quant=} must be 1.")
|
||||
tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, f"{BLOCK_SIZE_QUANT_DIM=} must be a multiple of 32")
|
||||
|
||||
# uint8 signifies two fp4 e2m1 values packed into a single byte
|
||||
mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty
|
||||
tl.static_assert(mx_tensor_dtype == tl.uint8 or (mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5),
|
||||
f"Invalid {mx_tensor_dtype=}. Must be uint8 or float8.")
|
||||
|
||||
src_dtype: tl.constexpr = src_ptr.dtype.element_ty
|
||||
tl.static_assert(mx_scale_ptr.dtype.element_ty == tl.uint8, f"{mx_scale_ptr.dtype.element_ty=} must be uint8")
|
||||
tl.static_assert((src_dtype == tl.bfloat16) or (src_dtype == tl.float16) or (src_dtype == tl.float32), f"{src_dtype=} must be bfloat16 or float16 or float32")
|
||||
is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
|
||||
|
||||
outer_block = tl.program_id(0).to(tl.int64)
|
||||
quant_block = tl.program_id(1).to(tl.int64)
|
||||
|
||||
K_DIVISOR: tl.constexpr = 2 if is_fp4 else 1
|
||||
BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE
|
||||
BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR
|
||||
|
||||
start_src_quant = quant_block * BLOCK_SIZE_QUANT_DIM
|
||||
start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE
|
||||
start_mx_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR
|
||||
start_out = outer_block * BLOCK_SIZE_OUT_DIM
|
||||
|
||||
src_ptr += start_src_quant * stride_src_quant + start_out * stride_src_outer
|
||||
mx_scale_ptr += start_mx_scale_quant * stride_mx_scale_quant + start_out * stride_mx_scale_outer
|
||||
mx_tensor_ptr += start_mx_quant * stride_mxt_quant + start_out * stride_mxt_outer
|
||||
|
||||
offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64)
|
||||
offs_mxt_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64)
|
||||
offs_scale_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64)
|
||||
offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64)
|
||||
|
||||
mask_src_quant = start_src_quant + offs_src_quant < quant_dim
|
||||
mask_n = start_out + offs_outer < outer_dim
|
||||
full_mask_src = mask_src_quant & mask_n
|
||||
|
||||
mask_mxt_quant = start_mx_quant + offs_mxt_quant < tl.cdiv(quant_dim, K_DIVISOR)
|
||||
full_mask_mxt = mask_mxt_quant & mask_n
|
||||
|
||||
scale_mask_k = start_mx_scale_quant + offs_scale_quant < tl.cdiv(quant_dim, MXFP_BLOCK_SIZE)
|
||||
full_scale_mask = scale_mask_k & mask_n
|
||||
|
||||
src_tensor_offsets = offs_src_quant * stride_src_quant + offs_outer * stride_src_outer
|
||||
mx_scale_offsets = offs_scale_quant * stride_mx_scale_quant + offs_outer * stride_mx_scale_outer
|
||||
mx_tensor_offsets = offs_mxt_quant * stride_mxt_quant + offs_outer * stride_mxt_outer
|
||||
src_tensor = tl.load(src_ptr + src_tensor_offsets, mask=full_mask_src)
|
||||
|
||||
out_tensor, scale_tensor = _compute_quant_and_scale(src_tensor, full_mask_src, mx_tensor_dtype,
|
||||
DEQUANT_SCALE_ROUNDING_MODE)
|
||||
|
||||
tl.store(mx_scale_ptr + mx_scale_offsets, scale_tensor, mask=full_scale_mask)
|
||||
tl.store(mx_tensor_ptr + mx_tensor_offsets, out_tensor, mask=full_mask_mxt)
|
||||
|
||||
|
||||
@triton.jit(repr=lambda _: "_dequantize_mxfp8")
|
||||
def _dequantize_mxfp8_fn(input, mask, pid=None):
|
||||
return _compute_quant_and_scale(input, mask, tl.float8e4nv)
|
||||
@@ -0,0 +1,136 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from ._downcast_to_mxfp import MXFP_BLOCK_SIZE
|
||||
|
||||
|
||||
# fmt: off
|
||||
@triton.jit
|
||||
def _upcast_from_mxfp(out_ptr, stride_o_outer, stride_o_quant: tl.constexpr, mx_scale_ptr, stride_scale_outer,
|
||||
stride_scale_quant, mx_tensor_ptr, stride_tensor_outer, stride_tensor_quant: tl.constexpr,
|
||||
outer_dim, quant_dim, BLOCK_SIZE_OUT_DIM: tl.constexpr, BLOCK_SIZE_QUANT_DIM: tl.constexpr):
|
||||
|
||||
tl.static_assert(stride_o_quant == 1, "the weight must be contiguous in the k dimension for mx")
|
||||
tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, "BLOCK_SIZE_K must be a multiple of 32")
|
||||
# uint8 signifies two fp4 e2m1 values packed into a single byte
|
||||
mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty
|
||||
dst_dtype: tl.constexpr = out_ptr.dtype.element_ty
|
||||
tl.static_assert(dst_dtype == tl.float16 or (dst_dtype == tl.bfloat16 or dst_dtype == tl.float32))
|
||||
tl.static_assert(
|
||||
mx_tensor_dtype == tl.uint8
|
||||
or ((mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5) or mx_tensor_dtype == dst_dtype),
|
||||
"mx_tensor_ptr must be uint8 or float8 or dst_dtype")
|
||||
tl.static_assert(mx_scale_ptr.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8")
|
||||
|
||||
# Determine if we are dealing with fp8 types.
|
||||
is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
|
||||
is_fp8: tl.constexpr = mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5
|
||||
K_DIVISOR: tl.constexpr = 2 if is_fp4 else 1
|
||||
BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE
|
||||
BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR
|
||||
|
||||
# Compute starting indices for the quantized (packed) dimension and the outer dimension.
|
||||
outer_block = tl.program_id(0).to(tl.int64)
|
||||
quant_block = tl.program_id(1).to(tl.int64)
|
||||
|
||||
start_mxt_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR
|
||||
start_out_quant = quant_block * BLOCK_SIZE_QUANT_DIM
|
||||
start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE
|
||||
start_out = outer_block * BLOCK_SIZE_OUT_DIM
|
||||
|
||||
mx_tensor_ptr += start_mxt_quant * stride_tensor_quant + start_out * stride_tensor_outer
|
||||
mx_scale_ptr += start_mx_scale_quant * stride_scale_quant + start_out * stride_scale_outer
|
||||
out_ptr += start_out * stride_o_outer + start_out_quant * stride_o_quant
|
||||
|
||||
# Compute offsets and masks.
|
||||
offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64)
|
||||
offs_out_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64)
|
||||
offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64)
|
||||
offs_scale = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64)
|
||||
|
||||
mask_outer = start_out + offs_outer < outer_dim
|
||||
mask_out_quant = start_out_quant + offs_out_quant < quant_dim
|
||||
full_mask_out = mask_out_quant & mask_outer
|
||||
|
||||
mask_src_quant = start_mxt_quant + offs_src_quant < tl.cdiv(quant_dim, K_DIVISOR)
|
||||
full_mask_src = mask_src_quant & mask_outer
|
||||
|
||||
mask_scale = start_mx_scale_quant + offs_scale < tl.cdiv(quant_dim, MXFP_BLOCK_SIZE)
|
||||
full_scale_mask = mask_scale & mask_outer
|
||||
|
||||
tensor_offsets = offs_src_quant * stride_tensor_quant + offs_outer * stride_tensor_outer
|
||||
scale_offsets = offs_scale * stride_scale_quant + offs_outer * stride_scale_outer
|
||||
out_offsets = offs_out_quant * stride_o_quant + offs_outer * stride_o_outer
|
||||
|
||||
# Load the packed tensor and scale.
|
||||
tensor = tl.load(mx_tensor_ptr + tensor_offsets, mask=full_mask_src)
|
||||
scale = tl.load(mx_scale_ptr + scale_offsets, mask=full_scale_mask)
|
||||
|
||||
# Upcast the scale to the destination type.
|
||||
if dst_dtype == tl.bfloat16:
|
||||
# dst_scale = (scale.to(tl.uint16) << 7).to(dst_dtype, bitcast=True)
|
||||
dst_scale = (scale.to(tl.uint16) << 7).to(tl.uint16).to(tl.bfloat16, bitcast=True)
|
||||
else:
|
||||
dst_scale = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True)
|
||||
if dst_dtype == tl.float16:
|
||||
dst_scale = dst_scale.to(tl.float16)
|
||||
|
||||
# Now upcast the tensor.
|
||||
intermediate_dtype: tl.constexpr = tl.bfloat16 if dst_dtype == tl.float32 else dst_dtype
|
||||
if is_fp8:
|
||||
dst_tensor = tensor.to(intermediate_dtype)
|
||||
if tensor.dtype == tl.float8e5:
|
||||
from_e_bits: tl.constexpr = 5
|
||||
from_m_bits: tl.constexpr = 2
|
||||
to_e_bits: tl.constexpr = 8 if intermediate_dtype == tl.bfloat16 else 5
|
||||
to_m_bits: tl.constexpr = 7 if intermediate_dtype == tl.bfloat16 else 10
|
||||
|
||||
# Preserve infs and nans. FIXME Fp8E5M2_to_Bf16 doesn't preserve them!
|
||||
non_finite_mask_src: tl.constexpr = ((1 << from_e_bits) - 1) << from_m_bits
|
||||
non_finite_mask_dst: tl.constexpr = ((1 << to_e_bits) - 1) << to_m_bits
|
||||
dst_tensor = tl.where(
|
||||
(tensor.to(tl.uint8, bitcast=True) & non_finite_mask_src) == non_finite_mask_src,
|
||||
(dst_tensor.to(tl.uint16, bitcast=True) | non_finite_mask_dst).to(intermediate_dtype, bitcast=True),
|
||||
dst_tensor,
|
||||
)
|
||||
else:
|
||||
assert is_fp4
|
||||
dst_bias: tl.constexpr = 127 if intermediate_dtype == tl.bfloat16 else 15
|
||||
dst_0p5: tl.constexpr = 16128 if intermediate_dtype == tl.bfloat16 else 0x3800
|
||||
dst_m_bits: tl.constexpr = 7 if intermediate_dtype == tl.bfloat16 else 10
|
||||
# e2m1
|
||||
em0 = tensor & 0x07
|
||||
em1 = tensor & 0x70
|
||||
x0 = (em0.to(tl.uint16) << (dst_m_bits - 1)) | ((tensor & 0x08).to(tl.uint16) << 12)
|
||||
x1 = (em1.to(tl.uint16) << (dst_m_bits - 5)) | ((tensor & 0x80).to(tl.uint16) << 8)
|
||||
# Three cases:
|
||||
# 1) x is normal and non-zero: Correct bias
|
||||
x0 = tl.where((em0 & 0x06) != 0, x0 + ((dst_bias - 1) << dst_m_bits), x0)
|
||||
x1 = tl.where((em1 & 0x60) != 0, x1 + ((dst_bias - 1) << dst_m_bits), x1)
|
||||
# 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
|
||||
x0 = tl.where(em0 == 0x01, dst_0p5 | (x0 & 0x8000), x0)
|
||||
x1 = tl.where(em1 == 0x10, dst_0p5 | (x1 & 0x8000), x1)
|
||||
# 3) x is zero, do nothing
|
||||
if intermediate_dtype == tl.bfloat16:
|
||||
dst_tensor = tl.interleave(x0, x1).to(tl.uint16).to(tl.bfloat16, bitcast=True)
|
||||
else:
|
||||
dst_tensor = tl.interleave(x0, x1).to(tl.float16, bitcast=True)
|
||||
# dst_tensor = dst_tensor.to(dst_dtype)
|
||||
if dst_dtype == tl.bfloat16:
|
||||
dst_tensor = dst_tensor.to(tl.bfloat16)
|
||||
elif dst_dtype == tl.float16:
|
||||
dst_tensor = dst_tensor.to(tl.float16)
|
||||
else:
|
||||
dst_tensor = dst_tensor.to(tl.float32)
|
||||
|
||||
|
||||
# Reshape for proper broadcasting: the scale was stored with a 32‐sized “inner” grouping.
|
||||
dst_tensor = dst_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
|
||||
dst_scale = dst_scale.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1])
|
||||
scale = scale.reshape(dst_scale.shape)
|
||||
|
||||
out_tensor = dst_tensor * dst_scale
|
||||
# Correct any NaNs encoded via the scale.
|
||||
out_tensor = tl.where(scale == 0xFF, float("nan"), out_tensor)
|
||||
out_tensor = out_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])
|
||||
tl.store(out_ptr + out_offsets, out_tensor, mask=full_mask_out)
|
||||
@@ -0,0 +1,303 @@
|
||||
# isort: off
|
||||
# fmt: off
|
||||
from enum import Enum
|
||||
import triton
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from ._upcast_from_mxfp import _upcast_from_mxfp
|
||||
from ._downcast_to_mxfp import _downcast_to_mxfp, _dequantize_mxfp8_fn, MXFP_BLOCK_SIZE
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Dequantization / Quantization Utilities
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DequantScaleRoundingMode(Enum):
|
||||
ROUND_UP = 0
|
||||
ROUND_DOWN = 1
|
||||
|
||||
|
||||
def downcast_to_mxfp(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis: int,
|
||||
DEQUANT_SCALE_ROUNDING_MODE: DequantScaleRoundingMode = DequantScaleRoundingMode.ROUND_UP):
|
||||
"""
|
||||
Convert the src weights to mx format. The src weight is quantized along the axis dimension.
|
||||
|
||||
If weight_quant_type is torch.uint8, we output mxfp4 where two e2m1 values are packed into a single byte.
|
||||
Note that this means the k_dim of the tensor will be half of the logical k_dim.
|
||||
|
||||
If weight_quant_type is torch.float8_e4m3fn or torch.float8_e5m2, we output mxfp8 with the float8s are stored
|
||||
in their respective formats.
|
||||
"""
|
||||
ndim = src_tensor.ndim
|
||||
assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
|
||||
axis = axis if axis >= 0 else axis + ndim
|
||||
# downcast
|
||||
src_tensor = src_tensor.transpose(axis, src_tensor.ndim - 1)
|
||||
is_fp4 = out_quant_type == torch.uint8
|
||||
is_fp8 = out_quant_type in (torch.float8_e4m3fn, torch.float8_e5m2)
|
||||
assert is_fp4 or is_fp8
|
||||
divisor = 2 if is_fp4 else 1
|
||||
L = src_tensor.shape[-1]
|
||||
if is_fp4:
|
||||
assert L % 2 == 0, f"axis dim must be divisible by 2 for e2m1. Got {L}"
|
||||
out_shape = src_tensor.shape[:-1] + (L // divisor, )
|
||||
out_scale_shape = src_tensor.shape[:-1] + (triton.cdiv(L, MXFP_BLOCK_SIZE), )
|
||||
|
||||
out_quant_tensor = src_tensor.new_empty(out_shape, dtype=out_quant_type)
|
||||
out_scale = src_tensor.new_empty(out_scale_shape, dtype=torch.uint8)
|
||||
|
||||
if src_tensor.numel() > 0:
|
||||
kernel_src_tensor = src_tensor.reshape(-1, src_tensor.shape[-1])
|
||||
kernel_quant_tensor = out_quant_tensor.view(-1, out_quant_tensor.shape[-1])
|
||||
kernel_scale = out_scale.view(-1, out_scale.shape[-1])
|
||||
|
||||
BLOCK_OUT_DIM = 128
|
||||
BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE.value
|
||||
grid_out = triton.cdiv(kernel_src_tensor.shape[0], BLOCK_OUT_DIM)
|
||||
grid_quant = triton.cdiv(kernel_src_tensor.shape[1], BLOCK_QUANT_DIM)
|
||||
|
||||
_downcast_to_mxfp[(grid_out, grid_quant)](kernel_quant_tensor, *kernel_quant_tensor.stride(), kernel_scale,
|
||||
*kernel_scale.stride(), kernel_src_tensor, *kernel_src_tensor.stride(),
|
||||
*kernel_src_tensor.shape, BLOCK_OUT_DIM, BLOCK_QUANT_DIM,
|
||||
DEQUANT_SCALE_ROUNDING_MODE.value, num_warps=8)
|
||||
|
||||
out_quant_tensor = out_quant_tensor.transpose(axis, src_tensor.ndim - 1)
|
||||
out_scale = out_scale.transpose(axis, src_tensor.ndim - 1)
|
||||
return out_quant_tensor, out_scale
|
||||
|
||||
|
||||
def upcast_from_mxfp(tensor: torch.Tensor, scale: torch.Tensor, dtype: torch.dtype, axis: int):
|
||||
"""
|
||||
Upcasts an mxfp (packed) weight tensor back to float16 or bfloat16.
|
||||
|
||||
The function assumes that the tensors were quantized along the given axis.
|
||||
It permutes the tensor so that the quantized axis is last, reshapes to 2D,
|
||||
launches the Triton upcast kernel, and then unpermutes back to the original order.
|
||||
"""
|
||||
ndim = tensor.ndim
|
||||
assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
|
||||
axis = axis if axis >= 0 else axis + ndim
|
||||
assert tensor.ndim == scale.ndim, (f"Weight and scale must have the same number of dimensions. "
|
||||
f"Got {tensor.ndim=} and {scale.ndim=}")
|
||||
# dtype checks
|
||||
assert tensor.dtype in {torch.uint8, torch.float8_e5m2, torch.float8_e4m3fn}, \
|
||||
f"Invalid tensor dtype {tensor.dtype=}"
|
||||
assert scale.dtype == torch.uint8, f"Invalid scale dtype {scale.dtype=}"
|
||||
assert dtype in (torch.float16, torch.bfloat16, torch.float32), f"Invalid output dtype {dtype=}"
|
||||
# upcast
|
||||
logical_quant_dim = tensor.shape[axis] * (2 if tensor.dtype == torch.uint8 else 1)
|
||||
tensor = tensor.transpose(axis, tensor.ndim - 1).contiguous()
|
||||
scale = scale.transpose(axis, scale.ndim - 1).contiguous()
|
||||
out = torch.empty((*tensor.shape[:-1], logical_quant_dim), dtype=dtype, device=tensor.device)
|
||||
reshaped_out = out.view(-1, out.shape[-1])
|
||||
reshaped_tensor = tensor.view(-1, tensor.shape[-1])
|
||||
reshaped_scale = scale.view(-1, scale.shape[-1])
|
||||
BLOCK_OUT_DIM = 128
|
||||
BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE.value
|
||||
blocks_out_dim = triton.cdiv(reshaped_out.shape[0], BLOCK_OUT_DIM)
|
||||
blocks_quant_dim = triton.cdiv(reshaped_out.shape[1], BLOCK_QUANT_DIM)
|
||||
_upcast_from_mxfp[(blocks_out_dim, blocks_quant_dim)](reshaped_out, *reshaped_out.stride(), reshaped_scale,
|
||||
*reshaped_scale.stride(), reshaped_tensor,
|
||||
*reshaped_tensor.stride(), *reshaped_out.shape, BLOCK_OUT_DIM,
|
||||
BLOCK_QUANT_DIM, num_warps=8)
|
||||
out = out.transpose(axis, scale.ndim - 1).contiguous()
|
||||
return out
|
||||
|
||||
|
||||
# ------------
|
||||
|
||||
|
||||
def right_shift_unsigned(x, shift):
|
||||
# CUDA torch does not support bit ops on uint32, so we need to mask to get unsigned right shift
|
||||
return (x >> shift) & ((1 << (32 - shift)) - 1)
|
||||
|
||||
|
||||
def get_max_quant_val(dtype: torch.dtype):
|
||||
d = {torch.uint8: 6.0, torch.float8_e5m2: 57344.0, torch.float8_e4m3fn: 448.0}
|
||||
assert dtype in d
|
||||
return d[dtype]
|
||||
|
||||
|
||||
def downcast_to_mxfp_torch(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis: int,
|
||||
DEQUANT_SCALE_ROUNDING_MODE: DequantScaleRoundingMode = DequantScaleRoundingMode.ROUND_UP):
|
||||
"""
|
||||
Converts the src tensor to the output format specified by out_quant_type.
|
||||
axis: The axis along which the tensors are contiguous and quantization is applied.
|
||||
DEQUANT_SCALE_ROUNDING_MODE: 0 for ROUND_UP, 1 for ROUND_DOWN.
|
||||
|
||||
Returns:
|
||||
out_quant_tensor: Quantized tensor in mx format.
|
||||
• For mxfp8, the output has the same shape as src_tensor.
|
||||
• For mxfp4, the size along the axis is halved, and the tensor is returned as a torch.uint8.
|
||||
scale: Scale tensor (stored as uint8) computed per group of 32 elements along the axis.
|
||||
Its shape is the same as src_tensor except that the axis is replaced by ceil(L/32),
|
||||
where L is the original length along that axis.
|
||||
"""
|
||||
# This should probably be packed into its own tiny class
|
||||
ndim = src_tensor.ndim
|
||||
assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
|
||||
assert src_tensor.dtype in {torch.float32, torch.bfloat16,
|
||||
torch.float16}, f"Invalid input tensor dtype {src_tensor.dtype}"
|
||||
|
||||
axis = axis if axis >= 0 else axis + ndim
|
||||
is_fp4 = out_quant_type == torch.uint8
|
||||
is_fp8 = "float8" in str(out_quant_type)
|
||||
assert is_fp4 or is_fp8, f"Invalid input tensor dtype {out_quant_type}"
|
||||
|
||||
device = src_tensor.device
|
||||
|
||||
# For mxfp4 conversion, we assume the contiguous axis length is even.
|
||||
if is_fp4:
|
||||
axis_shape = src_tensor.size(axis)
|
||||
assert axis_shape % 2 == 0, "For mxfp4 conversion the contiguous axis length must be even."
|
||||
|
||||
# Permute the tensor so that the contiguous axis becomes the last dimension.
|
||||
src = src_tensor.transpose(axis, src_tensor.ndim - 1).to(torch.float32)
|
||||
axis_shape = src.shape[-1]
|
||||
|
||||
# Pad the axis to be divisible by 32, in case it is not.
|
||||
next_multiple = triton.cdiv(axis_shape, MXFP_BLOCK_SIZE) * MXFP_BLOCK_SIZE
|
||||
pad_amount = next_multiple - axis_shape
|
||||
padded_src = F.pad(src, (0, pad_amount))
|
||||
valid_mask = F.pad(torch.ones_like(src, dtype=torch.bool), (0, pad_amount))
|
||||
padded_axis_shape = padded_src.size(-1) # now divisible by 32
|
||||
|
||||
# --- Compute per-group maximums for scale ---
|
||||
# Set padded entries to -1 so they don’t affect the max.
|
||||
abs_f = torch.abs(padded_src)
|
||||
abs_f = torch.where(valid_mask, abs_f, torch.tensor(-1.0, device=device, dtype=padded_src.dtype))
|
||||
# Reshape the last dimension into groups of 32.
|
||||
new_shape = padded_src.shape[:-1] + (padded_axis_shape // MXFP_BLOCK_SIZE, MXFP_BLOCK_SIZE)
|
||||
abs_groups = abs_f.view(*new_shape)
|
||||
# Compute maximum along the group dimension (of size 32).
|
||||
max_val, _ = abs_groups.max(dim=-1, keepdim=True)
|
||||
|
||||
# Choose a max quantization value depending on type.
|
||||
max_quant_val = get_max_quant_val(out_quant_type)
|
||||
dequant_scale = max_val / max_quant_val # shape: (..., padded_axis_shape//32, 1)
|
||||
|
||||
# Convert to int to round the FP32 scale, prior to quantization!
|
||||
ds_int = dequant_scale.view(torch.int32)
|
||||
if DEQUANT_SCALE_ROUNDING_MODE == DequantScaleRoundingMode.ROUND_UP:
|
||||
ds_int_rounded = (ds_int + 0x007FFFFF) & 0x7F800000
|
||||
else:
|
||||
ds_int_rounded = ds_int & 0x7F800000
|
||||
# Reinterpret back as float32.
|
||||
dequant_scale_rounded = ds_int_rounded.view(torch.float32)
|
||||
|
||||
# Compute the quantization scale.
|
||||
quant_scale = torch.where(dequant_scale_rounded == 0, torch.tensor(0.0, device=device), 1.0 / dequant_scale_rounded)
|
||||
|
||||
# Quantize the tensor
|
||||
orig_padded_shape = padded_src.shape
|
||||
padded_src_groups = padded_src.view(*new_shape)
|
||||
quant_tensor = padded_src_groups * quant_scale
|
||||
# Reshape back to the original shape and trim padding
|
||||
quant_tensor = quant_tensor.view(orig_padded_shape)
|
||||
quant_tensor = quant_tensor[..., :axis_shape]
|
||||
|
||||
# Finally, convert the quantized tensor to the target format
|
||||
if is_fp8:
|
||||
# Conversion must use satfinite PTX, so clamp before the conversion in torch to emulate this behavior
|
||||
quant_tensor = torch.clamp(quant_tensor, -max_quant_val, max_quant_val)
|
||||
out_weight = quant_tensor.to(out_quant_type)
|
||||
else:
|
||||
assert is_fp4, f"Invalid output quantization type {out_quant_type}"
|
||||
# For mxfp4, perform bit-level manipulation and pack two 4-bit values per uint8.
|
||||
# First, reinterpret the quantized tensor bits.
|
||||
q_int = quant_tensor.contiguous().view(torch.int32)
|
||||
# Extract sign, exponent, and mantissa.
|
||||
signs = q_int & 0x80000000
|
||||
exponents = right_shift_unsigned(q_int, 23) & 0xFF
|
||||
mantissas = q_int & 0x7FFFFF
|
||||
|
||||
E8_BIAS = 127
|
||||
E2_BIAS = 1
|
||||
# Adjust mantissas for subnormals.
|
||||
mantissas = torch.where(exponents < E8_BIAS, (0x400000 | right_shift_unsigned(mantissas, 1)) >>
|
||||
(E8_BIAS - exponents - 1), mantissas)
|
||||
exponents = torch.maximum(exponents, torch.tensor(E8_BIAS - E2_BIAS, device=device)) - (E8_BIAS - E2_BIAS)
|
||||
e2m1_tmp = right_shift_unsigned(((exponents << 2) | right_shift_unsigned(mantissas, 21)) + 1, 1)
|
||||
e2m1_tmp = torch.minimum(e2m1_tmp, torch.tensor(0x7, device=device))
|
||||
e2m1_value = (right_shift_unsigned(signs, 28) | e2m1_tmp).to(torch.uint8) # shape: (..., even_axis_shape)
|
||||
|
||||
# Pack pairs of 4-bit values along the last dimension.
|
||||
e2m1_value = e2m1_value.view(*e2m1_value.shape[:-1], axis_shape // 2, 2)
|
||||
evens = e2m1_value[..., 0]
|
||||
odds = e2m1_value[..., 1]
|
||||
out_weight = evens | (odds << 4) # shape: (..., axis_shape//2)
|
||||
|
||||
# --- Process and output the scale ---
|
||||
dq_scale = (ds_int_rounded.view(*dequant_scale.shape) >> 23).to(torch.uint8) # shape: (..., axis_shape//32, 1)
|
||||
dq_scale = dq_scale.squeeze(-1)
|
||||
out_weight = out_weight.transpose(axis, src_tensor.ndim - 1)
|
||||
dq_scale = dq_scale.transpose(axis, src_tensor.ndim - 1)
|
||||
return out_weight, dq_scale
|
||||
|
||||
|
||||
def cvt_e2m1_to_fp32(input_tensor):
|
||||
assert input_tensor.dtype == torch.uint8
|
||||
|
||||
input_tensor = input_tensor.to(torch.int32)
|
||||
evens = input_tensor & 0xF
|
||||
odds = (input_tensor >> 4) & 0xF
|
||||
|
||||
vals = [0.0, 0.5, 1, 1.5, 2, 3, 4, 6]
|
||||
outputs = torch.tensor(vals, dtype=torch.float32, device=input_tensor.device)
|
||||
outputs = torch.cat([outputs, -outputs])
|
||||
|
||||
even_floats = outputs[evens]
|
||||
odd_floats = outputs[odds]
|
||||
output_tensor = torch.stack([even_floats, odd_floats], dim=-1)
|
||||
output_tensor = output_tensor.view(*input_tensor.shape[:-1], -1)
|
||||
return output_tensor
|
||||
|
||||
|
||||
def upcast_from_mxfp_torch(tensor: torch.Tensor, scale: torch.Tensor, target_dtype: torch.dtype, axis: int):
|
||||
"""
|
||||
Converts the mxfp4/mxfp8 tensor to the target format specified by target_dtype.
|
||||
axis: The axis along which dequantization is applied.
|
||||
|
||||
Returns:
|
||||
out_weight: Tensor in the target format.
|
||||
"""
|
||||
|
||||
ndim = tensor.ndim
|
||||
assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
|
||||
is_fp8 = tensor.dtype == torch.float8_e4m3fn or tensor.dtype == torch.float8_e5m2
|
||||
assert is_fp8 or tensor.dtype == torch.uint8, f"Invalid input quantization type {tensor.dtype}"
|
||||
|
||||
# Permute the tensor and scale so that the quantization axis becomes the last dimension
|
||||
axis = axis if axis >= 0 else axis + ndim
|
||||
scale = scale.transpose(axis, scale.ndim - 1)
|
||||
tensor = tensor.transpose(axis, tensor.ndim - 1)
|
||||
|
||||
dq_scale = (scale.to(torch.int32) << 23).view(torch.float32) # Shift to the exponent and bitcast to fp32
|
||||
if tensor.dtype == torch.uint8:
|
||||
fp32_tensor = cvt_e2m1_to_fp32(tensor)
|
||||
else:
|
||||
fp32_tensor = tensor.to(torch.float32)
|
||||
|
||||
logical_quant_dim = tensor.shape[-1] * (2 if tensor.dtype == torch.uint8 else 1)
|
||||
axis_shape = fp32_tensor.size(-1)
|
||||
padded_axis_shape = triton.cdiv(logical_quant_dim, MXFP_BLOCK_SIZE) * MXFP_BLOCK_SIZE
|
||||
pad_size = padded_axis_shape - axis_shape
|
||||
padded_tensor = F.pad(fp32_tensor, (0, pad_size))
|
||||
|
||||
new_axis_shape = padded_tensor.shape[-1]
|
||||
new_shape = padded_tensor.shape[:-1] + (new_axis_shape // MXFP_BLOCK_SIZE, MXFP_BLOCK_SIZE)
|
||||
padded_tensor = padded_tensor.view(*new_shape)
|
||||
dq_scale_padded = dq_scale.unsqueeze(-1) # shape: [..., ceil(axis_shape/32), 1]
|
||||
out_padded = padded_tensor * dq_scale_padded
|
||||
|
||||
# Flatten back and remove the padded tail
|
||||
out_padded = out_padded.view(*fp32_tensor.shape[:-1], new_axis_shape)
|
||||
out_tensor = out_padded[..., :axis_shape]
|
||||
|
||||
out_tensor = out_tensor.to(target_dtype).contiguous()
|
||||
out_tensor = out_tensor.transpose(axis, tensor.ndim - 1)
|
||||
|
||||
return out_tensor
|
||||
|
||||
|
||||
dequantize_mxfp8_fn = _dequantize_mxfp8_fn
|
||||
@@ -261,6 +261,13 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
|
||||
return s
|
||||
|
||||
|
||||
def marlin_permute_bias(s: torch.Tensor) -> torch.Tensor:
|
||||
origin_shape = s.shape
|
||||
_, scale_perm_single = get_scale_perms()
|
||||
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
|
||||
return s.reshape(*origin_shape).contiguous()
|
||||
|
||||
|
||||
def marlin_moe_permute_scales(
|
||||
s: torch.Tensor,
|
||||
size_k: int,
|
||||
@@ -410,6 +417,7 @@ def apply_gptq_marlin_linear(
|
||||
output = ops.gptq_marlin_gemm(reshaped_x,
|
||||
None,
|
||||
weight,
|
||||
bias,
|
||||
weight_scale,
|
||||
None,
|
||||
weight_zp,
|
||||
@@ -425,9 +433,6 @@ def apply_gptq_marlin_linear(
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
|
||||
|
||||
@@ -456,6 +461,7 @@ def apply_awq_marlin_linear(
|
||||
output = ops.gptq_marlin_gemm(reshaped_x,
|
||||
None,
|
||||
weight,
|
||||
bias,
|
||||
weight_scale,
|
||||
None,
|
||||
weight_zp,
|
||||
@@ -470,7 +476,4 @@ def apply_awq_marlin_linear(
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
|
||||
@@ -8,8 +8,8 @@ import torch
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales,
|
||||
should_use_atomic_add_reduce)
|
||||
USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_bias,
|
||||
marlin_permute_scales, should_use_atomic_add_reduce)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
@@ -22,7 +22,7 @@ def is_fp4_marlin_supported():
|
||||
return current_platform.has_device_capability(80)
|
||||
|
||||
|
||||
def fp4_marlin_process_scales(marlin_scales):
|
||||
def nvfp4_marlin_process_scales(marlin_scales):
|
||||
if not (marlin_scales >= 0).all():
|
||||
logger.warning_once(
|
||||
"NVFP4 Marlin assumes the scales to be >=0, but has encountered "
|
||||
@@ -56,7 +56,20 @@ def fp4_marlin_process_scales(marlin_scales):
|
||||
return marlin_scales
|
||||
|
||||
|
||||
def fp4_marlin_process_global_scale(global_scale):
|
||||
def mxfp4_marlin_process_scales(marlin_scales):
|
||||
# 8 is the number of scale number using by one thread
|
||||
marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8)
|
||||
marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape(
|
||||
marlin_scales.size(0) * 2, -1)
|
||||
|
||||
# fit the layout of fp8 dequantization
|
||||
marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
|
||||
marlin_scales.size(0), -1)
|
||||
marlin_scales = marlin_scales.to(torch.float8_e8m0fnu)
|
||||
return marlin_scales
|
||||
|
||||
|
||||
def nvfp4_marlin_process_global_scale(global_scale):
|
||||
assert global_scale.dtype in [torch.half, torch.bfloat16]
|
||||
fp4_exponent = 2
|
||||
if global_scale.dtype == torch.half:
|
||||
@@ -73,7 +86,7 @@ def apply_fp4_marlin_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
weight_scale_2: torch.Tensor,
|
||||
weight_scale_2: Optional[torch.Tensor],
|
||||
workspace: torch.Tensor,
|
||||
size_n: int,
|
||||
size_k: int,
|
||||
@@ -94,6 +107,7 @@ def apply_fp4_marlin_linear(
|
||||
output = ops.gptq_marlin_gemm(a=reshaped_x,
|
||||
c=None,
|
||||
b_q_weight=weight,
|
||||
b_bias=bias,
|
||||
b_scales=weight_scale,
|
||||
global_scale=weight_scale_2,
|
||||
b_zeros=None,
|
||||
@@ -107,9 +121,6 @@ def apply_fp4_marlin_linear(
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=use_fp32_reduce)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
|
||||
|
||||
@@ -120,6 +131,9 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
"be used leveraging the Marlin kernel. This may degrade "
|
||||
"performance for compute-heavy workloads.")
|
||||
|
||||
is_nvfp4 = hasattr(layer, "weight_scale_2")
|
||||
group_size = 16 if is_nvfp4 else 32
|
||||
|
||||
part_size_n = layer.output_size_per_partition
|
||||
part_size_k = layer.input_size_per_partition
|
||||
param_dtype = layer.params_dtype
|
||||
@@ -145,18 +159,35 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
|
||||
# WEIGHT SCALES
|
||||
# Permute scales
|
||||
weight_scale = layer.weight_scale.T.to(param_dtype)
|
||||
weight_scale = layer.weight_scale.T.contiguous()
|
||||
|
||||
if not is_nvfp4:
|
||||
weight_scale = weight_scale.view(torch.float8_e8m0fnu)
|
||||
|
||||
weight_scale = weight_scale.to(param_dtype)
|
||||
weight_scale = marlin_permute_scales(s=weight_scale,
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n,
|
||||
group_size=16)
|
||||
weight_scale = fp4_marlin_process_scales(weight_scale)
|
||||
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
|
||||
group_size=group_size)
|
||||
|
||||
weight_scale_2 = layer.weight_scale_2.to(param_dtype)
|
||||
weight_scale_2 = fp4_marlin_process_global_scale(weight_scale_2)
|
||||
layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2,
|
||||
requires_grad=False)
|
||||
if is_nvfp4:
|
||||
weight_scale = nvfp4_marlin_process_scales(weight_scale)
|
||||
layer.weight_scale = torch.nn.Parameter(weight_scale,
|
||||
requires_grad=False)
|
||||
|
||||
weight_scale_2 = layer.weight_scale_2.to(param_dtype)
|
||||
weight_scale_2 = nvfp4_marlin_process_global_scale(weight_scale_2)
|
||||
layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2,
|
||||
requires_grad=False)
|
||||
else:
|
||||
weight_scale = mxfp4_marlin_process_scales(weight_scale)
|
||||
layer.weight_scale = torch.nn.Parameter(weight_scale,
|
||||
requires_grad=False)
|
||||
|
||||
if hasattr(layer, "bias") and layer.bias is not None:
|
||||
assert layer.bias.shape == (part_size_n, )
|
||||
bias = marlin_permute_bias(layer.bias)
|
||||
layer.bias = torch.nn.Parameter(bias, requires_grad=False)
|
||||
|
||||
return
|
||||
|
||||
@@ -168,6 +199,9 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
"be used leveraging the Marlin kernel. This may degrade "
|
||||
"performance for compute-heavy workloads.")
|
||||
|
||||
is_nvfp4 = hasattr(layer, "w13_weight_scale_2")
|
||||
group_size = 16 if is_nvfp4 else 32
|
||||
|
||||
e = layer.num_experts
|
||||
k = layer.hidden_size
|
||||
n = layer.intermediate_size_per_partition
|
||||
@@ -208,8 +242,13 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
# WEIGHT SCALES
|
||||
# Permute scales
|
||||
for name in ["w13", "w2"]:
|
||||
scales = getattr(layer, name + "_weight_scale").to(param_dtype)
|
||||
global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype)
|
||||
scales = getattr(layer, name + "_weight_scale")
|
||||
if not is_nvfp4:
|
||||
scales = scales.view(torch.float8_e8m0fnu)
|
||||
scales = scales.to(param_dtype)
|
||||
if is_nvfp4:
|
||||
global_scale = getattr(layer,
|
||||
name + "_weight_scale_2").to(param_dtype)
|
||||
|
||||
tensor_list = []
|
||||
if "w13" in name:
|
||||
@@ -218,23 +257,47 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
size_n, size_k = k, n
|
||||
|
||||
for i in range(e):
|
||||
marlin_scales = marlin_permute_scales(s=scales[i].T,
|
||||
scale = scales[i].T
|
||||
|
||||
marlin_scales = marlin_permute_scales(s=scale,
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
group_size=16)
|
||||
marlin_scales = fp4_marlin_process_scales(marlin_scales)
|
||||
group_size=group_size)
|
||||
if is_nvfp4:
|
||||
marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
|
||||
else:
|
||||
marlin_scales = mxfp4_marlin_process_scales(marlin_scales)
|
||||
tensor_list.append(marlin_scales)
|
||||
|
||||
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||
scales = torch.nn.Parameter(scales, requires_grad=False)
|
||||
setattr(layer, name + "_weight_scale", scales)
|
||||
|
||||
global_scale = fp4_marlin_process_global_scale(global_scale)
|
||||
global_scale = torch.nn.Parameter(global_scale, requires_grad=False)
|
||||
setattr(layer, name + "_weight_scale_2", global_scale)
|
||||
if is_nvfp4:
|
||||
global_scale = nvfp4_marlin_process_global_scale(global_scale)
|
||||
global_scale = torch.nn.Parameter(global_scale,
|
||||
requires_grad=False)
|
||||
setattr(layer, name + "_weight_scale_2", global_scale)
|
||||
|
||||
# BIAS
|
||||
# Permute bias
|
||||
for name in ["w13_bias", "w2_bias"]:
|
||||
if not hasattr(layer, name):
|
||||
continue
|
||||
bias = getattr(layer, name).to(param_dtype)
|
||||
|
||||
tensor_list = []
|
||||
for i in range(e):
|
||||
expert_bias = bias[i]
|
||||
|
||||
tensor_list.append(marlin_permute_bias(expert_bias))
|
||||
|
||||
bias = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||
bias = torch.nn.Parameter(bias, requires_grad=False)
|
||||
setattr(layer, name, bias)
|
||||
|
||||
|
||||
def rand_marlin_weight_fp4_like(weight, group_size):
|
||||
def rand_marlin_weight_nvfp4_like(weight, group_size):
|
||||
assert group_size > 0
|
||||
size_n, size_k = weight.shape
|
||||
device = weight.device
|
||||
@@ -276,8 +339,58 @@ def rand_marlin_weight_fp4_like(weight, group_size):
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
group_size=group_size)
|
||||
marlin_scales = fp4_marlin_process_scales(marlin_scales)
|
||||
marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
|
||||
|
||||
global_scale = fp4_marlin_process_global_scale(global_scale)
|
||||
global_scale = nvfp4_marlin_process_global_scale(global_scale)
|
||||
|
||||
return weight_ref.T, marlin_qweight, marlin_scales, global_scale
|
||||
|
||||
|
||||
def rand_marlin_weight_mxfp4_like(weight, group_size):
|
||||
assert group_size > 0
|
||||
size_n, size_k = weight.shape
|
||||
device = weight.device
|
||||
|
||||
scales = torch.randint(100,
|
||||
125, (size_n, size_k // group_size),
|
||||
dtype=torch.uint8,
|
||||
device=weight.device)
|
||||
scales = scales.view(torch.float8_e8m0fnu)
|
||||
|
||||
fp4_weight = torch.randint(0,
|
||||
256, (size_n, size_k // 2),
|
||||
dtype=torch.uint8,
|
||||
device=weight.device)
|
||||
fp4_weight_part_1 = ((fp4_weight & 0b10000000) |
|
||||
((fp4_weight & 0b01110000) >> 2))
|
||||
fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn)
|
||||
fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6)
|
||||
|
||||
fp4_weight2 = fp4_weight << 4
|
||||
fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) |
|
||||
((fp4_weight2 & 0b01110000) >> 2))
|
||||
fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn)
|
||||
fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6)
|
||||
|
||||
weight_ref = torch.cat(
|
||||
[fp4_weight_part_2.unsqueeze(2),
|
||||
fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k)
|
||||
weight_ref = weight_ref * \
|
||||
scales.repeat_interleave(group_size, 1).to(weight.dtype)
|
||||
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
b_q_weight=fp4_weight.view(torch.int32).T.contiguous(),
|
||||
perm=torch.empty(0, dtype=torch.int, device=device),
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
num_bits=4,
|
||||
)
|
||||
|
||||
marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype),
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
group_size=group_size)
|
||||
|
||||
marlin_scales = mxfp4_marlin_process_scales(marlin_scales)
|
||||
|
||||
return weight_ref.T, marlin_qweight, marlin_scales.to(torch.float8_e8m0fnu)
|
||||
|
||||
@@ -1,45 +1,133 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
OCP_MX_BLOCK_SIZE = 32
|
||||
|
||||
|
||||
def per_token_group_quant_mxfp4(x: torch.Tensor,
|
||||
block_k: int,
|
||||
scale_calculation_mode: str = "even"
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
|
||||
""" weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel
|
||||
"""
|
||||
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
|
||||
from triton_kernels.numerics import InFlexData
|
||||
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
||||
from triton_kernels.tensor_details import layout
|
||||
from triton_kernels.tensor_details.layout import StridedLayout
|
||||
if (current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(90)
|
||||
and not is_torch_equal_or_newer("2.8.1")):
|
||||
logger.warning_once(
|
||||
"Mxfp4 on hopper is running on torch < 2.8.1, "
|
||||
"this cause swizling to be disabled, which may "
|
||||
"cause performance degradation. Please upgrade to torch nightly")
|
||||
value_layout, value_layout_opts = StridedLayout, dict()
|
||||
scale_layout, scale_layout_opts = StridedLayout, dict()
|
||||
else:
|
||||
value_layout, value_layout_opts = \
|
||||
layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
|
||||
scale_layout, scale_layout_opts = (
|
||||
layout.make_default_matmul_mxfp4_w_scale_layout(
|
||||
mx_axis=1, num_warps=num_warps))
|
||||
if current_platform.is_cuda() and \
|
||||
current_platform.is_device_capability(100):
|
||||
constraints = {
|
||||
"is_persistent": True,
|
||||
"epilogue_subtile": 1,
|
||||
}
|
||||
opt_flags.update_opt_flags_constraints(constraints)
|
||||
# transpose the tensor so that the quantization axis is on dim1
|
||||
quant_tensor = quant_tensor.transpose(-2, -1)
|
||||
scale = scale.transpose(-2, -1)
|
||||
quant_tensor = convert_layout(wrap_torch_tensor(quant_tensor, dtype=FP4),
|
||||
value_layout, **value_layout_opts)
|
||||
scale = convert_layout(wrap_torch_tensor(scale), scale_layout,
|
||||
**scale_layout_opts)
|
||||
return quant_tensor, InFlexData(), scale
|
||||
|
||||
|
||||
def _can_support_mxfp4(use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
scoring_func: str = "softmax",
|
||||
activation: str = "swiglu_oai",
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None):
|
||||
return not (use_grouped_topk or topk_group or num_expert_group
|
||||
or expert_map or custom_routing_function
|
||||
or e_score_correction_bias or apply_router_weight_on_input
|
||||
or scoring_func != "softmax" or activation != "swiglu_oai"
|
||||
or expert_load_view or logical_to_physical_map
|
||||
or logical_replica_count)
|
||||
|
||||
|
||||
def _dequant_mxfp4(x: torch.Tensor, scale: torch.Tensor,
|
||||
float_dtype: torch.dtype) -> torch.Tensor:
|
||||
try:
|
||||
from quark.torch.kernel.hw_emulation.hw_emulation_interface import (
|
||||
fake_quantize_fp4_fp6_per_group_with_scale)
|
||||
from quark.torch.quantization.utils import (even_round,
|
||||
reshape_to_blocks)
|
||||
from quark.torch.kernel import mx
|
||||
except ImportError as err:
|
||||
raise ImportError("The package `amd-quark` is required to use "
|
||||
"MX-FP4 models. Please install it with `pip install "
|
||||
"amd-quark`.") from err
|
||||
|
||||
axis = -1
|
||||
block_x = reshape_to_blocks(x, block_k, axis)
|
||||
amax, _ = torch.max(torch.abs(block_x), dim=-1, keepdim=True)
|
||||
amax = amax.squeeze(-1)
|
||||
return mx.dq_mxfp4(x, scale, float_dtype)
|
||||
|
||||
# TODO: there are other rounding strategies supported in quark and in the
|
||||
# config.json that we do not check for here!
|
||||
if scale_calculation_mode != "even":
|
||||
raise NotImplementedError(
|
||||
f"Scale calculation mode {scale_calculation_mode} is not yet "
|
||||
"supported in MX-FP4 quantization")
|
||||
scale = even_round(amax, "fp4")
|
||||
|
||||
# Apply dequantize(quantize(x)).
|
||||
x = fake_quantize_fp4_fp6_per_group_with_scale(
|
||||
x,
|
||||
scale.to(x.device),
|
||||
axis=axis,
|
||||
group_size=block_k,
|
||||
quant_dtype="fp4",
|
||||
def _dequant_mxfp4_fake(x: torch.Tensor, scale: torch.Tensor,
|
||||
float_dtype: torch.dtype) -> torch.Tensor:
|
||||
return torch.empty((*x.shape[:-1], x.shape[-1] * 2),
|
||||
dtype=float_dtype,
|
||||
device=x.device)
|
||||
|
||||
|
||||
def _quant_dequant_mxfp4(x: torch.Tensor,
|
||||
scale_calculation_mode: str = "even") -> torch.Tensor:
|
||||
try:
|
||||
from quark.torch.kernel import mx
|
||||
except ImportError as err:
|
||||
raise ImportError("The package `amd-quark` is required to use "
|
||||
"MX-FP4 models. Please install it with `pip install "
|
||||
"amd-quark`.") from err
|
||||
|
||||
return mx.qdq_mxfp4(x, scale_calculation_mode)
|
||||
|
||||
|
||||
def _quant_dequant_mxfp4_fake(x: torch.Tensor,
|
||||
scale_calculation_mode: str = "even"
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="dequant_mxfp4",
|
||||
op_func=_dequant_mxfp4,
|
||||
mutates_args=[],
|
||||
fake_impl=_dequant_mxfp4_fake,
|
||||
)
|
||||
dequant_mxfp4 = torch.ops.vllm.dequant_mxfp4
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
return x, scale
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="quant_dequant_mxfp4",
|
||||
op_func=_quant_dequant_mxfp4,
|
||||
mutates_args=[],
|
||||
fake_impl=_quant_dequant_mxfp4_fake,
|
||||
)
|
||||
quant_dequant_mxfp4 = torch.ops.vllm.quant_dequant_mxfp4
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
@@ -3,22 +3,41 @@
|
||||
"""This file is used for /tests and /benchmarks"""
|
||||
from collections.abc import Mapping
|
||||
from types import MappingProxyType
|
||||
from typing import Optional
|
||||
from typing import ClassVar, NamedTuple, Optional
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
|
||||
from vllm.model_executor.layers.quantization.qqq import (
|
||||
MARLIN_QQQ_SUPPORTED_NUM_BITS)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
# Use proxy as NamedTuple direct subclasses cannot have static members
|
||||
class _GroupShape(NamedTuple):
|
||||
row: int
|
||||
col: int
|
||||
|
||||
|
||||
class GroupShape(_GroupShape):
|
||||
"""
|
||||
This class describes the quantization group shape.
|
||||
It includes static members for common shapes (per-tensor, per-token).
|
||||
"""
|
||||
|
||||
# Aliases for common quantization group shapes
|
||||
PER_TENSOR: ClassVar['GroupShape']
|
||||
PER_TOKEN: ClassVar['GroupShape']
|
||||
|
||||
|
||||
GroupShape.PER_TENSOR = GroupShape(-1, -1)
|
||||
GroupShape.PER_TOKEN = GroupShape(1, -1)
|
||||
|
||||
|
||||
# Normalize the group_shape to the full extent for any dims that are -1
|
||||
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: tuple[int,
|
||||
int]):
|
||||
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
|
||||
# -1 means full extent
|
||||
return (group_shape[0] if group_shape[0] > 0 else x.shape[-2],
|
||||
group_shape[1] if group_shape[1] > 0 else x.shape[-1])
|
||||
@@ -58,7 +77,7 @@ def group_broadcast(t, shape):
|
||||
# (i.e. per-token-per-group)
|
||||
def scaled_quantize(
|
||||
x: torch.Tensor,
|
||||
group_shape: tuple[int, int],
|
||||
group_shape: GroupShape,
|
||||
quant_dtype: torch.dtype,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
group_shape = _normalize_quant_group_shape(x, group_shape)
|
||||
@@ -99,7 +118,7 @@ def scaled_quantize(
|
||||
def scaled_dequantize(
|
||||
x_q: torch.Tensor,
|
||||
x_s: torch.Tensor,
|
||||
group_shape: Optional[tuple[int, int]] = None,
|
||||
group_shape: Optional[GroupShape] = None,
|
||||
out_dtype: torch.dtype = torch.float32,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if group_shape is not None:
|
||||
@@ -332,6 +351,10 @@ def quantize_weights(w: torch.Tensor,
|
||||
)
|
||||
|
||||
|
||||
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
|
||||
def gptq_quantize_weights(w: torch.Tensor,
|
||||
quant_type: ScalarType,
|
||||
group_size: int,
|
||||
@@ -571,3 +594,56 @@ def awq_pack(
|
||||
q_w = q_w.reshape((-1, size_n)).contiguous()
|
||||
|
||||
return pack_cols(q_w, num_bits, size_k, size_n)
|
||||
|
||||
|
||||
def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Pad and block-interleave the FP4 block-scales so that they match the data
|
||||
layout expected by the CUTLASS / FlashInfer kernels.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
scale: torch.Tensor
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
The swizzled tensor with the same logical shape as *scale*.
|
||||
"""
|
||||
assert scale.dtype == torch.float8_e4m3fn, (
|
||||
"swizzle_blockscale expects the input tensor to be in "
|
||||
"torch.float8_e4m3fn format.")
|
||||
|
||||
scale_ndim = scale.ndim
|
||||
if scale_ndim == 2:
|
||||
scale = scale.unsqueeze(0) # (1, M, K)
|
||||
assert scale.ndim == 3, "Expected a 2-D or 3-D tensor for block scales."
|
||||
|
||||
B, M, K = scale.shape
|
||||
|
||||
def _round_up(x: int, m: int) -> int:
|
||||
return (x + m - 1) // m * m
|
||||
|
||||
M_padded = _round_up(M, 128)
|
||||
K_padded = _round_up(K, 4)
|
||||
|
||||
padded = torch.zeros((B, M_padded, K_padded),
|
||||
dtype=scale.dtype,
|
||||
device=scale.device)
|
||||
padded[:B, :M, :K] = scale
|
||||
|
||||
# Reshape / permute to the layout required by the kernel.
|
||||
padded = padded.reshape(B, M_padded // 128, 4, 32, K_padded // 4, 4)
|
||||
swizzled = padded.permute(0, 1, 4, 3, 2, 5).contiguous().cuda()
|
||||
|
||||
if scale_ndim == 2:
|
||||
return swizzled.reshape(M, K)
|
||||
return swizzled.reshape(B, M, K)
|
||||
|
||||
|
||||
def cutlass_fp4_supported() -> bool:
|
||||
if not current_platform.is_cuda():
|
||||
return False
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
capability = -1 if capability_tuple is None else capability_tuple.to_int()
|
||||
return cutlass_scaled_mm_supports_fp4(capability)
|
||||
|
||||
@@ -8,7 +8,6 @@ import torch.distributed as dist
|
||||
from torch import nn
|
||||
from transformers import GptOssConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
@@ -28,7 +27,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import cdiv
|
||||
|
||||
from .utils import extract_layer_index, maybe_prefix
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
class OAIAttention(nn.Module):
|
||||
@@ -70,12 +70,9 @@ class OAIAttention(nn.Module):
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
# attention_sink_dtype = (torch.float32 if envs.VLLM_USE_TRTLLM_ATTENTION
|
||||
# else torch.bfloat16)
|
||||
attention_sink_dtype = torch.bfloat16
|
||||
self.sinks = torch.nn.Parameter(
|
||||
torch.empty(config.num_attention_heads // tp_size,
|
||||
dtype=attention_sink_dtype,
|
||||
dtype=torch.bfloat16,
|
||||
requires_grad=False))
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
@@ -207,6 +204,7 @@ class GptOssModel(nn.Module):
|
||||
super().__init__()
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
self.quant_config = vllm_config.quant_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.config.hidden_size = self.config.hidden_size
|
||||
self.embedding = VocabParallelEmbedding(
|
||||
self.config.vocab_size,
|
||||
@@ -229,8 +227,364 @@ class GptOssModel(nn.Module):
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def _load_weights_mxfp4(
|
||||
self,
|
||||
ep_rank_end: int,
|
||||
ep_rank_start: int,
|
||||
heads_per_rank: int,
|
||||
head_start: int,
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
stacked_params_mapping: list[tuple[str, ...]],
|
||||
) -> set[str]:
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
mxfp4_block = 32
|
||||
use_ep = self.parallel_config.enable_expert_parallel
|
||||
num_experts = self.config.num_local_experts
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
intermediate_size = self.config.intermediate_size
|
||||
intermediate_size_block = intermediate_size // mxfp4_block
|
||||
per_rank_intermediate_size_block = cdiv(intermediate_size_block,
|
||||
tp_size)
|
||||
per_rank_intermediate_size = (per_rank_intermediate_size_block *
|
||||
mxfp4_block)
|
||||
|
||||
# Calculate common slicing bounds for current rank
|
||||
tp_rank_start = tp_rank * per_rank_intermediate_size
|
||||
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
|
||||
intermediate_size)
|
||||
|
||||
for name, weight in weights:
|
||||
# FIXME(woosuk): Remove this after testing.
|
||||
weight = weight.cuda()
|
||||
|
||||
if ".w13_weight_scale" in name:
|
||||
# Handle MLP gate and up projection weights scale
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:,
|
||||
2 * tp_rank_start:2 * tp_rank_end,
|
||||
...]
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param,
|
||||
narrow_weight,
|
||||
weight_name=name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
elif ".w2_weight_scale" in name:
|
||||
# Handle MLP down projection weights
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[..., tp_rank_start //
|
||||
mxfp4_block:tp_rank_end //
|
||||
mxfp4_block]
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param,
|
||||
narrow_weight,
|
||||
weight_name=name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
elif ".w13_weight" in name:
|
||||
# Handle MLP gate and up projection weights
|
||||
# flat weight from (E, 2 * N, block_size, entry_per_block)
|
||||
# to (E, 2 * N, -1), shouldn't trigger copy for contiguous
|
||||
weight = weight.view(num_experts, 2 * intermediate_size,
|
||||
-1).contiguous()
|
||||
|
||||
# Extract gate and up projection parts
|
||||
# since the weight is shuffled, we can slice directly
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:,
|
||||
2 * tp_rank_start:2 * tp_rank_end,
|
||||
...]
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param,
|
||||
narrow_weight,
|
||||
weight_name=name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
elif ".w2_weight" in name:
|
||||
# Handle MLP down projection weights
|
||||
# same flatten here, but since 2 mx4 value are packed in 1
|
||||
# uint8, divide by 2
|
||||
weight = weight.view(num_experts, -1,
|
||||
intermediate_size // 2).contiguous()
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[...,
|
||||
tp_rank_start // 2:tp_rank_end // 2]
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param,
|
||||
narrow_weight,
|
||||
weight_name=name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
elif ".w13_bias" in name:
|
||||
# Handle MLP gate and up projection biases
|
||||
# Extract gate and up projection bias parts
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:,
|
||||
2 * tp_rank_start:2 * tp_rank_end]
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param,
|
||||
narrow_weight,
|
||||
weight_name=name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
elif ".w2_bias" in name:
|
||||
# Handle MLP down projection bias
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
if use_ep:
|
||||
weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
# (only load on rank 0 to avoid duplication)
|
||||
if tp_rank != 0:
|
||||
weight.zero_()
|
||||
weight_loader(param,
|
||||
weight,
|
||||
weight_name=name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
elif "sinks" in name:
|
||||
# Handle attention sinks (distributed across ranks)
|
||||
param = params_dict[name]
|
||||
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
|
||||
param.data.copy_(narrow_weight)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
if weight_loader == default_weight_loader:
|
||||
weight_loader(param, weight)
|
||||
else:
|
||||
weight_loader(param, weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Handle all other weights with potential renaming
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
def _load_weights_other(
|
||||
self,
|
||||
ep_rank_start: int,
|
||||
ep_rank_end: int,
|
||||
heads_per_rank: int,
|
||||
head_start: int,
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
stacked_params_mapping: list[tuple[str, ...]],
|
||||
) -> set[str]:
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
use_ep = self.parallel_config.enable_expert_parallel
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
intermediate_size = self.config.intermediate_size
|
||||
per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
|
||||
# Calculate common slicing bounds for current rank
|
||||
tp_rank_start = tp_rank * per_rank_intermediate_size
|
||||
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
|
||||
intermediate_size)
|
||||
|
||||
for name, weight in weights:
|
||||
if ".w13_weight" in name:
|
||||
# Handle MLP gate and up projection weights
|
||||
# Extract gate and up projection parts
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:, :,
|
||||
2 * tp_rank_start:2 * tp_rank_end]
|
||||
|
||||
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
|
||||
param = params_dict[name]
|
||||
|
||||
param.copy_(narrow_weight)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
elif ".w2_weight" in name:
|
||||
# Handle MLP down projection weights
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
|
||||
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
|
||||
param = params_dict[name]
|
||||
|
||||
param.copy_(narrow_weight)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
elif ".w13_bias" in name:
|
||||
# Handle MLP gate and up projection biases
|
||||
# Extract gate and up projection bias parts
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:,
|
||||
2 * tp_rank_start:2 * tp_rank_end]
|
||||
|
||||
param = params_dict[name]
|
||||
param.copy_(narrow_weight)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
elif ".w2_bias" in name:
|
||||
# Handle MLP down projection bias
|
||||
if use_ep:
|
||||
weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
# (only load on rank 0 to avoid duplication)
|
||||
if tp_rank != 0:
|
||||
weight.zero_()
|
||||
param = params_dict[name]
|
||||
param.copy_(weight)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
elif "sinks" in name:
|
||||
# Handle attention sinks (distributed across ranks)
|
||||
param = params_dict[name]
|
||||
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
|
||||
param.data.copy_(narrow_weight)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
if weight_loader == default_weight_loader:
|
||||
weight_loader(param, weight)
|
||||
else:
|
||||
weight_loader(param, weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Handle all other weights with potential renaming
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv", ".q_proj", "q"),
|
||||
(".qkv", ".k_proj", "k"),
|
||||
(".qkv", ".v_proj", "v"),
|
||||
]
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
# Attention heads per rank
|
||||
heads_per_rank = self.config.num_attention_heads // tp_size
|
||||
head_start = tp_rank * heads_per_rank
|
||||
|
||||
ep_size = get_ep_group().world_size
|
||||
ep_rank = get_ep_group().rank
|
||||
num_experts = self.config.num_local_experts
|
||||
experts_per_rank = num_experts // ep_size
|
||||
ep_rank_start = ep_rank * experts_per_rank
|
||||
ep_rank_end = (ep_rank + 1) * experts_per_rank
|
||||
|
||||
quant_method = (self.config.quantization_config['quant_method'] if
|
||||
hasattr(self.config, "quantization_config") else None)
|
||||
if quant_method == "mxfp4":
|
||||
return self._load_weights_mxfp4(ep_rank_end, ep_rank_start,
|
||||
heads_per_rank, head_start,
|
||||
weights, stacked_params_mapping)
|
||||
else:
|
||||
return self._load_weights_other(ep_rank_end, ep_rank_start,
|
||||
heads_per_rank, head_start,
|
||||
weights, stacked_params_mapping)
|
||||
|
||||
|
||||
class GptOssForCausalLM(nn.Module):
|
||||
packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]}
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={
|
||||
".self_attn.": ".attn.",
|
||||
".post_attention_layernorm.": ".mlp.norm.",
|
||||
},
|
||||
orig_to_new_suffix={
|
||||
".embed_tokens.weight": ".embedding.weight",
|
||||
".input_layernorm.weight": ".attn.norm.weight",
|
||||
".post_attention_layernorm.weight": ".mlp.norm.weight",
|
||||
|
||||
# MoE MXFP4 weights
|
||||
".gate_up_proj_blocks": ".w13_weight",
|
||||
".down_proj_blocks": ".w2_weight",
|
||||
".gate_up_proj_scales": ".w13_weight_scale",
|
||||
".down_proj_scales": ".w2_weight_scale",
|
||||
|
||||
# MoE other weights
|
||||
".gate_up_proj": ".w13_weight",
|
||||
".down_proj": ".w2_weight",
|
||||
|
||||
# MoE Bias
|
||||
".gate_up_proj_bias": ".w13_bias",
|
||||
".down_proj_bias": ".w2_bias",
|
||||
},
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -239,16 +593,17 @@ class GptOssForCausalLM(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config.hf_config
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
|
||||
self.model = GptOssModel(
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"),
|
||||
)
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.model_config.vocab_size,
|
||||
self.model_config.hidden_size,
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(self.model_config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(self.config.vocab_size)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -265,354 +620,11 @@ class GptOssForCausalLM(nn.Module):
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def _load_weights_mxfp4(
|
||||
self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
rename_mapping = {
|
||||
"self_attn": "attn",
|
||||
"input_layernorm.weight": "attn.norm.weight",
|
||||
"post_attention_layernorm.weight": "mlp.norm.weight",
|
||||
"embed_tokens": "embedding",
|
||||
}
|
||||
|
||||
def maybe_rename(name: str) -> str:
|
||||
for remap_name, new_name in rename_mapping.items():
|
||||
if remap_name in name:
|
||||
return name.replace(remap_name, new_name)
|
||||
return name
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
mxfp4_block = 32
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
intermediate_size = self.model_config.intermediate_size
|
||||
intermediate_size_block = intermediate_size // mxfp4_block
|
||||
per_rank_intermediate_size_block = cdiv(intermediate_size_block,
|
||||
tp_size)
|
||||
per_rank_intermediate_size = (per_rank_intermediate_size_block *
|
||||
mxfp4_block)
|
||||
|
||||
# Calculate common slicing bounds for current rank
|
||||
tp_rank_start = tp_rank * per_rank_intermediate_size
|
||||
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
|
||||
intermediate_size)
|
||||
|
||||
# Attention heads per rank
|
||||
heads_per_rank = self.model_config.num_attention_heads // tp_size
|
||||
head_start = tp_rank * heads_per_rank
|
||||
|
||||
use_ep = self.vllm_config.parallel_config.enable_expert_parallel
|
||||
ep_size = get_ep_group().world_size
|
||||
ep_rank = get_ep_group().rank
|
||||
num_experts = self.model_config.num_local_experts
|
||||
experts_per_rank = num_experts // ep_size
|
||||
ep_rank_start = ep_rank * experts_per_rank
|
||||
ep_rank_end = (ep_rank + 1) * experts_per_rank
|
||||
|
||||
for name, weight in weights:
|
||||
# FIXME(woosuk): Remove this after testing.
|
||||
weight = weight.cuda()
|
||||
|
||||
if "gate_up_proj_blocks" in name:
|
||||
# Handle MLP gate and up projection weights
|
||||
new_name = name.replace("gate_up_proj_blocks", "w13_weight")
|
||||
|
||||
# flat weight from (E, 2 * N, block_size, entry_per_block)
|
||||
# to (E, 2 * N, -1), shouldn't trigger copy for contiguous
|
||||
weight = weight.view(num_experts, 2 * intermediate_size,
|
||||
-1).contiguous()
|
||||
|
||||
# Extract gate and up projection parts
|
||||
# since the weight is shuffled, we can slice directly
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:,
|
||||
2 * tp_rank_start:2 * tp_rank_end,
|
||||
...]
|
||||
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param,
|
||||
narrow_weight,
|
||||
weight_name=new_name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(new_name)
|
||||
|
||||
elif "down_proj_blocks" in name:
|
||||
# Handle MLP down projection weights
|
||||
new_name = name.replace("down_proj_blocks", "w2_weight")
|
||||
# same flatten here, but since 2 mx4 value are packed in 1
|
||||
# uint8, divide by 2
|
||||
weight = weight.view(num_experts, -1,
|
||||
intermediate_size // 2).contiguous()
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[...,
|
||||
tp_rank_start // 2:tp_rank_end // 2]
|
||||
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param,
|
||||
narrow_weight,
|
||||
weight_name=new_name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(new_name)
|
||||
|
||||
elif "gate_up_proj_scales" in name:
|
||||
# Handle MLP gate and up projection weights scale
|
||||
new_name = name.replace("gate_up_proj_scales",
|
||||
"w13_weight_scale")
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:,
|
||||
2 * tp_rank_start:2 * tp_rank_end,
|
||||
...]
|
||||
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param,
|
||||
narrow_weight,
|
||||
weight_name=new_name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(new_name)
|
||||
|
||||
elif "down_proj_scales" in name:
|
||||
# Handle MLP down projection weights
|
||||
new_name = name.replace("down_proj_scales", "w2_weight_scale")
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[..., tp_rank_start //
|
||||
mxfp4_block:tp_rank_end //
|
||||
mxfp4_block]
|
||||
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param,
|
||||
narrow_weight,
|
||||
weight_name=new_name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(new_name)
|
||||
elif "gate_up_proj_bias" in name:
|
||||
# Handle MLP gate and up projection biases
|
||||
new_name = name.replace("gate_up_proj_bias", "w13_bias")
|
||||
|
||||
# Extract gate and up projection bias parts
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:,
|
||||
2 * tp_rank_start:2 * tp_rank_end]
|
||||
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param,
|
||||
narrow_weight,
|
||||
weight_name=new_name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(new_name)
|
||||
|
||||
elif "down_proj_bias" in name:
|
||||
# Handle MLP down projection bias
|
||||
new_name = name.replace("down_proj_bias", "w2_bias")
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
if use_ep:
|
||||
weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
# (only load on rank 0 to avoid duplication)
|
||||
if tp_rank != 0:
|
||||
weight.zero_()
|
||||
weight_loader(param,
|
||||
weight,
|
||||
weight_name=new_name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(new_name)
|
||||
elif "sinks" in name:
|
||||
# Handle attention sinks (distributed across ranks)
|
||||
name = name.replace("self_attn", "attn")
|
||||
param = params_dict[name]
|
||||
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
|
||||
param.data.copy_(narrow_weight)
|
||||
loaded_params.add(name)
|
||||
elif "q_proj" in name or "k_proj" in name or "v_proj" in name:
|
||||
shard_id = ("q" if "q_proj" in name else
|
||||
"k" if "k_proj" in name else "v")
|
||||
name = name.replace("self_attn", "attn")
|
||||
param_name = name.replace(f"{shard_id}_proj", "qkv")
|
||||
param = params_dict[param_name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, weight, loaded_shard_id=shard_id)
|
||||
loaded_params.add(param_name)
|
||||
else:
|
||||
# Handle all other weights with potential renaming
|
||||
renamed_name = maybe_rename(name)
|
||||
if renamed_name not in params_dict:
|
||||
continue
|
||||
param = params_dict[renamed_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, weight)
|
||||
loaded_params.add(renamed_name)
|
||||
|
||||
return loaded_params
|
||||
|
||||
def _load_weights_other(
|
||||
self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
rename_mapping = {
|
||||
"self_attn": "attn",
|
||||
"input_layernorm.weight": "attn.norm.weight",
|
||||
"post_attention_layernorm.weight": "mlp.norm.weight",
|
||||
"embed_tokens": "embedding",
|
||||
}
|
||||
|
||||
def maybe_rename(name: str) -> str:
|
||||
for remap_name, new_name in rename_mapping.items():
|
||||
if remap_name in name:
|
||||
return name.replace(remap_name, new_name)
|
||||
return name
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
intermediate_size = self.model_config.intermediate_size
|
||||
|
||||
per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
|
||||
# Calculate common slicing bounds for current rank
|
||||
tp_rank_start = tp_rank * per_rank_intermediate_size
|
||||
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
|
||||
intermediate_size)
|
||||
|
||||
# Attention heads per rank
|
||||
heads_per_rank = self.model_config.num_attention_heads // tp_size
|
||||
head_start = tp_rank * heads_per_rank
|
||||
|
||||
use_ep = self.vllm_config.parallel_config.enable_expert_parallel
|
||||
ep_size = get_ep_group().world_size
|
||||
ep_rank = get_ep_group().rank
|
||||
num_experts = self.model_config.num_local_experts
|
||||
experts_per_rank = num_experts // ep_size
|
||||
ep_rank_start = ep_rank * experts_per_rank
|
||||
ep_rank_end = (ep_rank + 1) * experts_per_rank
|
||||
|
||||
for name, weight in weights:
|
||||
if ".experts.gate_up_proj" in name and "bias" not in name:
|
||||
# Handle MLP gate and up projection weights
|
||||
new_name = name.replace(".experts.gate_up_proj",
|
||||
".experts.w13_weight")
|
||||
|
||||
# Extract gate and up projection parts
|
||||
# since the weight is shuffled, we can slice directly
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:, :,
|
||||
2 * tp_rank_start:2 * tp_rank_end]
|
||||
|
||||
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
|
||||
param = params_dict[new_name]
|
||||
|
||||
param.copy_(narrow_weight)
|
||||
loaded_params.add(new_name)
|
||||
|
||||
elif ".experts.down_proj" in name and "bias" not in name:
|
||||
# Handle MLP down projection weights
|
||||
new_name = name.replace(".experts.down_proj",
|
||||
".experts.w2_weight")
|
||||
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
|
||||
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
|
||||
param = params_dict[new_name]
|
||||
|
||||
param.copy_(narrow_weight)
|
||||
loaded_params.add(new_name)
|
||||
|
||||
elif "gate_up_proj_bias" in name:
|
||||
# Handle MLP gate and up projection biases
|
||||
new_name = name.replace("gate_up_proj_bias", "w13_bias")
|
||||
|
||||
# Extract gate and up projection bias parts
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:,
|
||||
2 * tp_rank_start:2 * tp_rank_end]
|
||||
|
||||
param = params_dict[new_name]
|
||||
|
||||
param.copy_(narrow_weight)
|
||||
loaded_params.add(new_name)
|
||||
|
||||
elif "down_proj_bias" in name:
|
||||
# Handle MLP down projection bias
|
||||
new_name = name.replace("down_proj_bias", "w2_bias")
|
||||
|
||||
if use_ep:
|
||||
weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
# (only load on rank 0 to avoid duplication)
|
||||
if tp_rank != 0:
|
||||
weight.zero_()
|
||||
param = params_dict[new_name]
|
||||
param.copy_(weight)
|
||||
loaded_params.add(new_name)
|
||||
elif "sinks" in name:
|
||||
# Handle attention sinks (distributed across ranks)
|
||||
name = name.replace("self_attn", "attn")
|
||||
param = params_dict[name]
|
||||
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
|
||||
param.data.copy_(narrow_weight)
|
||||
loaded_params.add(name)
|
||||
elif "q_proj" in name or "k_proj" in name or "v_proj" in name:
|
||||
shard_id = ("q" if "q_proj" in name else
|
||||
"k" if "k_proj" in name else "v")
|
||||
name = name.replace("self_attn", "attn")
|
||||
param_name = name.replace(f"{shard_id}_proj", "qkv")
|
||||
param = params_dict[param_name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, weight, loaded_shard_id=shard_id)
|
||||
loaded_params.add(param_name)
|
||||
else:
|
||||
# Handle all other weights with potential renaming
|
||||
|
||||
renamed_name = maybe_rename(name)
|
||||
if renamed_name not in params_dict:
|
||||
continue
|
||||
param = params_dict[renamed_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, weight)
|
||||
loaded_params.add(renamed_name)
|
||||
|
||||
return loaded_params
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
quant_method = (self.model_config.quantization_config['quant_method']
|
||||
if hasattr(self.model_config, "quantization_config")
|
||||
else None)
|
||||
if quant_method == "mxfp4":
|
||||
return self._load_weights_mxfp4(weights)
|
||||
else:
|
||||
return self._load_weights_other(weights)
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."]
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
@@ -179,12 +179,14 @@ class CudaPlatformBase(Platform):
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
||||
kv_cache_dtype, block_size, use_v1,
|
||||
use_mla) -> str:
|
||||
kv_cache_dtype, block_size, use_v1, use_mla,
|
||||
has_sink) -> str:
|
||||
if use_mla:
|
||||
# TODO(lucas): refactor to be more concise
|
||||
# TODO(lucas): refactor to be more concise
|
||||
# we should probably consider factoring out V1 here
|
||||
if selected_backend == _Backend.CUTLASS_MLA_VLLM_V1:
|
||||
if selected_backend == _Backend.CUTLASS_MLA or (
|
||||
cls.is_device_capability(100) and selected_backend is None
|
||||
and block_size == 128):
|
||||
if use_v1:
|
||||
logger.info_once("Using Cutlass MLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla."
|
||||
@@ -223,37 +225,101 @@ class CudaPlatformBase(Platform):
|
||||
return ("vllm.attention.backends."
|
||||
"flashmla.FlashMLABackend")
|
||||
if use_v1:
|
||||
FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
|
||||
FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
|
||||
TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
|
||||
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
|
||||
TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501
|
||||
XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501
|
||||
|
||||
if selected_backend == _Backend.FLASHINFER:
|
||||
logger.info_once("Using FlashInfer backend on V1 engine.")
|
||||
return "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
|
||||
if selected_backend == _Backend.FLEX_ATTENTION:
|
||||
logger.info("Using FlexAttenion backend on V1 engine.")
|
||||
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
|
||||
if selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
|
||||
if cls.has_device_capability(100):
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
set_kv_cache_layout)
|
||||
set_kv_cache_layout("HND")
|
||||
return FLASHINFER_V1
|
||||
elif selected_backend == _Backend.FLEX_ATTENTION:
|
||||
logger.info_once("Using FlexAttention backend on V1 engine.")
|
||||
return FLEX_ATTENTION_V1
|
||||
elif selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
|
||||
logger.info_once("Using Triton backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends."
|
||||
"triton_attn.TritonAttentionBackend")
|
||||
return TRITON_ATTN_VLLM_V1
|
||||
elif selected_backend == _Backend.FLASH_ATTN:
|
||||
logger.info_once("Using Flash Attention backend on V1 engine.")
|
||||
return FLASH_ATTN_V1
|
||||
# elif selected_backend == _Backend.TREE_ATTN:
|
||||
# logger.info_once("Using Tree Attention backend on V1 engine.")
|
||||
# return TREE_ATTN_V1
|
||||
# elif selected_backend == _Backend.XFORMERS_VLLM_V1:
|
||||
# logger.info_once("Using XFormers backend on V1 engine.")
|
||||
# return XFORMERS_V1
|
||||
|
||||
from vllm.attention.selector import is_attn_backend_supported
|
||||
|
||||
# Default backends for V1 engine
|
||||
# Prefer FlashInfer for Blackwell GPUs if installed
|
||||
if cls.is_device_capability(100):
|
||||
# Prefer FlashInfer for V1 on Blackwell GPUs if installed
|
||||
try:
|
||||
import flashinfer # noqa: F401
|
||||
logger.info_once(
|
||||
"Using FlashInfer backend on V1 engine by default for "
|
||||
"Blackwell (SM 10.0) GPUs.")
|
||||
return ("vllm.v1.attention.backends."
|
||||
"flashinfer.FlashInferBackend")
|
||||
except ImportError:
|
||||
if is_default_backend_supported := is_attn_backend_supported(
|
||||
FLASHINFER_V1, head_size, dtype):
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
set_kv_cache_layout)
|
||||
|
||||
logger.info_once(
|
||||
"Using FlashInfer backend with HND KV cache layout on "
|
||||
"V1 engine by default for Blackwell (SM 10.0) GPUs.")
|
||||
set_kv_cache_layout("HND")
|
||||
|
||||
return FLASHINFER_V1
|
||||
|
||||
if not is_default_backend_supported.can_import:
|
||||
logger.warning_once(
|
||||
"FlashInfer failed to import for V1 engine on "
|
||||
"Blackwell (SM 10.0) GPUs; it is recommended to "
|
||||
"install FlashInfer for better performance.")
|
||||
pass
|
||||
|
||||
# FlashAttention is the default for SM 8.0+ GPUs
|
||||
if cls.has_device_capability(80):
|
||||
logger.info_once("Using Flash Attention backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends."
|
||||
"flash_attn.FlashAttentionBackend")
|
||||
if has_sink and not cls.is_device_capability(90):
|
||||
logger.info_once("Using Triton backend on V1 engine.")
|
||||
return TRITON_ATTN_VLLM_V1
|
||||
if is_default_backend_supported := is_attn_backend_supported(
|
||||
FLASH_ATTN_V1, head_size, dtype,
|
||||
allow_import_error=False):
|
||||
logger.info_once("Using Flash Attention backend on "
|
||||
"V1 engine.")
|
||||
return FLASH_ATTN_V1
|
||||
|
||||
# FlexAttention is the default for older GPUs
|
||||
else:
|
||||
logger.info_once("Using FlexAttention backend on V1 engine.")
|
||||
return FLEX_ATTENTION_V1
|
||||
|
||||
assert not is_default_backend_supported
|
||||
|
||||
use_flex_attention_reason = {}
|
||||
if not is_default_backend_supported.head_size:
|
||||
use_flex_attention_reason["head_size"] = head_size
|
||||
if not is_default_backend_supported.dtype:
|
||||
use_flex_attention_reason["dtype"] = dtype
|
||||
|
||||
logger.info_once(
|
||||
"Using FlexAttention backend for %s on V1 engine.",
|
||||
", ".join(f"{k}={v}"
|
||||
for k, v in use_flex_attention_reason.items()),
|
||||
)
|
||||
return FLEX_ATTENTION_V1
|
||||
|
||||
# Backends for V0 engine
|
||||
if selected_backend == _Backend.FLASHINFER:
|
||||
logger.info("Using FlashInfer backend.")
|
||||
if cls.has_device_capability(100):
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
set_kv_cache_layout)
|
||||
logger.info_once(
|
||||
"Using HND KV cache layout on V1 engine by default for "
|
||||
"Blackwell (SM 10.0) GPUs.")
|
||||
set_kv_cache_layout("HND")
|
||||
return "vllm.attention.backends.flashinfer.FlashInferBackend"
|
||||
elif selected_backend == _Backend.XFORMERS:
|
||||
logger.info("Using XFormers backend.")
|
||||
@@ -262,6 +328,10 @@ class CudaPlatformBase(Platform):
|
||||
logger.info("Using DualChunkFlashAttention backend.")
|
||||
return ("vllm.attention.backends.dual_chunk_flash_attn."
|
||||
"DualChunkFlashAttentionBackend")
|
||||
elif selected_backend == _Backend.DIFFERENTIAL_FLASH_ATTN:
|
||||
logger.info("Using DifferentialFlashAttention backend.")
|
||||
return ("vllm.attention.backends.differential_flash_attn."
|
||||
"DifferentialFlashAttentionBackend")
|
||||
elif selected_backend == _Backend.FLASH_ATTN:
|
||||
pass
|
||||
elif selected_backend:
|
||||
@@ -291,7 +361,7 @@ class CudaPlatformBase(Platform):
|
||||
# installed.
|
||||
if target_backend == _Backend.FLASH_ATTN:
|
||||
try:
|
||||
import flash_attn # noqa: F401
|
||||
import vllm.vllm_flash_attn # noqa: F401
|
||||
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
||||
FlashAttentionBackend, flash_attn_supports_fp8)
|
||||
|
||||
|
||||
@@ -2908,3 +2908,37 @@ def is_torch_equal_or_newer(target: str) -> bool:
|
||||
except Exception:
|
||||
# Fallback to PKG-INFO to load the package info, needed by the doc gen.
|
||||
return Version(importlib.metadata.version('torch')) >= Version(target)
|
||||
|
||||
|
||||
@cache
|
||||
def _has_module(module_name: str) -> bool:
|
||||
"""Return True if *module_name* can be found in the current environment.
|
||||
|
||||
The result is cached so that subsequent queries for the same module incur
|
||||
no additional overhead.
|
||||
"""
|
||||
return importlib.util.find_spec(module_name) is not None
|
||||
|
||||
|
||||
def has_pplx() -> bool:
|
||||
"""Whether the optional `pplx_kernels` package is available."""
|
||||
|
||||
return _has_module("pplx_kernels")
|
||||
|
||||
|
||||
def has_deep_ep() -> bool:
|
||||
"""Whether the optional `deep_ep` package is available."""
|
||||
|
||||
return _has_module("deep_ep")
|
||||
|
||||
|
||||
def has_deep_gemm() -> bool:
|
||||
"""Whether the optional `deep_gemm` package is available."""
|
||||
|
||||
return _has_module("deep_gemm")
|
||||
|
||||
|
||||
def has_triton_kernels() -> bool:
|
||||
"""Whether the optional `triton_kernels` package is available."""
|
||||
|
||||
return _has_module("triton_kernels")
|
||||
|
||||
@@ -44,10 +44,25 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> list[int]:
|
||||
@classmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 80, 96, 112, 128, 160, 192, 224, 256]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
attn_type = cls.__name__.removesuffix("Backend")
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {supported_head_sizes}. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes.")
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASH_ATTN_VLLM_V1"
|
||||
@@ -657,15 +672,12 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
sinks: Optional[torch.Tensor] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"FlashAttention does not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
@@ -675,6 +687,8 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if sliding_window is None:
|
||||
self.sliding_window = (-1, -1)
|
||||
elif attn_type == AttentionType.ENCODER_ONLY:
|
||||
self.sliding_window = (sliding_window - 1, sliding_window - 1)
|
||||
else:
|
||||
self.sliding_window = (sliding_window - 1, 0)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
@@ -684,28 +698,33 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
|
||||
if head_size not in support_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by FlashAttention. "
|
||||
f"Supported head sizes are: {support_head_sizes}. "
|
||||
"Set VLLM_USE_V1=0 to use another attention backend.")
|
||||
FlashAttentionBackend.validate_head_size(head_size)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
if attn_type not in [
|
||||
AttentionType.DECODER, AttentionType.ENCODER_ONLY
|
||||
]:
|
||||
raise NotImplementedError("Encoder/decoder cross-attention "
|
||||
"is not implemented for "
|
||||
"FlashAttentionImpl")
|
||||
|
||||
self.use_irope = use_irope
|
||||
self.attn_type = attn_type
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype) \
|
||||
and not flash_attn_supports_fp8():
|
||||
raise NotImplementedError(
|
||||
"FlashAttention does not support fp8 kv-cache on this device.")
|
||||
|
||||
self.sinks = sinks
|
||||
if self.sinks is not None:
|
||||
assert self.vllm_flash_attn_version == 3, (
|
||||
"Sinks are only supported in FlashAttention 3")
|
||||
assert self.sinks.shape[0] == num_heads, (
|
||||
"Sinks must have the same number of heads as the number of "
|
||||
"heads in the layer")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -715,6 +734,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
@@ -732,6 +752,11 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for FlashAttentionImpl")
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
|
||||
@@ -37,6 +37,10 @@ class FlashInferBackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> list[int]:
|
||||
return [64, 128, 256]
|
||||
@@ -510,10 +514,10 @@ class FlashInferImpl(AttentionImpl):
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[int] = None,
|
||||
sinks: Optional[torch.Tensor] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if use_irope:
|
||||
@@ -543,6 +547,19 @@ class FlashInferImpl(AttentionImpl):
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashInferImpl")
|
||||
|
||||
self.sinks: Optional[torch.Tensor] = None
|
||||
if sinks is not None:
|
||||
if sinks.shape[0] != num_heads:
|
||||
raise ValueError(
|
||||
"Sinks must have the same number of heads as the number of "
|
||||
f"heads in the layer. Expected {num_heads}, but got "
|
||||
f"{sinks.shape[0]}."
|
||||
)
|
||||
# Cast sinks to float32 if needed (FlashInfer requirement)
|
||||
if sinks.dtype != torch.float32:
|
||||
sinks = sinks.to(torch.float32)
|
||||
self.sinks = sinks
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -553,6 +570,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashInferMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashInfer.
|
||||
|
||||
@@ -567,6 +585,11 @@ class FlashInferImpl(AttentionImpl):
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for FlashInferImpl")
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
|
||||
@@ -44,6 +44,10 @@ def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor:
|
||||
class FlexAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@classmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> list[int]:
|
||||
return [16, 32, 64, 96, 128, 160, 192, 224, 256]
|
||||
@@ -346,15 +350,10 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
# TODO we should support this :think
|
||||
raise ValueError(
|
||||
"FlashAttention does not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
@@ -410,6 +409,7 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlexAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FLexAttention.
|
||||
|
||||
@@ -423,6 +423,11 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for FlexAttentionImpl")
|
||||
|
||||
enable_gqa = self.num_kv_heads != self.num_heads
|
||||
|
||||
if attn_metadata is None:
|
||||
|
||||
@@ -110,7 +110,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[int] = None,
|
||||
@@ -120,9 +119,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
logger.warning_once(
|
||||
"Using irope in Pallas is not supported yet, it will fall back "
|
||||
"to global attention for long context.")
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError("Paged attention Pallas kernel does "
|
||||
"not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
@@ -139,8 +135,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
raise NotImplementedError("Alibi slopes is not supported.")
|
||||
if kv_cache_dtype != "auto":
|
||||
raise NotImplementedError("FP8 KV cache dtype is not supported.")
|
||||
if blocksparse_params is not None:
|
||||
raise NotImplementedError("Blocksparse is not supported.")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
@@ -161,6 +155,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: PallasMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Pallas attention.
|
||||
|
||||
@@ -173,6 +168,11 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for PallasAttentionBackendImpl")
|
||||
|
||||
# For determine_available_memory case.
|
||||
if kv_cache.numel() == 0:
|
||||
if output is None:
|
||||
|
||||
@@ -3,6 +3,10 @@
|
||||
"""Attention layer with PagedAttention and Triton prefix prefill."""
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import cache
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
@@ -12,7 +16,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
from vllm.attention.ops.chunked_prefill_paged_decode import (
|
||||
chunked_prefill_paged_decode)
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.flash_attn import (
|
||||
@@ -38,10 +42,25 @@ class TritonAttentionBackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> list[int]:
|
||||
@classmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
attn_type = cls.__name__.removesuffix("Backend")
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {supported_head_sizes}. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes.")
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_ATTN_VLLM_V1"
|
||||
@@ -74,6 +93,15 @@ class TritonAttentionBackend(AttentionBackend):
|
||||
return TritonAttentionMetadataBuilder
|
||||
|
||||
|
||||
@cache
|
||||
def use_aiter_unified_attention() -> bool:
|
||||
"""Check if aiter unified attention should be used."""
|
||||
# VLLM_ROCM_USE_AITER_MHA needs to set to 0 as well as it is set
|
||||
# to 1 as default
|
||||
return envs.VLLM_ROCM_USE_AITER \
|
||||
and envs.VLLM_USE_AITER_UNIFIED_ATTENTION
|
||||
|
||||
|
||||
class TritonAttentionImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
@@ -85,16 +113,11 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[int] = None,
|
||||
use_irope: bool = False,
|
||||
sinks: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"TritonAttention does not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
@@ -113,16 +136,9 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
self.use_irope = use_irope
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
support_head_sizes = TritonAttentionBackend.get_supported_head_sizes()
|
||||
if head_size not in support_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by TritonAttention. "
|
||||
f"Supported head sizes are: {support_head_sizes}.")
|
||||
TritonAttentionBackend.validate_head_size(head_size)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
@@ -133,7 +149,23 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
self.fp8_dtype = current_platform.fp8_dtype()
|
||||
self.force_prefill_decode_attn = \
|
||||
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
|
||||
|
||||
|
||||
if not self.force_prefill_decode_attn:
|
||||
# If not using prefill decode attention, we use the Triton
|
||||
# unified attention implementation.
|
||||
if use_aiter_unified_attention():
|
||||
logger.info_once(
|
||||
"Using aiter unified attention for TritonAttentionImpl")
|
||||
from aiter.ops.triton.unified_attention import (
|
||||
unified_attention)
|
||||
self.unified_attention = unified_attention
|
||||
else:
|
||||
logger.info_once(
|
||||
"Using vllm unified attention for TritonAttentionImpl")
|
||||
from vllm.attention.ops.triton_unified_attention import (
|
||||
unified_attention)
|
||||
self.unified_attention = unified_attention
|
||||
|
||||
self.sinks = sinks
|
||||
if sinks is not None:
|
||||
assert sinks.shape[0] == num_heads, (
|
||||
@@ -150,6 +182,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
@@ -164,6 +197,11 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for TritonAttentionImpl")
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
@@ -227,51 +265,41 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
query.reshape(
|
||||
(num_tokens, num_heads * head_size)).contiguous(),
|
||||
layer._q_scale)
|
||||
query = query.reshape((num_tokens, num_heads, head_size))
|
||||
query = query.reshape((num_tokens, num_heads, head_size))
|
||||
|
||||
use_local_attn = \
|
||||
(self.use_irope and attn_metadata.local_attn_metadata is not None)
|
||||
|
||||
if use_local_attn:
|
||||
assert attn_metadata.local_attn_metadata is not None
|
||||
local_metadata = attn_metadata.local_attn_metadata
|
||||
cu_seqlens_q = local_metadata.local_query_start_loc
|
||||
seqused_k = local_metadata.local_seqused_k
|
||||
max_seqlen_q = local_metadata.local_max_query_len
|
||||
max_seqlen_k = local_metadata.local_max_seq_len
|
||||
block_table = local_metadata.local_block_table
|
||||
else:
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
|
||||
if use_prefill_decode_attn:
|
||||
# Compute attention and update output up to `num_actual_tokens`.
|
||||
chunked_prefill_paged_decode(query=query[:num_actual_tokens],
|
||||
key=key[:num_actual_tokens],
|
||||
value=value[:num_actual_tokens],
|
||||
output=output[:num_actual_tokens],
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
block_table=block_table,
|
||||
query_start_loc=cu_seqlens_q,
|
||||
seq_lens=seqused_k,
|
||||
max_seq_len=max_seqlen_k,
|
||||
max_query_len=max_seqlen_q,
|
||||
k_scale=layer._k_scale,
|
||||
v_scale=layer._v_scale,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sliding_window=self.sliding_window[0],
|
||||
sm_scale=self.scale,
|
||||
sinks=self.sinks)
|
||||
chunked_prefill_paged_decode(
|
||||
query=query[:num_actual_tokens],
|
||||
key=key[:num_actual_tokens],
|
||||
value=value[:num_actual_tokens],
|
||||
output=output[:num_actual_tokens],
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
block_table=block_table,
|
||||
query_start_loc=cu_seqlens_q,
|
||||
seq_lens=seqused_k,
|
||||
max_seq_len=max_seqlen_k,
|
||||
max_query_len=max_seqlen_q,
|
||||
k_scale=layer._k_scale,
|
||||
v_scale=layer._v_scale,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sliding_window=self.sliding_window[0],
|
||||
sm_scale=self.scale,
|
||||
sinks=self.sinks,
|
||||
)
|
||||
|
||||
else:
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
||||
|
||||
unified_attention(
|
||||
self.unified_attention(
|
||||
q=query[:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
|
||||
Reference in New Issue
Block a user