Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -8,6 +8,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm import envs
from vllm.distributed import (
divide,
get_tensor_model_parallel_rank,
@@ -130,13 +131,12 @@ class SiluAndMul(CustomOp):
def __init__(self, *, compile_native: bool = True):
super().__init__(compile_native=compile_native)
if current_platform.is_cuda_alike():
if current_platform.is_cuda_alike() or current_platform.is_xpu():
from vllm import _custom_ops as ops
self.op = ops.silu_and_mul
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops
self.op = ipex_ops.silu_and_mul
if envs.VLLM_USE_SILU_QUANT_FUSION:
self.op = ops.silu_and_mul_quant
else:
self.op = ops.silu_and_mul
elif current_platform.is_cpu():
self._forward_method = self.forward_native
@@ -146,11 +146,15 @@ class SiluAndMul(CustomOp):
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
def forward_cuda(self, x: torch.Tensor, out_dim: int = 0) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
self.op(out, x)
if envs.VLLM_USE_SILU_QUANT_FUSION:
quant_out, out_scales = self.op(x, out_dim)
out = (quant_out, out_scales, x.dtype)
else:
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
self.op(out, x)
return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
@@ -174,7 +178,6 @@ class MulAndSilu(CustomOp):
def __init__(self):
super().__init__()
if current_platform.is_cuda_alike() or current_platform.is_xpu():
# self.op = torch.ops._C.mul_and_silu
from vllm import _custom_ops as ops
self.op = ops.mul_and_silu
elif current_platform.is_cpu():
@@ -397,7 +400,6 @@ class NewGELU(CustomOp):
or current_platform.is_cpu()
or current_platform.is_xpu()
):
# self.op = torch.ops._C.gelu_new
from vllm import _custom_ops as ops
self.op = ops.gelu_new
@@ -427,7 +429,8 @@ class FastGELU(CustomOp):
or current_platform.is_cpu()
or current_platform.is_xpu()
):
self.op = torch.ops._C.gelu_fast
from vllm import _custom_ops as ops
self.op = ops.gelu_fast
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
@@ -455,7 +458,6 @@ class QuickGELU(CustomOp):
or current_platform.is_cpu()
or current_platform.is_xpu()
):
# self.op = torch.ops._C.gelu_quick
from vllm import _custom_ops as ops
self.op = ops.gelu_quick

View File

@@ -12,7 +12,7 @@ from vllm.config.vllm import VllmConfig
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.kv_transfer_utils import (
maybe_transfer_kv_layer,
maybe_transfer_kv_layer
)
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
@@ -40,6 +40,9 @@ from vllm.v1.kv_cache_interface import (
KVCacheSpec,
SlidingWindowSpec,
)
from .extra_cache import StaticQuantManager
from ixformer.core import config
_USE_TORCH_OPS = config.IXFORMER_USE_TORCH_OPS
if TYPE_CHECKING:
from vllm.model_executor.layers.attention import MLAAttention
@@ -202,6 +205,7 @@ class Attention(nn.Module, AttentionLayerBase):
kv_sharing_target_layer_name: str | None = None,
attn_backend: type[AttentionBackend] | None = None,
head_size_v: int | None = None,
extra_cache_para: dict = None,
**extra_impl_args,
) -> None:
"""
@@ -258,6 +262,7 @@ class Attention(nn.Module, AttentionLayerBase):
self.num_heads = num_heads
self.head_size = head_size
self.hidden_size = head_size * num_heads
self.head_size_v = self.head_size if head_size_v is None else head_size_v
self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window
@@ -326,6 +331,15 @@ class Attention(nn.Module, AttentionLayerBase):
kv_sharing_target_layer_name,
**extra_impl_args,
)
if extra_cache_para is not None:
self.quant_manager = StaticQuantManager(
layer_id=extra_cache_para.get("layer_id", None),
shape=(self.num_kv_heads, self.head_size_v),
dtype=torch.float32,
total_layer_num=extra_cache_para.get("total_layer_num", None)
)
else:
self.quant_manager = None
self.backend = AttentionBackendEnum[self.attn_backend.get_name()]
self.dtype = dtype
@@ -333,7 +347,10 @@ class Attention(nn.Module, AttentionLayerBase):
# torch.compile works by registering the attention as one giant
# opaque custom op. For other platforms, we directly call them
# and let torch.compile handle them.
self.use_direct_call = not current_platform.opaque_attention_op()
if _USE_TORCH_OPS:
self.use_direct_call = False
else:
self.use_direct_call = True
self.use_output = self.attn_backend.accept_output_buffer
compilation_config = vllm_config.compilation_config
@@ -349,14 +366,26 @@ class Attention(nn.Module, AttentionLayerBase):
compilation_config.static_forward_context,
)
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
# use a placeholder kv cache tensor during init, which will be replaced
# by bind_kv_cache
# this variable will not be accessed if use_direct_call is True
self.kv_cache = [
torch.tensor([])
for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
]
self.is_i8qi8ki8v = envs.VLLM_ATTN_OPT_LEVEL == 1
self.is_i8qi8kf16v = envs.VLLM_ATTN_OPT_LEVEL == 2
if self.is_i8qi8kf16v:
self.kv_cache_scale = [
torch.tensor([]) for _ in range(get_current_vllm_config(
).parallel_config.pipeline_parallel_size)
]
elif self.is_i8qi8ki8v:
self.kv_cache_scale = [
[torch.tensor([]), torch.tensor([])] for _ in range(get_current_vllm_config(
).parallel_config.pipeline_parallel_size)
]
# use a placeholder kv cache tensor during init, which will be replaced
# by bind_kv_cache
# this variable will not be accessed if use_direct_call is True
# Initialize KV cache quantization attributes
_init_kv_cache_quant(self, quant_config, prefix)
@@ -396,6 +425,7 @@ class Attention(nn.Module, AttentionLayerBase):
context using
`vllm.forward_context.get_forward_context().attn_metadata`.
"""
optional_args = {}
if self.calculate_kv_scales:
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
output_dtype = query.dtype
@@ -412,15 +442,8 @@ class Attention(nn.Module, AttentionLayerBase):
query, _ = self.query_quant(query, self._q_scale)
if self.use_output:
if output_shape is None:
# Handle both 2D [num_tokens, hidden] and
# 3D [num_tokens, heads, head_dim] query
num_tokens = query.shape[0]
output_shape = torch.Size(
(num_tokens, self.num_heads * self.head_size_v)
)
output_shape = output_shape if output_shape is not None else query.shape
output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
hidden_size = output_shape[-1]
# Reshape the query, key, and value tensors.
# NOTE(woosuk): We do this outside the custom op to minimize the
# CPU overheads from the non-CUDA-graph regions.
@@ -430,46 +453,50 @@ class Attention(nn.Module, AttentionLayerBase):
key = key.view(-1, self.num_kv_heads, self.head_size)
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size_v)
kv_cache_dummy_dep = None
if self.use_direct_call:
# Skip this if sharing KV cache with an earlier attention layer.
if (
not self.attn_backend.forward_includes_kv_cache_update
and self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
kv_cache_dummy_dep = unified_kv_cache_update(
key, value, self.layer_name
def direct_forward(layer_name: str, output: torch.Tensor):
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
# Skip this if sharing KV cache with an earlier attention layer.
if self.is_i8qi8ki8v or self.is_i8qi8kf16v:
optional_args["kv_cache_scale"] = self.kv_cache_scale[forward_context.virtual_engine]
output = self.impl.forward(
self,
query,
key,
value,
self_kv_cache,
attn_metadata,
output=output,
**optional_args
)
unified_attention_with_output(
query,
key,
value,
output,
self.layer_name,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
return output
return maybe_transfer_kv_layer(direct_forward)(self.layer_name, output)
else:
# Skip this if sharing KV cache with an earlier attention layer.
if (
not self.attn_backend.forward_includes_kv_cache_update
and self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
key, value, self.layer_name
)
if self.is_i8qi8ki8v:
forward_context: ForwardContext = get_forward_context()
kv_cache_scale = self.kv_cache_scale[forward_context.virtual_engine][0]
v_cache_scale = self.kv_cache_scale[forward_context.virtual_engine][1]
elif self.is_i8qi8kf16v:
forward_context: ForwardContext = get_forward_context()
kv_cache_scale = self.kv_cache_scale[forward_context.virtual_engine]
v_cache_scale = None
else:
kv_cache_scale = None
v_cache_scale = None
torch.ops.vllm.unified_attention_with_output(
query,
key,
value,
output,
self.layer_name,
kv_cache_dummy_dep=kv_cache_dummy_dep,
kv_cache_scale,
v_cache_scale
)
return output.view(-1, hidden_size)
return output.view(-1, self.hidden_size)
else:
assert self.attn_backend.forward_includes_kv_cache_update, (
"Split KV cache update not supported when output tensor not provided."
@@ -521,6 +548,7 @@ class Attention(nn.Module, AttentionLayerBase):
block_size = vllm_config.cache_config.block_size
# Should not be called for enc-dec or encoder-only attention.
assert self.attn_type == AttentionType.DECODER
# TODO : kernel unsupport kvcache for sliding_window, use FullAttentionSpec replace
if self.sliding_window is not None:
assert not vllm_config.model_config.use_mla, (
"MLA is not supported for slidingwindow"
@@ -689,6 +717,8 @@ def unified_attention_with_output(
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
kv_cache_scale: torch.Tensor | None = None,
v_cache_scale: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
kv_cache_dummy_dep: torch.Tensor | None = None,
@@ -696,9 +726,7 @@ def unified_attention_with_output(
# kv_cache_dummy_dep is not used but accepting it creates a data dependency
# that ensures torch.compile preserves ordering between KV cache update and
# attention forward.
del kv_cache_dummy_dep
attn_metadata, self, kv_cache, _ = get_attention_context(layer_name)
self.impl.forward(
self,
query,
@@ -707,6 +735,7 @@ def unified_attention_with_output(
kv_cache,
attn_metadata,
output=output,
kv_cache_scale = [kv_cache_scale, v_cache_scale] if envs.VLLM_ATTN_OPT_LEVEL==1 else kv_cache_scale,
output_scale=output_scale,
output_block_scale=output_block_scale,
)
@@ -718,6 +747,8 @@ def unified_attention_with_output_fake(
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
kv_cache_scale: torch.Tensor | None = None,
v_cache_scale: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
kv_cache_dummy_dep: torch.Tensor | None = None,

View File

@@ -0,0 +1,131 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import torch
from filelock import FileLock
import vllm.envs as envs
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.logger import init_logger
logger = init_logger(__name__)
class StaticQuantManager:
def __init__(
self,
layer_id: int,
shape: tuple,
dtype: torch.dtype,
total_layer_num: int,
device: str = None,
tp_size: int = None,
tp_rank: int = None,
file_save_path: str = None,
save_step: int = 100,
info_step: int = 100,
):
# update parament
if tp_size is None:
tp_size = get_tensor_model_parallel_world_size()
if tp_rank is None:
tp_rank = get_tensor_model_parallel_rank()
if file_save_path is None:
file_save_path = envs.VLLM_ATTN_STATIC_QUANT_SCALE_FILE_PATH
if device is None:
device = "cuda"
# check parament
if file_save_path in [None, ""]:
self.disable = True
return
para_dir = os.path.dirname(file_save_path)
assert os.path.exists(para_dir), (
f"StaticQuantManager workdir {para_dir} not exist!"
)
self.disable = os.path.exists(file_save_path)
if self.disable:
return
assert layer_id is not None
assert total_layer_num is not None
world_rank = torch.distributed.get_rank()
work_dir = os.path.join(para_dir, "StaticQuantManagerWorkdir")
self.operator = world_rank == 0 and layer_id == 0
if not os.path.exists(work_dir):
if self.operator:
logger.debug(f"StaticQuantManager Creat {work_dir}!")
os.mkdir(work_dir)
self.file_save_path = file_save_path
self.work_dir = work_dir
self.tp_size = tp_size
self.tp_rank = tp_rank
self.world_rank = world_rank
self.layer_id = layer_id
self.total_layer_num = total_layer_num
self.save_step = save_step
self.info_step = info_step
self.update_count = 0
self.save_flag = False
self.scales = torch.zeros(shape, dtype=dtype, device=device)
logger.debug(
f"StaticQuantManager info: world_rank:{self.world_rank} tp_rank:{self.tp_rank} layer_id:{self.layer_id} scale shape:{shape} self.scales:{self.scales.device}"
)
def check_enable(self):
return not self.disable
def update_data(self, data):
if self.disable:
return
self.scales = torch.max(data, self.scales)
# save file
self.update_count += 1
if self.update_count % self.info_step == 0 and self.operator:
logger.info(f"StaticQuantManager run update_data {self.update_count} step")
if self.update_count % self.save_step == 0:
# step1: save to disk
save_file_path = os.path.join(
self.work_dir, f"{self.layer_id}_{self.tp_rank}.pt"
)
lock_file_path = os.path.join(
self.work_dir, f"{self.layer_id}_{self.tp_rank}.lock"
)
lock = FileLock(lock_file_path)
cpu_data = self.scales.cpu()
with lock:
torch.save(cpu_data, save_file_path)
# step2: merge and save
if self.save_flag and self.operator:
save_dict = {}
for idx in range(self.total_layer_num):
tp_datas = []
for tp_rank in range(self.tp_size):
load_file = os.path.join(self.work_dir, f"{idx}_{tp_rank}.pt")
lock_file_path = os.path.join(
self.work_dir, f"{idx}_{tp_rank}.lock"
)
lock = FileLock(lock_file_path)
with lock:
cur_data = torch.load(load_file)
tp_datas.append(cur_data)
layer_data = torch.concat(tp_datas)
save_dict[f"layer_{idx}"] = layer_data
torch.save(save_dict, self.file_save_path)
logger.info(
f"StaticQuantManager save to {self.file_save_path} with {self.update_count} step"
)
self.save_flag = True

File diff suppressed because it is too large Load Diff

View File

@@ -2,21 +2,94 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.utils.math_utils import round_up
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
vit_flashinfer_wrapper,
vit_torch_sdpa_wrapper,
vit_triton_attn_wrapper,
)
import ixformer.contrib.vllm_flash_attn as ops
logger = init_logger(__name__)
# Batch buckets for cuDNN graph caching.
# Graphs use batch size and max sequence length as cache key.
# This avoids creating a new graph for each unique set of
# batch size and max sequence length at runtime.
# From the cuDNN team's performance measurements, there
# is no significant kernel performance difference between padding
# to a smaller batch size/seq length and padding to larger
# ones. The bucketing here is solely used to avoid memory
# operation overhead, which won't be needed if we have CUDA
# graph support in the future.
# TODO: Remove buckets after issue #34763
# (cuda graph support) is addressed.
FLASHINFER_BATCH_BUCKETS = [8, 16, 32, 64]
FLASHINFER_MAX_SEQLEN_BUCKETS = [
1 * 1024,
2 * 1024,
4 * 1024,
8 * 1024,
16 * 1024,
32 * 1024,
64 * 1024,
128 * 1024,
]
# Workspace buffer for FlashInfer CuDNN backend
FLASHINFER_CUDNN_WORKSPACE_SIZE_BYTES = 128 * 1024 * 1024
_flashinfer_workspace_buffer: torch.Tensor | None = None
def _get_flashinfer_workspace_buffer() -> torch.Tensor:
global _flashinfer_workspace_buffer
if _flashinfer_workspace_buffer is None:
_flashinfer_workspace_buffer = torch.zeros(
FLASHINFER_CUDNN_WORKSPACE_SIZE_BYTES,
dtype=torch.uint8,
device="cuda",
)
return _flashinfer_workspace_buffer
def add_padding_to_seqlens(
seq: np.ndarray,
batch_size: int,
padding_value: int,
) -> np.ndarray:
batch_size_padded = next(
(b for b in FLASHINFER_BATCH_BUCKETS if b >= batch_size),
round_up(batch_size, FLASHINFER_BATCH_BUCKETS[0]),
)
if batch_size_padded == batch_size:
return seq
return np.concatenate(
[
seq,
np.full((batch_size_padded - batch_size,), padding_value, dtype=seq.dtype),
]
)
def bucket_flashinfer_max_seqlen(
real_max_seqlen: int,
) -> int:
if real_max_seqlen <= 0:
return FLASHINFER_MAX_SEQLEN_BUCKETS[0]
return next(
(s for s in FLASHINFER_MAX_SEQLEN_BUCKETS if s >= real_max_seqlen),
round_up(real_max_seqlen, FLASHINFER_MAX_SEQLEN_BUCKETS[-1]),
)
# --8<-- [start:mm_encoder_attn]
@CustomOp.register("mm_encoder_attn")
@@ -24,6 +97,67 @@ class MMEncoderAttention(CustomOp):
"""Multi-headed attention without any cache, used for multimodal encoder."""
# --8<-- [end:mm_encoder_attn]
@classmethod
def compute_max_seqlen(
cls,
attn_backend: AttentionBackendEnum,
cu_seqlens: np.ndarray,
) -> int:
max_seqlen = 0
if (
attn_backend
in (
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.TRITON_ATTN,
AttentionBackendEnum.FLASHINFER,
)
and len(cu_seqlens) >= 2
):
max_seqlen = int((cu_seqlens[1:] - cu_seqlens[:-1]).max())
if attn_backend == AttentionBackendEnum.FLASHINFER:
max_seqlen = bucket_flashinfer_max_seqlen(max_seqlen)
return max_seqlen
@classmethod
def maybe_compute_sequence_lengths(
cls,
attn_backend: AttentionBackendEnum,
cu_seqlens: np.ndarray,
) -> np.ndarray | None:
if attn_backend != AttentionBackendEnum.FLASHINFER:
return None
sequence_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
sequence_lengths = add_padding_to_seqlens(
sequence_lengths, len(sequence_lengths), 0
)
return sequence_lengths
@classmethod
def maybe_recompute_cu_seqlens(
cls,
attn_backend: AttentionBackendEnum,
cu_seqlens: np.ndarray,
hidden_size: int,
tp_size: int,
) -> np.ndarray:
if attn_backend != AttentionBackendEnum.FLASHINFER:
return cu_seqlens
batch_size = len(cu_seqlens) - 1
scale = hidden_size // tp_size
cu_seqlens = cu_seqlens * scale
cu_seqlens_qko = cu_seqlens
cu_seqlens_v = cu_seqlens * 3
cu_seqlens_qko = add_padding_to_seqlens(
cu_seqlens_qko, batch_size, cu_seqlens_qko[-1]
)
cu_seqlens_v = add_padding_to_seqlens(
cu_seqlens_v, batch_size, cu_seqlens_v[-1]
)
return np.concatenate([cu_seqlens_qko, cu_seqlens_v])
def __init__(
self,
@@ -46,10 +180,9 @@ class MMEncoderAttention(CustomOp):
self.num_heads = num_heads
self.head_size = head_size
self.scale = scale
self.scale = 1.0 / (head_size**0.5) if scale is None else scale
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.layer_name = prefix
assert self.num_heads % self.num_kv_heads == 0, (
f"num_heads ({self.num_heads}) is not "
f"divisible by num_kv_heads ({self.num_kv_heads})"
@@ -72,9 +205,14 @@ class MMEncoderAttention(CustomOp):
}
self._fa_version = (
get_flash_attn_version() if self.is_flash_attn_backend else None
get_flash_attn_version(head_size=head_size)
if self.is_flash_attn_backend
else None
)
if self.attn_backend == AttentionBackendEnum.FLASHINFER:
_get_flashinfer_workspace_buffer()
logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")
@classmethod
@@ -148,23 +286,27 @@ class MMEncoderAttention(CustomOp):
bsz, q_len = query.size()[:2]
kv_len = key.size(1)
is_reshaped = query.dim() != 4
query = query.view(bsz * q_len, self.num_heads, self.head_size)
key = key.view(bsz * kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz * kv_len, self.num_kv_heads, self.head_size)
query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len)
output = vit_flash_attn_wrapper(
q=query,
k=key,
v=value,
batch_size=bsz,
is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
fa_version=self._fa_version,
scale=self.scale,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
cu_q = torch.tensor([0,] + [q_len for _ in range(bsz)], device=query.device, dtype=torch.int32).cumsum(dim=0, dtype=torch.int32)
cu_kv = torch.tensor([0,] + [kv_len for _ in range(bsz)], device=query.device, dtype=torch.int32).cumsum(dim=0, dtype=torch.int32)
out = ops.flash_attn_varlen_func(
query,
key,
value,
cu_q,
cu_kv,
q_len,
kv_len,
softmax_scale=self.scale,
causal=False,
)
out = out.view(bsz, q_len, self.num_heads, self.head_size)
if is_reshaped:
output = output.reshape(bsz, q_len, -1)
return output
out = out.reshape(bsz, q_len, -1)
return out
def _forward_triton(
self,
@@ -201,6 +343,27 @@ class MMEncoderAttention(CustomOp):
output = output.reshape(bsz, q_len, -1)
return output
def _forward_flashinfer(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None,
sequence_lengths: torch.Tensor
| None = None, # Only used for FlashInfer CuDNN backend
) -> torch.Tensor:
return vit_flashinfer_wrapper(
q=query,
k=key,
v=value,
scale=self.scale,
workspace_buffer=_get_flashinfer_workspace_buffer(),
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
sequence_lengths=sequence_lengths,
)
def forward_native(
self,
query: torch.Tensor,
@@ -208,6 +371,8 @@ class MMEncoderAttention(CustomOp):
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
sequence_lengths: torch.Tensor
| None = None, # Only used for FlashInfer CuDNN backend
) -> torch.Tensor:
return self._forward_sdpa(query, key, value, cu_seqlens)
@@ -218,11 +383,17 @@ class MMEncoderAttention(CustomOp):
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
sequence_lengths: torch.Tensor
| None = None, # Only used for FlashInfer CuDNN backend
) -> torch.Tensor:
if self.is_flash_attn_backend:
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
elif self.attn_backend == AttentionBackendEnum.TRITON_ATTN:
return self._forward_triton(query, key, value, cu_seqlens, max_seqlen)
elif self.attn_backend == AttentionBackendEnum.FLASHINFER:
return self._forward_flashinfer(
query, key, value, cu_seqlens, max_seqlen, sequence_lengths
)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
return self._forward_sdpa(query, key, value, cu_seqlens)
else:
@@ -238,6 +409,8 @@ class MMEncoderAttention(CustomOp):
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
sequence_lengths: torch.Tensor
| None = None, # Only used for FlashInfer CuDNN backend
) -> torch.Tensor:
return self._forward_sdpa(query, key, value, cu_seqlens)
@@ -248,6 +421,8 @@ class MMEncoderAttention(CustomOp):
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
sequence_lengths: torch.Tensor
| None = None, # Only used for FlashInfer CuDNN backend
) -> torch.Tensor:
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)

View File

@@ -7,11 +7,17 @@
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from .chunk import chunk_gated_delta_rule
from .fused_recurrent import fused_recurrent_gated_delta_rule
from .fused_recurrent import (
fused_recurrent_gated_delta_rule,
fused_recurrent_gated_delta_rule_packed_decode,
)
from .fused_sigmoid_gating import fused_sigmoid_gating_delta_rule_update
from .layernorm_guard import RMSNormGated
__all__ = [
"RMSNormGated",
"chunk_gated_delta_rule",
"fused_recurrent_gated_delta_rule",
"fused_recurrent_gated_delta_rule_packed_decode",
"fused_sigmoid_gating_delta_rule_update",
]

View File

@@ -30,7 +30,7 @@ def chunk_gated_delta_rule_fwd(
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
):
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
# obtain WY representation. u is actually the new v.
@@ -84,7 +84,7 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
):
if use_qk_l2norm_in_kernel:
@@ -117,7 +117,7 @@ def chunk_gated_delta_rule(
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
):
r"""
@@ -141,7 +141,7 @@ def chunk_gated_delta_rule(
Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[N, H, V, K]`. Default: `False`.
cu_seqlens (torch.LongTensor):
cu_seqlens (torch.Tensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
Returns:
@@ -171,7 +171,7 @@ def chunk_gated_delta_rule(
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.int32)
>>> o_var, ht_var = chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=h0,

View File

@@ -288,7 +288,7 @@ def chunk_gated_delta_rule_fwd_h(
output_final_state: bool = False,
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
save_new_value: bool = True,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# This kernel is slightly different from fla to support Q/K with different head numbers.
# In fla, Q/K always have the same head number, so Hg is always equal to H.

View File

@@ -89,7 +89,7 @@ def chunk_fwd_kernel_o(
b_o = tl.zeros([BT, BV], dtype=tl.float32)
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(
q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
@@ -145,7 +145,7 @@ def chunk_fwd_o(
h: torch.Tensor,
g: torch.Tensor | None = None, # cumsum of log decay
scale: float | None = None,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
chunk_size: int = 64,
) -> torch.Tensor:
B, T, Hg, K, V = *q.shape, v.shape[-1]

View File

@@ -102,7 +102,7 @@ def chunk_scaled_dot_kkt_fwd(
k: torch.Tensor,
g: torch.Tensor | None = None,
beta: torch.Tensor | None = None,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
chunk_size: int = 64,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
@@ -116,7 +116,7 @@ def chunk_scaled_dot_kkt_fwd(
The beta tensor of shape `[B, T, H]`.
g (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`.
cu_seqlens (torch.LongTensor):
cu_seqlens (torch.Tensor):
The cumulative sequence lengths of the input tensor.
Default: None
chunk_size (int):

View File

@@ -106,12 +106,12 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
else:
i_t = 0
# Load state index and check for PAD_SLOT_ID (-1)
# Load state index and check for invalid entries
state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
tl.int64
)
# Skip if state index is invalid (PAD_SLOT_ID = -1)
if state_idx < 0:
# Skip if state index is invalid (NULL_BLOCK_ID=0)
if state_idx <= 0:
return
p_h0 = h0 + state_idx * stride_init_state_token
else:
@@ -150,12 +150,12 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
# keep the states for multi-query tokens
if INPLACE_FINAL_STATE:
# Load state index and check for PAD_SLOT_ID (-1)
# Load state index and check for invalid entries
final_state_idx = tl.load(
ssm_state_indices + i_n * stride_indices_seq + i_t
).to(tl.int64)
# Only store if state index is valid (not PAD_SLOT_ID)
if final_state_idx >= 0:
# Only store if state index is valid (not NULL_BLOCK_ID=0)
if final_state_idx > 0:
p_ht = ht + final_state_idx * stride_final_state_token
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
@@ -184,7 +184,7 @@ def fused_recurrent_gated_delta_rule_fwd(
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
@@ -252,6 +252,232 @@ def fused_recurrent_gated_delta_rule_fwd(
return o, final_state
@triton.jit
def fused_recurrent_gated_delta_rule_packed_decode_kernel(
mixed_qkv,
a,
b,
A_log,
dt_bias,
o,
h0,
ht,
ssm_state_indices,
scale,
stride_mixed_qkv_tok: tl.constexpr,
stride_a_tok: tl.constexpr,
stride_b_tok: tl.constexpr,
stride_init_state_token: tl.constexpr,
stride_final_state_token: tl.constexpr,
stride_indices_seq: tl.constexpr,
H: tl.constexpr,
HV: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
SOFTPLUS_THRESHOLD: tl.constexpr,
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
):
i_v, i_nh = tl.program_id(0), tl.program_id(1)
i_n, i_hv = i_nh // HV, i_nh % HV
i_h = i_hv // (HV // H)
o_k = tl.arange(0, BK)
o_v = i_v * BV + tl.arange(0, BV)
mask_k = o_k < K
mask_v = o_v < V
mask_h = mask_v[:, None] & mask_k[None, :]
state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq).to(tl.int64)
p_o = o + (i_n * HV + i_hv) * V + o_v
# Skip if state index is invalid (NULL_BLOCK_ID=0)
if state_idx <= 0:
zero = tl.zeros([BV], dtype=tl.float32).to(p_o.dtype.element_ty)
tl.store(p_o, zero, mask=mask_v)
return
p_h0 = h0 + state_idx * stride_init_state_token
p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
b_h = tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
p_mixed = mixed_qkv + i_n * stride_mixed_qkv_tok
q_off = i_h * K + o_k
k_off = (H * K) + i_h * K + o_k
v_off = (2 * H * K) + i_hv * V + o_v
b_q = tl.load(p_mixed + q_off, mask=mask_k, other=0).to(tl.float32)
b_k = tl.load(p_mixed + k_off, mask=mask_k, other=0).to(tl.float32)
b_v = tl.load(p_mixed + v_off, mask=mask_v, other=0).to(tl.float32)
if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
b_q = b_q * scale
a_val = tl.load(a + i_n * stride_a_tok + i_hv).to(tl.float32)
b_val = tl.load(b + i_n * stride_b_tok + i_hv).to(tl.float32)
A_log_val = tl.load(A_log + i_hv).to(tl.float32)
dt_bias_val = tl.load(dt_bias + i_hv).to(tl.float32)
x = a_val + dt_bias_val
softplus_x = tl.where(x <= SOFTPLUS_THRESHOLD, tl.log(1.0 + tl.exp(x)), x)
g_val = -tl.exp(A_log_val) * softplus_x
beta_val = tl.sigmoid(b_val).to(b.dtype.element_ty).to(tl.float32)
b_h *= exp(g_val)
b_v -= tl.sum(b_h * b_k[None, :], 1)
b_v *= beta_val
b_h += b_v[:, None] * b_k[None, :]
b_o = tl.sum(b_h * b_q[None, :], 1)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
p_ht = ht + state_idx * stride_final_state_token
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
def fused_recurrent_gated_delta_rule_packed_decode(
mixed_qkv: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
A_log: torch.Tensor,
dt_bias: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
out: torch.Tensor,
ssm_state_indices: torch.Tensor,
use_qk_l2norm_in_kernel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
if mixed_qkv.ndim != 2:
raise ValueError(
f"`mixed_qkv` must be a 2D tensor (got ndim={mixed_qkv.ndim})."
)
if mixed_qkv.stride(-1) != 1:
raise ValueError("`mixed_qkv` must be contiguous in the last dim.")
if a.ndim != 2 or b.ndim != 2:
raise ValueError(
f"`a` and `b` must be 2D tensors (got a.ndim={a.ndim}, b.ndim={b.ndim})."
)
if a.stride(-1) != 1 or b.stride(-1) != 1:
raise ValueError("`a`/`b` must be contiguous in the last dim.")
if A_log.ndim != 1 or dt_bias.ndim != 1:
raise ValueError("`A_log`/`dt_bias` must be 1D tensors.")
if A_log.stride(0) != 1 or dt_bias.stride(0) != 1:
raise ValueError("`A_log`/`dt_bias` must be contiguous.")
if ssm_state_indices.ndim != 1:
raise ValueError(
f"`ssm_state_indices` must be 1D for packed decode (got ndim={ssm_state_indices.ndim})."
)
if not out.is_contiguous():
raise ValueError("`out` must be contiguous.")
dev = mixed_qkv.device
if (
a.device != dev
or b.device != dev
or A_log.device != dev
or dt_bias.device != dev
or initial_state.device != dev
or out.device != dev
or ssm_state_indices.device != dev
):
raise ValueError("All inputs must be on the same device.")
B = mixed_qkv.shape[0]
if a.shape[0] != B or b.shape[0] != B:
raise ValueError(
"Mismatched batch sizes: "
f"mixed_qkv.shape[0]={B}, a.shape[0]={a.shape[0]}, b.shape[0]={b.shape[0]}."
)
if ssm_state_indices.shape[0] != B:
raise ValueError(
f"`ssm_state_indices` must have shape [B] (got {tuple(ssm_state_indices.shape)}; expected ({B},))."
)
if initial_state.ndim != 4:
raise ValueError(
f"`initial_state` must be a 4D tensor (got ndim={initial_state.ndim})."
)
if initial_state.stride(-1) != 1:
raise ValueError("`initial_state` must be contiguous in the last dim.")
HV, V, K = initial_state.shape[-3:]
if a.shape[1] != HV or b.shape[1] != HV:
raise ValueError(
f"`a`/`b` must have shape [B, HV] with HV={HV} (got a.shape={tuple(a.shape)}, b.shape={tuple(b.shape)})."
)
if A_log.numel() != HV or dt_bias.numel() != HV:
raise ValueError(
f"`A_log` and `dt_bias` must have {HV} elements (got A_log.numel()={A_log.numel()}, dt_bias.numel()={dt_bias.numel()})."
)
if out.shape != (B, 1, HV, V):
raise ValueError(
f"`out` must have shape {(B, 1, HV, V)} (got out.shape={tuple(out.shape)})."
)
qkv_dim = mixed_qkv.shape[1]
qk_dim = qkv_dim - HV * V
if qk_dim <= 0 or qk_dim % 2 != 0:
raise ValueError(
f"Invalid packed `mixed_qkv` last dim={qkv_dim} for HV={HV}, V={V}."
)
q_dim = qk_dim // 2
if q_dim % K != 0:
raise ValueError(f"Invalid packed Q size {q_dim}: must be divisible by K={K}.")
H = q_dim // K
if H <= 0 or HV % H != 0:
raise ValueError(
f"Invalid head config inferred from mixed_qkv: H={H}, HV={HV}."
)
BK = triton.next_power_of_2(K)
if triton.cdiv(K, BK) != 1:
raise ValueError(
f"Packed decode kernel only supports NK=1 (got K={K}, BK={BK})."
)
BV = min(triton.next_power_of_2(V), 32)
num_stages = 3
num_warps = 1
stride_mixed_qkv_tok = mixed_qkv.stride(0)
stride_a_tok = a.stride(0)
stride_b_tok = b.stride(0)
stride_init_state_token = initial_state.stride(0)
stride_final_state_token = initial_state.stride(0)
stride_indices_seq = ssm_state_indices.stride(0)
NV = triton.cdiv(V, BV)
grid = (NV, B * HV)
fused_recurrent_gated_delta_rule_packed_decode_kernel[grid](
mixed_qkv=mixed_qkv,
a=a,
b=b,
A_log=A_log,
dt_bias=dt_bias,
o=out,
h0=initial_state,
ht=initial_state,
ssm_state_indices=ssm_state_indices,
scale=scale,
stride_mixed_qkv_tok=stride_mixed_qkv_tok,
stride_a_tok=stride_a_tok,
stride_b_tok=stride_b_tok,
stride_init_state_token=stride_init_state_token,
stride_final_state_token=stride_final_state_token,
stride_indices_seq=stride_indices_seq,
H=H,
HV=HV,
K=K,
V=V,
BK=BK,
BV=BV,
SOFTPLUS_THRESHOLD=20.0,
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
num_warps=num_warps,
num_stages=num_stages,
)
return out, initial_state
class FusedRecurrentFunction(torch.autograd.Function):
@staticmethod
def forward(
@@ -264,7 +490,7 @@ class FusedRecurrentFunction(torch.autograd.Function):
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
@@ -296,7 +522,7 @@ def fused_recurrent_gated_delta_rule(
scale: float = None,
initial_state: torch.Tensor = None,
inplace_final_state: bool = True,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
@@ -324,7 +550,7 @@ def fused_recurrent_gated_delta_rule(
inplace_final_state: bool:
Whether to store the final state in-place to save memory.
Default: `True`.
cu_seqlens (torch.LongTensor):
cu_seqlens (torch.Tensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
ssm_state_indices (Optional[torch.Tensor]):
@@ -358,7 +584,7 @@ def fused_recurrent_gated_delta_rule(
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
>>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.int32)
>>> o_var, ht_var = fused_gated_recurrent_delta_rule(
q, k, v, g, beta,
initial_state=h0,

View File

@@ -0,0 +1,279 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
import torch
from vllm.triton_utils import tl, triton
@triton.heuristics(
{
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
"IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None,
"IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None,
}
)
@triton.jit(do_not_specialize=["N", "T"])
def fused_sigmoid_gating_delta_rule_update_kernel(
A_log,
a,
b,
dt_bias,
beta,
threshold,
q,
k,
v,
o,
h0,
ht,
cu_seqlens,
ssm_state_indices,
num_accepted_tokens,
scale,
N: tl.int64, # num of sequences
T: tl.int64, # num of tokens
B: tl.constexpr,
H: tl.constexpr,
HV: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
stride_init_state_token: tl.constexpr,
stride_final_state_token: tl.constexpr,
stride_indices_seq: tl.constexpr,
stride_indices_tok: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
IS_VARLEN: tl.constexpr,
IS_CONTINUOUS_BATCHING: tl.constexpr,
IS_SPEC_DECODING: tl.constexpr,
IS_KDA: tl.constexpr,
):
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_n, i_hv = i_nh // HV, i_nh % HV
i_h = i_hv // (HV // H)
if IS_VARLEN:
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int64),
tl.load(cu_seqlens + i_n + 1).to(tl.int64),
)
all = T
T = eos - bos
else:
bos, eos = i_n * T, i_n * T + T
all = B * T
if T == 0:
# no tokens to process for this sequence
return
o_k = i_k * BK + tl.arange(0, BK)
o_v = i_v * BV + tl.arange(0, BV)
p_q = q + (bos * H + i_h) * K + o_k
p_k = k + (bos * H + i_h) * K + o_k
p_v = v + (bos * HV + i_hv) * V + o_v
p_A_log = A_log + i_hv
if not IS_KDA:
p_a = a + bos * HV + i_hv
p_dt_bias = dt_bias + i_hv
else:
p_a = a + (bos * HV + i_hv) * K + o_k
p_dt_bias = dt_bias + i_hv * K + o_k
p_b = b + bos * HV + i_hv
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
mask_k = o_k < K
mask_v = o_v < V
mask_h = mask_v[:, None] & mask_k[None, :]
b_h = tl.zeros([BV, BK], dtype=tl.float32)
if USE_INITIAL_STATE:
if IS_CONTINUOUS_BATCHING:
if IS_SPEC_DECODING:
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
else:
i_t = 0
# Load state index and check for invalid entries
state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
tl.int64
)
# Skip if state index is invalid (NULL_BLOCK_ID=0)
if state_idx <= 0:
return
p_h0 = h0 + state_idx * stride_init_state_token
else:
p_h0 = h0 + bos * HV * V * K
p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
for i_t in range(0, T):
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
b_b = tl.load(p_b).to(tl.float32)
# If the model is loaded in fp16, without the .float() here, A might be -inf
x = tl.load(p_a).to(tl.float32) + tl.load(p_dt_bias).to(tl.float32)
softplus_x = tl.where(
beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x
)
b_g = -tl.exp(tl.load(p_A_log).to(tl.float32)) * softplus_x
# compute beta_output = sigmoid(b)
b_beta = tl.sigmoid(b_b.to(tl.float32))
if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q * (tl.rsqrt(tl.sum(b_q * b_q) + 1e-6))
b_k = b_k * (tl.rsqrt(tl.sum(b_k * b_k) + 1e-6))
b_q = b_q * scale
# [BV, BK]
if not IS_KDA:
b_h *= tl.exp(b_g)
else:
b_h *= tl.exp(b_g[None, :])
# [BV]
b_v -= tl.sum(b_h * b_k[None, :], 1)
b_v *= b_beta
# [BV, BK]
b_h += b_v[:, None] * b_k[None, :]
# [BV]
b_o = tl.sum(b_h * b_q[None, :], 1)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
# keep the states for multi-query tokens
if INPLACE_FINAL_STATE:
# Load state index and check for invalid entries
final_state_idx = tl.load(
ssm_state_indices + i_n * stride_indices_seq + i_t
).to(tl.int64)
# Only store if state index is valid (not NULL_BLOCK_ID=0)
if final_state_idx > 0:
p_ht = ht + final_state_idx * stride_final_state_token
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
else:
p_ht = ht + (bos + i_t) * stride_final_state_token
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
# Update pointers for next timestep
p_q += H * K
p_k += H * K
p_o += HV * V
p_v += HV * V
p_b += HV
p_a += HV
def fused_sigmoid_gating_delta_rule_update(
A_log: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
dt_bias: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
beta: float = 1.0,
threshold: float = 20.0,
scale: float = None,
initial_state: torch.Tensor = None,
inplace_final_state: bool = True,
cu_seqlens: torch.Tensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
is_kda: bool = False,
):
"""
Fused triton implementation of sigmoid gating delta rule update.
This function uses a single fused kernel that combines both sigmoid gating
computation and the recurrent delta rule update for better performance.
"""
B, T, H, K, V = *k.shape, v.shape[-1]
HV = v.shape[2]
N = B if cu_seqlens is None else len(cu_seqlens) - 1
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32)
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
assert NK == 1, "NK > 1 is not supported yet"
num_stages = 3
num_warps = 4
if cu_seqlens is not None and q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]}"
f" when using `cu_seqlens`. Please flatten variable-length"
f" inputs before processing."
)
if scale is None:
scale = k.shape[-1] ** -0.5
else:
assert scale > 0, "scale must be positive"
o = q.new_empty(NK, *v.shape)
if inplace_final_state:
final_state = initial_state
else:
final_state = q.new_empty(T, HV, V, K, dtype=initial_state.dtype)
stride_init_state_token = initial_state.stride(0)
stride_final_state_token = final_state.stride(0)
if ssm_state_indices is None:
stride_indices_seq, stride_indices_tok = 1, 1
elif ssm_state_indices.ndim == 1:
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1
else:
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride()
grid = (NK, NV, N * HV)
fused_sigmoid_gating_delta_rule_update_kernel[grid](
A_log=A_log,
a=a.contiguous(),
b=b.contiguous(),
dt_bias=dt_bias,
beta=beta,
threshold=threshold,
q=q.contiguous(),
k=k.contiguous(),
v=v.contiguous(),
o=o,
h0=initial_state,
ht=final_state,
cu_seqlens=cu_seqlens,
ssm_state_indices=ssm_state_indices,
num_accepted_tokens=num_accepted_tokens,
scale=scale,
N=N,
T=T,
B=B,
H=H,
HV=HV,
K=K,
V=V,
BK=BK,
BV=BV,
stride_init_state_token=stride_init_state_token,
stride_final_state_token=stride_final_state_token,
stride_indices_seq=stride_indices_seq,
stride_indices_tok=stride_indices_tok,
INPLACE_FINAL_STATE=inplace_final_state,
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
IS_KDA=is_kda,
num_warps=num_warps,
num_stages=num_stages,
)
o = o.squeeze(0)
return o, final_state

View File

@@ -15,14 +15,12 @@ from .utils import tensor_cache
@tensor_cache
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
def prepare_lens(cu_seqlens: torch.Tensor) -> torch.Tensor:
return cu_seqlens[1:] - cu_seqlens[:-1]
@tensor_cache
def prepare_chunk_indices(
cu_seqlens: torch.LongTensor, chunk_size: int
) -> torch.LongTensor:
def prepare_chunk_indices(cu_seqlens: torch.Tensor, chunk_size: int) -> torch.Tensor:
indices = torch.cat(
[
torch.arange(n)
@@ -33,9 +31,7 @@ def prepare_chunk_indices(
@tensor_cache
def prepare_chunk_offsets(
cu_seqlens: torch.LongTensor, chunk_size: int
) -> torch.LongTensor:
def prepare_chunk_offsets(cu_seqlens: torch.Tensor, chunk_size: int) -> torch.Tensor:
return torch.cat(
[cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]
).cumsum(-1)

View File

@@ -37,7 +37,7 @@ def fused_recurrent_kda_fwd(
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
@@ -115,7 +115,7 @@ def fused_recurrent_kda(
initial_state: torch.Tensor = None,
inplace_final_state: bool = True,
use_qk_l2norm_in_kernel: bool = True,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
ssm_state_indices: torch.LongTensor | None = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
@@ -692,7 +692,7 @@ def chunk_kda_scaled_dot_kkt_fwd(
gk: torch.Tensor | None = None,
beta: torch.Tensor | None = None,
scale: float | None = None,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
chunk_size: int = 64,
output_dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor]:
@@ -706,7 +706,7 @@ def chunk_kda_scaled_dot_kkt_fwd(
The beta tensor of shape `[B, T, H]`.
gk (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`.
cu_seqlens (torch.LongTensor):
cu_seqlens (torch.Tensor):
The cumulative sequence lengths of the input tensor.
Default: None
chunk_size (int):
@@ -936,7 +936,7 @@ def recompute_w_u_fwd(
A: torch.Tensor,
q: torch.Tensor | None = None,
gk: torch.Tensor | None = None,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
BT = A.shape[-1]
@@ -1104,7 +1104,7 @@ def chunk_gla_fwd_o_gk(
h: torch.Tensor,
o: torch.Tensor,
scale: float,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
chunk_size: int = 64,
):
B, T, H, K, V = *q.shape, v.shape[-1]
@@ -1148,7 +1148,7 @@ def chunk_kda_fwd(
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
):
chunk_size = 64
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens)
@@ -1208,7 +1208,7 @@ def chunk_kda(
initial_state: torch.Tensor = None,
output_final_state: bool = False,
use_qk_l2norm_in_kernel: bool = False,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
**kwargs,
):
if scale is None:

View File

@@ -84,6 +84,7 @@ def layer_norm_fwd_kernel(
HAS_Z: tl.constexpr,
NORM_BEFORE_GATE: tl.constexpr,
IS_RMS_NORM: tl.constexpr,
ACTIVATION: tl.constexpr,
):
# Map the program id to the starting row of X and Y it should compute.
row_start = tl.program_id(0) * ROWS_PER_BLOCK
@@ -112,7 +113,10 @@ def layer_norm_fwd_kernel(
if HAS_Z and not NORM_BEFORE_GATE:
Z_base = Z + rows[:, None] * stride_z_row + col_offsets
z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32)
x *= z * tl.sigmoid(z)
if ACTIVATION == "swish" or ACTIVATION == "silu":
x *= z * tl.sigmoid(z)
elif ACTIVATION == "sigmoid":
x *= tl.sigmoid(z)
# Compute mean and variance per row (reduce along axis 1)
if not IS_RMS_NORM:
@@ -155,7 +159,10 @@ def layer_norm_fwd_kernel(
if HAS_Z and NORM_BEFORE_GATE:
Z_base = Z + rows[:, None] * stride_z_row + col_offsets
z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32)
y *= z * tl.sigmoid(z)
if ACTIVATION == "swish" or ACTIVATION == "silu":
y *= z * tl.sigmoid(z)
elif ACTIVATION == "sigmoid":
y *= tl.sigmoid(z)
# Write output
tl.store(Y_base, y, mask=mask)
@@ -178,6 +185,7 @@ def layer_norm_fwd(
group_size: int = None,
norm_before_gate: bool = True,
is_rms_norm: bool = False,
activation: str = "swish",
):
M, N = x.shape
if group_size is None:
@@ -232,9 +240,12 @@ def layer_norm_fwd(
eps,
BLOCK_N=BLOCK_N,
ROWS_PER_BLOCK=rows_per_block,
HAS_BIAS=bias is not None,
HAS_Z=z is not None,
NORM_BEFORE_GATE=norm_before_gate,
IS_RMS_NORM=is_rms_norm,
num_warps=num_warps,
ACTIVATION=activation,
)
return out, mean, rstd
@@ -252,6 +263,7 @@ class LayerNormFn(torch.autograd.Function):
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
activation: str = "swish",
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
@@ -277,6 +289,7 @@ class LayerNormFn(torch.autograd.Function):
group_size=group_size,
norm_before_gate=norm_before_gate,
is_rms_norm=is_rms_norm,
activation=activation,
)
ctx.save_for_backward(x, weight, bias, mean, rstd, z)
ctx.x_shape_og = x_shape_og
@@ -284,6 +297,7 @@ class LayerNormFn(torch.autograd.Function):
ctx.group_size = group_size
ctx.norm_before_gate = norm_before_gate
ctx.is_rms_norm = is_rms_norm
ctx.activation = activation
return y.reshape(x_shape_og)
@@ -296,17 +310,25 @@ def layernorm_fn(
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
activation: str = "swish",
):
return LayerNormFn.apply(
x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm
x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm, activation
)
def rmsnorm_fn(
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
activation: str = "swish",
):
return LayerNormFn.apply(
x, weight, bias, z, eps, group_size, norm_before_gate, True
x, weight, bias, z, eps, group_size, norm_before_gate, True, activation
)
@@ -359,6 +381,7 @@ class RMSNormGated(nn.Module):
norm_before_gate: bool = False,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
activation: str = "swish",
):
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
@@ -366,6 +389,7 @@ class RMSNormGated(nn.Module):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
self.activation = activation
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.register_parameter("bias", None)
self.group_size = group_size
@@ -385,4 +409,5 @@ class RMSNormGated(nn.Module):
eps=self.eps,
group_size=self.group_size,
norm_before_gate=self.norm_before_gate,
activation=self.activation,
)

View File

@@ -122,7 +122,7 @@ def recompute_w_u_fwd(
beta: torch.Tensor,
g_cumsum: torch.Tensor,
A: torch.Tensor,
cu_seqlens: torch.LongTensor | None,
cu_seqlens: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, Hg, K, V = *k.shape, v.shape[-1]
H = v.shape[-2]

View File

@@ -22,12 +22,13 @@ from vllm.model_executor.layers.fused_moe.layer import (
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat,
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
FusedMoEExpertsModular,
FusedMoEPrepareAndFinalizeModular,
)
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.router.gate_linear import GateLinear
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod,
@@ -61,9 +62,10 @@ __all__ = [
"MoEActivation",
"UnquantizedFusedMoEMethod",
"FusedMoeWeightScaleSupported",
"FusedMoEPermuteExpertsUnpermute",
"FusedMoEExpertsModular",
"FusedMoEActivationFormat",
"FusedMoEPrepareAndFinalize",
"FusedMoEPrepareAndFinalizeModular",
"GateLinear",
"RoutingMethodType",
"SharedFusedMoE",
"ZeroExpertFusedMoE",
@@ -137,4 +139,4 @@ else:
raise NotImplementedError(f"{method} is not implemented as lack of triton.")
fused_topk = lambda *args, **kwargs: _raise_exception("fused_topk")
fused_experts = lambda *args, **kwargs: _raise_exception("fused_experts")
fused_experts = lambda *args, **kwargs: _raise_exception("fused_experts")

View File

@@ -6,8 +6,7 @@ from enum import Enum
import torch
import torch.nn.functional as F
from vllm._custom_ops import silu_and_mul, gelu_and_mul, swigluoai_and_mul
from vllm import _custom_ops as ops
class MoEActivation(Enum):
@@ -114,14 +113,11 @@ def apply_moe_activation(
# Activations with gated multiplication (gate × activation(up))
if activation == MoEActivation.SILU:
# torch.ops._C.silu_and_mul(output, input)
silu_and_mul(output, input)
ops.silu_and_mul(output, input)
elif activation == MoEActivation.GELU:
# torch.ops._C.gelu_and_mul(output, input)
gelu_and_mul(output, input)
ops.gelu_and_mul(output, input)
elif activation == MoEActivation.SWIGLUOAI:
# torch.ops._C.swigluoai_and_mul(output, input)
swigluoai_and_mul(output, input)
ops.swigluoai_and_mul(output, input)
elif activation == MoEActivation.SWIGLUSTEP:
from vllm.model_executor.layers.activation import swiglustep_and_mul_triton

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
@@ -20,20 +21,15 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNaiveEP,
MoEPrepareAndFinalizeNoEP,
make_moe_prepare_and_finalize_naive_dp_ep,
make_moe_prepare_and_finalize_no_dp_ep,
)
from vllm.platforms import current_platform
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
from vllm.utils.import_utils import has_deep_ep, has_mori
logger = init_logger(__name__)
if current_platform.is_cuda_alike():
if has_pplx():
from .pplx_prepare_finalize import (
PplxPrepareAndFinalize,
pplx_hidden_dim_scale_bytes,
)
if has_deep_ep():
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
from .deepep_ll_prepare_finalize import (
@@ -81,6 +77,7 @@ def maybe_make_prepare_finalize(
quant_config: FusedMoEQuantConfig | None,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
allow_new_interface: bool = False,
use_monolithic: bool = False,
) -> FusedMoEPrepareAndFinalize | None:
# NOTE(rob): we are migrating each quant_method to hold the MK
# in all cases. The allow_new_interface=False flag allow us to fall
@@ -106,65 +103,25 @@ def maybe_make_prepare_finalize(
"Detected DP deployment with no --enable-expert-parallel. "
"Falling back to AllGather+ReduceScatter dispatch/combine."
)
return MoEPrepareAndFinalizeNaiveEP(
return make_moe_prepare_and_finalize_naive_dp_ep(
is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel,
num_dispatchers=(
get_ep_group().device_communicator.all2all_manager.world_size
),
use_monolithic=use_monolithic,
)
else:
return MoEPrepareAndFinalizeNoEP()
return make_moe_prepare_and_finalize_no_dp_ep(use_monolithic)
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
prepare_finalize: FusedMoEPrepareAndFinalize | None = None
if moe.use_pplx_kernels:
assert quant_config is not None
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
moe.max_num_tokens,
moe.hidden_dim,
moe.in_dtype,
quant_config.quant_dtype,
per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape,
)
all_to_all_args = dict(
max_num_tokens=moe.max_num_tokens,
num_experts=moe.num_experts,
experts_per_token=moe.experts_per_token, # topk
rank=all2all_manager.rank,
world_size=all2all_manager.world_size,
# dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size,
hidden_dim=moe.hidden_dim,
hidden_dim_bytes=hidden_dim_bytes,
hidden_dim_scale_bytes=hidden_scale_bytes,
)
num_dispatchers = (
all2all_manager.world_size // all2all_manager.tp_group.world_size
)
# Intranode pplx a2a takes a group name while internode does not.
if not all2all_manager.internode:
all_to_all_args["group_name"] = all2all_manager.cpu_group.group_name
handle = all2all_manager.get_handle(all_to_all_args)
prepare_finalize = PplxPrepareAndFinalize(
handle,
max_num_tokens=moe.max_num_tokens,
num_local_experts=moe.num_local_experts,
num_dispatchers=num_dispatchers,
)
elif moe.use_deepep_ht_kernels:
if moe.use_deepep_ht_kernels:
assert moe.dp_size == all2all_manager.dp_world_size
all_to_all_args = dict()
all_to_all_args: dict[str, Any] = dict()
handle = all2all_manager.get_handle(all_to_all_args)
prepare_finalize = DeepEPHTPrepareAndFinalize(
handle,
@@ -246,8 +203,9 @@ def maybe_make_prepare_finalize(
)
elif moe.use_naive_all2all_kernels and allow_new_interface:
prepare_finalize = MoEPrepareAndFinalizeNaiveEP(
is_sequence_parallel=(moe.moe_parallel_config.is_sequence_parallel),
prepare_finalize = make_moe_prepare_and_finalize_naive_dp_ep(
use_monolithic=use_monolithic,
is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel,
num_dispatchers=all2all_manager.world_size,
)

View File

@@ -261,7 +261,7 @@ def persistent_masked_m_silu_mul_quant(
return y_q, y_s
class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
class BatchedDeepGemmExperts(mk.FusedMoEExpertsModular):
def __init__(
self,
moe_config: FusedMoEConfig,

View File

@@ -228,6 +228,7 @@ class FusedMoEQuantConfig:
_a2: FusedMoEQuantDesc
_w1: FusedMoEQuantDesc
_w2: FusedMoEQuantDesc
is_nvfp4_scale_swizzled: bool = True
def __post_init__(self):
assert not self.per_act_token_quant or self.block_shape is None, (
@@ -475,6 +476,7 @@ class FusedMoEQuantConfig:
w1_zp: torch.Tensor | None = None,
w2_zp: torch.Tensor | None = None,
weight_dtype: torch.dtype | str | None = None,
is_nvfp4_scale_swizzled: bool = True,
) -> "FusedMoEQuantConfig":
"""
General builder function for a FusedMoEQuantConfig.
@@ -504,6 +506,7 @@ class FusedMoEQuantConfig:
- w2_bias: Optional biases for w1 (GPT OSS Triton).
- w1_zp: Optional w1 zero points for int4/int8 quantization.
- w2_zp: Optional w2 zero points for int4/int8 quantization.
- is_nvfp4_scale_swizzled: Whether to swizzle the nvfp4 scale swizzling.
"""
assert not isinstance(quant_dtype, str) or quant_dtype in {
"nvfp4",
@@ -536,6 +539,7 @@ class FusedMoEQuantConfig:
_w2=FusedMoEQuantDesc(
weight_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias
),
is_nvfp4_scale_swizzled=is_nvfp4_scale_swizzled,
)
assert quant_config.per_act_token_quant == per_act_token_quant
assert quant_config.per_out_ch_quant == per_out_ch_quant
@@ -737,6 +741,7 @@ def nvfp4_moe_quant_config(
w2_scale: torch.Tensor,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
is_nvfp4_scale_swizzled: bool = True,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for mxfp4 activations and nvp4 weights.
@@ -754,6 +759,7 @@ def nvfp4_moe_quant_config(
per_act_token_quant=False,
per_out_ch_quant=False,
block_shape=None,
is_nvfp4_scale_swizzled=is_nvfp4_scale_swizzled,
)
@@ -939,10 +945,6 @@ class FusedMoEParallelConfig:
def use_all2all_kernels(self):
return self.dp_size > 1 and self.use_ep
@property
def use_pplx_kernels(self):
return self.use_all2all_kernels and self.all2all_backend == "pplx"
@property
def use_deepep_ht_kernels(self):
return (
@@ -962,7 +964,7 @@ class FusedMoEParallelConfig:
@property
def use_batched_activation_format(self):
return self.use_deepep_ll_kernels or self.use_pplx_kernels
return self.use_deepep_ll_kernels
@property
def use_naive_all2all_kernels(self):
@@ -1221,10 +1223,6 @@ class FusedMoEConfig:
def use_ep(self):
return self.moe_parallel_config.use_ep
@property
def use_pplx_kernels(self):
return self.moe_parallel_config.use_pplx_kernels
@property
def use_deepep_ht_kernels(self):
return self.moe_parallel_config.use_deepep_ht_kernels

View File

@@ -0,0 +1,147 @@
{
"triton_version": "3.6.0",
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 2
}
}

View File

@@ -0,0 +1,147 @@
{
"triton_version": "3.6.0",
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 3
}
}

View File

@@ -21,7 +21,7 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
moe_unpermute,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
MoEPrepareAndFinalizeNoDPEPModular,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate,
@@ -166,7 +166,7 @@ def run_cutlass_moe_fp8(
problem_sizes1 = torch.empty((local_E, 3), dtype=torch.int32, device=device)
problem_sizes2 = torch.empty((local_E, 3), dtype=torch.int32, device=device)
ops.get_cutlass_pplx_moe_mm_data(
ops.get_cutlass_batched_moe_mm_data(
expert_offsets,
problem_sizes1,
problem_sizes2,
@@ -262,7 +262,7 @@ def run_cutlass_moe_fp8(
)
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
class CutlassExpertsFp8Base(mk.FusedMoEExpertsModular):
def __init__(
self,
moe_config: FusedMoEConfig,
@@ -661,7 +661,7 @@ def run_cutlass_moe_fp4(
return
class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
class CutlassExpertsFp4(mk.FusedMoEExpertsModular):
"""CUTLASS FP4 fused MoE expert implementation."""
@property
@@ -928,7 +928,7 @@ def run_cutlass_moe_w4a8_fp8(
)
class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
class CutlassExpertsW4A8Fp8(mk.FusedMoEExpertsModular):
def __init__(
self,
out_dtype: torch.dtype | None,
@@ -1170,8 +1170,8 @@ def cutlass_moe_w4a8_fp8(
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0)
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
fn = mk.FusedMoEKernel(
MoEPrepareAndFinalizeNoDPEPModular(),
CutlassExpertsW4A8Fp8(
out_dtype=a.dtype,
a_strides1=a_strides1,
@@ -1186,10 +1186,9 @@ def cutlass_moe_w4a8_fp8(
quant_config=quant_config,
group_size=group_size,
),
inplace=False,
)
return fn(
return fn.apply(
a,
w1_q,
w2_q,

View File

@@ -113,7 +113,7 @@ def _valid_deep_gemm(
return True
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
class DeepGemmExperts(mk.FusedMoEExpertsModular):
"""DeepGemm-based fused MoE expert implementation."""
def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig):

View File

@@ -25,7 +25,7 @@ from vllm.v1.worker.ubatching import (
)
class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
"""
Prepare/Finalize using DeepEP High-Throughput kernels.
"""
@@ -123,7 +123,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
is_token_in_rank,
event,
) = self.buffer.get_dispatch_layout(
topk_idx=rank_topk_ids,
topk_idx=rank_topk_ids.long(),
num_experts=num_experts,
previous_event=previous_event,
async_finish=False,
@@ -148,7 +148,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
is_token_in_rank=is_token_in_rank,
num_tokens_per_expert=dispatch_expert_num_tokens,
topk_idx=rank_topk_ids,
topk_idx=rank_topk_ids.long(),
topk_weights=rank_topk_weights,
# expert_alignment rounds the number of tokens per expert
# to this value.
@@ -169,7 +169,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
event,
has_scales,
token_data,
expert_topk_ids,
expert_topk_ids.int(),
num_experts,
expert_num_tokens_per_expert_list,
expert_topk_weights,
@@ -239,6 +239,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=False,
block_shape=quant_config.block_shape,
is_fp4_scale_swizzled=quant_config.is_nvfp4_scale_swizzled,
)
return (

View File

@@ -49,7 +49,7 @@ def dequant_fp8(
return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.size())
class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
"""
Prepare/Finalize using DeepEP low-latency kernels.
"""
@@ -119,7 +119,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# time. This setting is handled by post_init_setup.
self.use_ue8m0_dispatch = False
def post_init_setup(self, fused_experts: mk.FusedMoEPermuteExpertsUnpermute):
def post_init_setup(self, fused_experts: mk.FusedMoEExperts):
if not fused_experts.supports_packed_ue8m0_act_scales():
# Early exit.
return
@@ -297,12 +297,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
dispatch_topk_ids = self._map_global_to_physical_ids(topk_ids)
expert_x, expert_num_tokens, handle, _, hook = self.buffer.low_latency_dispatch(
a1,
dispatch_topk_ids,
dispatch_topk_ids.long(),
self.max_tokens_per_rank,
num_experts,
use_fp8=self.use_fp8_dispatch,
round_scale=self.use_ue8m0_dispatch,
use_ue8m0=self.use_ue8m0_dispatch,
# round_scale=self.use_ue8m0_dispatch,
# use_ue8m0=self.use_ue8m0_dispatch,
**(dict(use_nvfp4=True) if use_nvfp4 else dict()),
**(
dict(x_global_scale=qc_a1_gscale_or_scale)
@@ -398,7 +398,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
dbo_maybe_run_recv_hook()
_, _, recv_hook = self.buffer.low_latency_combine(
fused_expert_output,
combine_topk_ids,
combine_topk_ids.long(),
combine_topk_weights,
handle,
async_finish=False,

View File

@@ -0,0 +1,335 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
activation_to_flashinfer_int,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8Dynamic128Sym,
kFp8Static128BlockSym,
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
"""
Fp8 TRTLLM-Gen MoE kernels. Supports monolithic interface.
"""
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
super().__init__(moe_config, quant_config)
if moe_config.moe_parallel_config.use_ep and quant_config.is_per_tensor:
raise NotImplementedError(
"EP parallelism is not supported with TRTLLM"
"per-tensor FP8 quantization."
)
self.routing_method_type = moe_config.routing_method
self.topk = moe_config.experts_per_token
self.intermediate_size_per_partition = (
moe_config.intermediate_size_per_partition
)
self.hidden_dim = moe_config.hidden_dim
self.local_num_experts = moe_config.num_local_experts
self.ep_rank = moe_config.moe_parallel_config.ep_rank
# Make additional scales for per-tensor interface.
if self.quant_config.is_per_tensor:
w1_scale = self.quant_config.w1_scale
assert w1_scale is not None
a1_scale = self.quant_config.a1_scale
assert a1_scale is not None
w2_scale = self.quant_config.w2_scale
assert w2_scale is not None
a2_scale = self.quant_config.a2_scale
assert a2_scale is not None
self._g1_alphas = (w1_scale * a1_scale).squeeze()
self._g2_alphas = (w2_scale * a2_scale).squeeze()
self._g1_scale_c = (
self._g1_alphas / self.quant_config.a2_scale
if moe_config.is_act_and_mul
else torch.ones_like(self._g1_alphas) / self.quant_config.a2_scale
)
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
@staticmethod
def _supports_current_device() -> bool:
"""Supports only Blackwell-family GPUs."""
p = current_platform
# Add check flashinfer trtllm is available
return p.is_cuda() and p.is_device_capability_family(100)
@staticmethod
def _supports_no_act_and_mul() -> bool:
"""Does not support non-gated MoE (i.e. Nanotron-3-Nano)."""
return True
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Supports Fp8 per-tensor and Fp8 block."""
SUPPORTED_W_A = [
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
"""Supports only SiLU and RELU^2 non-gated activation."""
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
@staticmethod
def _supports_routing_method(
routing_method: RoutingMethodType,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Monolithic kernels need to express router support."""
# NOTE(dbari): TopK routing could also be enabled, but need to validate models
# NOTE(dbari): Default is not implemented and should not be enabled until it is
if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym):
# NOTE(rob): potentially allow others here. This is a conservative list.
return routing_method in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
]
elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
# NOTE(dbari): as above, potentially allow others here.
return routing_method in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Llama4,
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
]
else:
raise ValueError("Unsupported quantization scheme.")
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
"""Monolithic kernel so only use with naive DP/EP and TP."""
return (
not moe_parallel_config.use_all2all_kernels
or moe_parallel_config.use_naive_all2all_kernels
) and not moe_parallel_config.enable_eplb
@staticmethod
def _supports_router_logits_dtype(
router_logits_dtype: torch.dtype | None,
routing_method: RoutingMethodType,
) -> bool:
"""
The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default.
Only DeepSeekV3 routing supports float32 router_logits (which is converted
internally in the kernel).
"""
if router_logits_dtype == torch.float32:
# Only DeepSeekV3 routing handles float32 logits
# https://github.com/flashinfer-ai/flashinfer/issues/2469
return routing_method == RoutingMethodType.DeepSeekV3
return True
def supports_chunking(self) -> bool:
return False
def supports_expert_map(self) -> bool:
return False
def _apply_per_block(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
apply_router_weight_on_input: bool,
# grouped topk + fused topk bias parameters
num_expert_group: int | None = None,
e_score_correction_bias: torch.Tensor | None = None,
routed_scaling_factor: float | None = None,
topk_group: int | None = None,
) -> torch.Tensor:
# Delay import for non-CUDA.
import flashinfer
assert not apply_router_weight_on_input
assert activation == MoEActivation.SILU
if e_score_correction_bias is not None:
e_score_correction_bias = e_score_correction_bias.to(hidden_states.dtype)
if self.routing_method_type == RoutingMethodType.DeepSeekV3:
router_logits = router_logits.to(torch.float32)
assert self.topk <= global_num_experts
assert self.topk <= 10
assert global_num_experts % 4 == 0
assert self.quant_config.block_shape == [128, 128]
# Routing kernel expects #experts <= #threads 512
assert global_num_experts <= 512
# Kernel requires transposed hidden state scales
# TODO: fuse into the quant kernel.
assert a1q_scale is not None
a1q_scale_t = a1q_scale.t().contiguous()
return flashinfer.fused_moe.trtllm_fp8_block_scale_moe(
routing_logits=router_logits,
routing_bias=e_score_correction_bias,
hidden_states=hidden_states,
hidden_states_scale=a1q_scale_t,
gemm1_weights=w1,
gemm1_weights_scale=self.quant_config.w1_scale,
gemm2_weights=w2,
gemm2_weights_scale=self.quant_config.w2_scale,
num_experts=global_num_experts,
top_k=self.topk,
n_group=(num_expert_group or 0),
topk_group=(topk_group or 0),
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.ep_rank * self.local_num_experts,
local_num_experts=self.local_num_experts,
routed_scaling_factor=routed_scaling_factor,
routing_method_type=self.routing_method_type,
use_shuffled_weight=False,
)
def _apply_per_tensor(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
apply_router_weight_on_input: bool,
# grouped topk + fused topk bias parameters
num_expert_group: int | None = None,
e_score_correction_bias: torch.Tensor | None = None,
routed_scaling_factor: float | None = None,
topk_group: int | None = None,
) -> torch.Tensor:
# Delay import for non-CUDA.
import flashinfer
from flashinfer.fused_moe.core import ActivationType
# Confirm supported activation function.
assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
activation_type = ActivationType(activation_to_flashinfer_int(activation))
# Confirm Llama-4 routing is proper.
if self.routing_method_type == RoutingMethodType.Llama4:
assert apply_router_weight_on_input
else:
assert not apply_router_weight_on_input
# The DeepSeekV3 routing method requires float32 router logits.
if self.routing_method_type == RoutingMethodType.DeepSeekV3:
router_logits = router_logits.to(torch.float32)
out = flashinfer.fused_moe.trtllm_fp8_per_tensor_scale_moe(
routing_logits=router_logits,
routing_bias=e_score_correction_bias,
hidden_states=hidden_states,
gemm1_weights=w1,
output1_scales_scalar=self._g1_scale_c,
output1_scales_gate_scalar=self._g1_alphas,
gemm2_weights=w2,
output2_scales_scalar=self._g2_alphas,
num_experts=global_num_experts,
top_k=self.topk,
n_group=num_expert_group or 0,
topk_group=topk_group or 0,
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.ep_rank * self.local_num_experts,
local_num_experts=self.local_num_experts,
routed_scaling_factor=routed_scaling_factor,
use_routing_scales_on_input=apply_router_weight_on_input,
routing_method_type=self.routing_method_type,
activation_type=activation_type,
)
return out
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
apply_router_weight_on_input: bool,
# grouped topk + fused topk bias parameters
num_expert_group: int | None = None,
e_score_correction_bias: torch.Tensor | None = None,
routed_scaling_factor: float | None = None,
topk_group: int | None = None,
) -> torch.Tensor:
if self.quant_config.block_shape is not None:
return self._apply_per_block(
hidden_states,
w1,
w2,
router_logits,
activation,
global_num_experts,
expert_map,
a1q_scale,
apply_router_weight_on_input,
num_expert_group=num_expert_group,
e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
topk_group=topk_group,
)
elif self.quant_config.is_per_tensor:
return self._apply_per_tensor(
hidden_states,
w1,
w2,
router_logits,
activation,
global_num_experts,
expert_map,
a1q_scale,
apply_router_weight_on_input,
num_expert_group=num_expert_group,
e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
else:
raise NotImplementedError(
"Only per-block and per-tensor quantization are supported in "
f"{self.__class__.__name__}."
)

View File

@@ -0,0 +1,326 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import flashinfer
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
activation_to_flashinfer_int,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kNvfp4Dynamic,
kNvfp4Static,
)
from vllm.platforms import current_platform
class TrtLlmNvFp4ExpertsBase:
"""
NvFp4 TRTLLM-Gen MoE kernels. Supports modular and monolithic interface.
"""
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
self.moe_config = moe_config
self.quant_config = quant_config
self.routing_method_type = self.moe_config.routing_method
self.topk = moe_config.experts_per_token
self.intermediate_size_per_partition = (
moe_config.intermediate_size_per_partition
)
self.hidden_dim = moe_config.hidden_dim
self.local_num_experts = moe_config.num_local_experts
self.ep_rank = moe_config.moe_parallel_config.ep_rank
assert self.quant_config.g1_alphas is not None
assert self.quant_config.a2_gscale is not None
if moe_config.is_act_and_mul:
# g1_alpha_s = a13_scale * w13_scale_2
# a2_gscale = (1 / a2_scale)
# g1_scale_c = a13_scale * w13_scale_2 / a2_scale
self.g1_scale_c = self.quant_config.g1_alphas * self.quant_config.a2_gscale
else:
self.g1_scale_c = (
torch.ones_like(self.quant_config.a1_gscale)
* self.quant_config.a2_gscale
)
@staticmethod
def _supports_current_device() -> bool:
"""Supports only Blackwell-family GPUs."""
p = current_platform
return p.is_cuda() and p.is_device_capability_family(100)
@staticmethod
def _supports_no_act_and_mul() -> bool:
"""Supports non-gated MoE (i.e. Nemotron-Nano)."""
return True
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Supports Nvfp4 quantization."""
SUPPORTED_W_A = [
(kNvfp4Static, kNvfp4Dynamic),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
"""Supports only SiLU and RELU^2 non-gated activation."""
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
@staticmethod
def _supports_shape(hidden_dim: int) -> bool:
"""Requires hidden dim to be multiple of 512."""
return hidden_dim % 512 == 0
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def supports_chunking(self) -> bool:
return False
def supports_expert_map(self) -> bool:
return False
class TrtLlmNvFp4ExpertsModular(TrtLlmNvFp4ExpertsBase, mk.FusedMoEExpertsModular):
"""
Modular version of the implementation (just the experts).
"""
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
"""The modular implementation supports all parallel configs."""
return True
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# The workspaces for this implementation are managed by flashinfer.
workspace1 = (0,)
workspace2 = (0,)
# Hidden states are Nvfp4, packed into int8 dtype, so we
# need to multiply K by 2 to get the output shape right.
assert self.hidden_dim == K * 2
output = (M, self.hidden_dim)
return (workspace1, workspace2, output)
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
assert a1q_scale is not None
assert self.quant_config.w1_scale is not None
assert self.quant_config.w2_scale is not None
# Pack topk ids and weights into format expected by the kernel.
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
torch.bfloat16
).view(torch.int16)
# trtllm_fp4_block_scale_routed_moe does not support autotuning
# so skip this kernel during dummy run for autotuning.
import vllm.utils.flashinfer as fi_utils
if fi_utils._is_fi_autotuning:
return hidden_states
# Invoke kernel.
flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(
topk_ids=packed_tensor,
routing_bias=None,
hidden_states=hidden_states,
hidden_states_scale=a1q_scale.view(torch.float8_e4m3fn).reshape(
*hidden_states.shape[:-1], -1
),
gemm1_weights=w1,
gemm1_weights_scale=self.quant_config.w1_scale.view(torch.float8_e4m3fn),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=w2,
gemm2_weights_scale=self.quant_config.w2_scale.view(torch.float8_e4m3fn),
gemm2_bias=None,
output1_scale_scalar=self.g1_scale_c,
output1_scale_gate_scalar=self.quant_config.g1_alphas,
output2_scale_scalar=self.quant_config.g2_alphas,
num_experts=global_num_experts,
top_k=self.topk,
n_group=0,
topk_group=0,
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.ep_rank * self.local_num_experts,
local_num_experts=self.local_num_experts,
routed_scaling_factor=None,
routing_method_type=1,
do_finalize=True,
activation_type=activation_to_flashinfer_int(activation),
output=output,
)
class TrtLlmNvFp4ExpertsMonolithic(
TrtLlmNvFp4ExpertsBase, mk.FusedMoEExpertsMonolithic
):
"""
Monolithic version of the kernel (router + experts).
"""
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
"""The modular implementation should be used for the Dp/Ep or EPLB case."""
return (
not moe_parallel_config.use_all2all_kernels
and not moe_parallel_config.enable_eplb
)
@staticmethod
def _supports_routing_method(
routing_method_type: RoutingMethodType,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
# NOTE(rob): this is a conservative list.
return routing_method_type in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
RoutingMethodType.Llama4,
]
@staticmethod
def _supports_router_logits_dtype(
router_logits_dtype: torch.dtype | None,
routing_method: RoutingMethodType,
) -> bool:
"""
The FlashInfer TRTLLM NvFp4 kernel expects bfloat16 router_logits by default.
Only DeepSeekV3 routing supports float32 router_logits (which is converted
internally in the kernel).
"""
if router_logits_dtype == torch.float32:
# Only DeepSeekV3 routing handles float32 logits
# https://github.com/flashinfer-ai/flashinfer/issues/2469
return routing_method == RoutingMethodType.DeepSeekV3
return True
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
apply_router_weight_on_input: bool,
# grouped topk + fused topk bias parameters
num_expert_group: int | None = None,
e_score_correction_bias: torch.Tensor | None = None,
routed_scaling_factor: float | None = None,
topk_group: int | None = None,
) -> torch.Tensor:
assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
assert a1q_scale is not None
assert self.quant_config.w1_scale is not None
assert self.quant_config.w2_scale is not None
assert (
apply_router_weight_on_input
and self.routing_method_type == RoutingMethodType.Llama4
) or (
not apply_router_weight_on_input
and self.routing_method_type != RoutingMethodType.Llama4
)
# Prepare routing bias into kernel format.
routing_bias = e_score_correction_bias
if routing_bias is not None:
routing_bias = routing_bias.to(torch.bfloat16)
router_logits = (
router_logits.to(torch.float32)
if self.routing_method_type == RoutingMethodType.DeepSeekV3
else router_logits
)
# Invoke kernel.
return flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
routing_logits=router_logits,
routing_bias=routing_bias,
hidden_states=hidden_states,
hidden_states_scale=a1q_scale.view(torch.float8_e4m3fn).reshape(
*hidden_states.shape[:-1], -1
),
gemm1_weights=w1,
gemm1_weights_scale=self.quant_config.w1_scale.view(torch.float8_e4m3fn),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=w2,
gemm2_weights_scale=self.quant_config.w2_scale.view(torch.float8_e4m3fn),
gemm2_bias=None,
output1_scale_scalar=self.g1_scale_c,
output1_scale_gate_scalar=self.quant_config.g1_alphas,
output2_scale_scalar=self.quant_config.g2_alphas,
num_experts=global_num_experts,
top_k=self.topk,
n_group=(num_expert_group or 0),
topk_group=(topk_group or 0),
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.ep_rank * self.local_num_experts,
local_num_experts=self.local_num_experts,
routed_scaling_factor=routed_scaling_factor,
routing_method_type=self.routing_method_type,
do_finalize=True,
)[0]

View File

@@ -11,13 +11,13 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
class FallbackExperts(mk.FusedMoEExpertsModular, ABC):
"""Base class for runtime dispatching of expert implementations."""
def __init__(
self,
experts: mk.FusedMoEPermuteExpertsUnpermute,
fallback_experts: mk.FusedMoEPermuteExpertsUnpermute,
experts: mk.FusedMoEExpertsModular,
fallback_experts: mk.FusedMoEExpertsModular,
):
super().__init__(
moe_config=experts.moe_config, quant_config=experts.quant_config
@@ -27,8 +27,8 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
@staticmethod
def get_clses() -> tuple[
type[mk.FusedMoEPermuteExpertsUnpermute],
type[mk.FusedMoEPermuteExpertsUnpermute],
type[mk.FusedMoEExpertsModular],
type[mk.FusedMoEExpertsModular],
]:
"""
Get the cls for the experts and fallback experts.
@@ -149,7 +149,7 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
) -> mk.FusedMoEPermuteExpertsUnpermute:
) -> mk.FusedMoEExpertsModular:
raise NotImplementedError
def apply(

View File

@@ -18,7 +18,7 @@ def get_local_sizes():
return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
class FlashInferA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
class FlashInferA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
"""Base class for FlashInfer MoE prepare and finalize operations."""
def __init__(
@@ -185,8 +185,8 @@ def flashinfer_alltoall_dispatch(
ep_size,
)
# Swizzle after the A2A if nvfp4.
if quant_config.quant_dtype == "nvfp4":
# Swizzle after the A2A if MoE kernel expects swizzled scales.
if quant_config.quant_dtype == "nvfp4" and quant_config.is_nvfp4_scale_swizzled:
if x_sf.element_size() == 1:
x_sf = x_sf.view(torch.uint8)
x_sf = nvfp4_block_scale_interleave(x_sf)

View File

@@ -30,7 +30,7 @@ from vllm.utils.flashinfer import (
logger = init_logger(__name__)
class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
def __init__(
self,
moe_config: FusedMoEConfig,

View File

@@ -60,7 +60,7 @@ def is_valid_flashinfer_cutlass_fused_moe(
return True
class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
class FlashInferExperts(mk.FusedMoEExpertsModular):
def __init__(
self,
moe_config: mk.FusedMoEConfig,

View File

@@ -10,16 +10,6 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEParallelConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8Dynamic128Sym,
kFp8Static128BlockSym,
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
@@ -39,49 +29,10 @@ def _supports_no_act_and_mul() -> bool:
return True
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Supports Fp8 per-tensor and Fp8 block."""
SUPPORTED_W_A = [
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
]
return (weight_key, activation_key) in SUPPORTED_W_A
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
def _supports_routing_method(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
routing_method: RoutingMethodType,
) -> bool:
"""Monolithic kernels need to express router support."""
# NOTE(dbari): TopK routing could also be enabled, but need to validate models
# NOTE(dbari): Default is not implemented and should not be enabled until it is
if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym):
# NOTE(rob): potentially allow others here. This is a conservative list.
return routing_method in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
]
elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
# NOTE(dbari): as above, potentially allow others here.
return routing_method in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Llama4,
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
]
else:
raise ValueError("Unsupported quantization scheme.")
def _supports_routing_method_bf16(
routing_method: RoutingMethodType,
) -> bool:
@@ -99,62 +50,6 @@ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bo
return not moe_parallel_config.enable_eplb
def _supports_router_logits_dtype(
router_logits_dtype: torch.dtype | None,
routing_method: RoutingMethodType,
) -> bool:
"""
The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default.
Only DeepSeekV3 routing supports float32 router_logits (which is converted
internally in the kernel).
"""
if router_logits_dtype == torch.float32:
# Only DeepSeekV3 routing handles float32 logits
# https://github.com/flashinfer-ai/flashinfer/issues/2469
return routing_method == RoutingMethodType.DeepSeekV3
return True
def is_supported_config_trtllm_fp8(
moe_config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat,
) -> tuple[bool, str | None]:
"""
This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config
"""
def _make_reason(reason: str) -> str:
return f"kernel does not support {reason}"
if not _supports_current_device():
return False, _make_reason(f"current device {current_platform.device_name}")
elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()):
return False, _make_reason("no act_and_mul MLP layer")
elif not _supports_activation(moe_config.activation):
return False, _make_reason(f"{moe_config.activation} activation")
elif not _supports_quant_scheme(weight_key, activation_key):
return False, _make_reason(f"quantization scheme {weight_key}x{activation_key}")
elif not _supports_parallel_config(moe_config.moe_parallel_config):
return False, _make_reason(f"parallel config {moe_config.moe_parallel_config}")
elif not _supports_routing_method(
weight_key, activation_key, moe_config.routing_method
):
return False, _make_reason(f"routing method {moe_config.routing_method}")
elif activation_format != mk.FusedMoEActivationFormat.Standard:
return False, _make_reason(f"activation format {activation_format}")
elif not _supports_router_logits_dtype(
moe_config.router_logits_dtype, moe_config.routing_method
):
return False, _make_reason(
"float32 router_logits with non-DeepSeekV3 routing "
f"{moe_config.router_logits_dtype}x{moe_config.routing_method}"
)
return True, None
def is_supported_config_trtllm_bf16(
moe_config: FusedMoEConfig,
activation_format: mk.FusedMoEActivationFormat,
@@ -183,199 +78,6 @@ def is_supported_config_trtllm_bf16(
return True, None
def flashinfer_fused_moe_blockscale_fp8(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor | None,
x: torch.Tensor,
w13_weight: torch.Tensor,
w13_weight_scale_inv: torch.Tensor,
w2_weight: torch.Tensor,
w2_weight_scale_inv: torch.Tensor,
global_num_experts: int,
top_k: int,
num_expert_group: int | None,
topk_group: int | None,
intermediate_size: int,
expert_offset: int,
local_num_experts: int,
block_shape: list[int],
routing_method_type: int,
routed_scaling: float | None = 1.0,
) -> torch.Tensor:
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
num_expert_group = num_expert_group if num_expert_group is not None else 0
topk_group = topk_group if topk_group is not None else 0
assert top_k <= global_num_experts
assert top_k <= 10
assert global_num_experts % 4 == 0
assert block_shape == [128, 128]
# Routing kernel expects #experts <= #threads 512
assert global_num_experts <= 512
# The DeepSeekV3 routing method requires float32 router logits.
if routing_method_type == RoutingMethodType.DeepSeekV3:
routing_logits = routing_logits.to(torch.float32)
if routing_bias is not None:
routing_bias = routing_bias.to(x.dtype)
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
# NOTE: scales of hidden states have to be transposed!
a_sf_t = a_sf.t().contiguous()
return flashinfer_trtllm_fp8_block_scale_moe(
routing_logits=routing_logits,
routing_bias=routing_bias,
hidden_states=a_q,
hidden_states_scale=a_sf_t,
gemm1_weights=w13_weight,
gemm1_weights_scale=w13_weight_scale_inv,
gemm2_weights=w2_weight,
gemm2_weights_scale=w2_weight_scale_inv,
num_experts=global_num_experts,
top_k=top_k,
n_group=num_expert_group,
topk_group=topk_group,
intermediate_size=intermediate_size,
local_expert_offset=expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling,
routing_method_type=routing_method_type,
use_shuffled_weight=False,
)
def flashinfer_fused_moe_blockscale_fp8_fake(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor | None,
x: torch.Tensor,
w13_weight: torch.Tensor,
w13_weight_scale_inv: torch.Tensor,
w2_weight: torch.Tensor,
w2_weight_scale_inv: torch.Tensor,
global_num_experts: int,
top_k: int,
num_expert_group: int,
topk_group: int,
intermediate_size: int,
expert_offset: int,
local_num_experts: int,
block_shape: list[int],
routing_method_type: int,
routed_scaling: float = 1.0,
) -> torch.Tensor:
return torch.empty_like(x)
# TODO(bnell): Does this really need to be a torch.op?
direct_register_custom_op(
op_name="flashinfer_fused_moe_blockscale_fp8",
op_func=flashinfer_fused_moe_blockscale_fp8,
fake_impl=flashinfer_fused_moe_blockscale_fp8_fake,
tags=(torch.Tag.needs_fixed_stride_order,),
)
def fi_trtllm_fp8_per_tensor_moe(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor | None,
hidden_states: torch.Tensor,
input_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm2_weights: torch.Tensor,
output1_scales_scalar: torch.Tensor,
output1_scales_gate_scalar: torch.Tensor,
output2_scales_scalar: torch.Tensor,
num_experts: int,
top_k: int,
num_expert_group: int | None,
topk_group: int | None,
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
use_routing_scales_on_input: bool,
routing_method_type: int,
activation_type: int,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor:
num_expert_group = num_expert_group if num_expert_group is not None else 0
topk_group = topk_group if topk_group is not None else 0
quant_hidden_states, _ = moe_kernel_quantize_input(
hidden_states,
input_scale,
quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=False,
)
from flashinfer.fused_moe.core import ActivationType
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_per_tensor_scale_moe
# The DeepSeekV3 routing method requires float32 router logits.
if routing_method_type == RoutingMethodType.DeepSeekV3:
routing_logits = routing_logits.to(torch.float32)
return flashinfer_trtllm_fp8_per_tensor_scale_moe(
routing_logits=routing_logits,
routing_bias=routing_bias,
hidden_states=quant_hidden_states,
gemm1_weights=gemm1_weights,
output1_scales_scalar=output1_scales_scalar,
output1_scales_gate_scalar=output1_scales_gate_scalar,
gemm2_weights=gemm2_weights,
output2_scales_scalar=output2_scales_scalar,
num_experts=num_experts,
top_k=top_k,
n_group=num_expert_group,
topk_group=topk_group,
intermediate_size=intermediate_size,
local_expert_offset=local_expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling_factor,
use_routing_scales_on_input=use_routing_scales_on_input,
routing_method_type=routing_method_type,
# TODO: enum type Required for flashinfer==0.6.3, remove with update
# https://github.com/flashinfer-ai/flashinfer/pull/2508
activation_type=ActivationType(activation_type),
)
def fi_trtllm_fp8_per_tensor_moe_fake(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor | None,
hidden_states: torch.Tensor,
input_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm2_weights: torch.Tensor,
output1_scales_scalar: torch.Tensor,
output1_scales_gate_scalar: torch.Tensor,
output2_scales_scalar: torch.Tensor,
num_experts: int,
top_k: int,
num_expert_group: int | None,
topk_group: int | None,
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
use_routing_scales_on_input: bool,
routing_method_type: int,
activation_type: int,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
# TODO(bnell): Does this really need to be a torch.op?
direct_register_custom_op(
op_name="fi_trtllm_fp8_per_tensor_moe",
op_func=fi_trtllm_fp8_per_tensor_moe,
mutates_args=["hidden_states"],
fake_impl=fi_trtllm_fp8_per_tensor_moe_fake,
tags=(torch.Tag.needs_fixed_stride_order,),
)
def flashinfer_fused_moe_bf16(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor | None,

View File

@@ -489,11 +489,11 @@ def invoke_moe_batched_triton_kernel(
)
class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
"""
A reference prepare/finalize class that reorganizes the tokens into
expert batched format, i.e. E x max_num_tokens x K. This is the format
that the PPLX dispatch/combine kernels use.
that the batched dispatch/combine kernels use.
"""
def __init__(
@@ -645,10 +645,10 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
)
class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
class NaiveBatchedExperts(mk.FusedMoEExpertsModular):
"""
A reference MoE expert class that operates on expert batched format,
i.e. E x max_num_tokens x K. This is the format that the pplx
i.e. E x max_num_tokens x K. This is the format that the batched
dispatch/combine kernels use.
"""
@@ -877,10 +877,10 @@ def batched_moe_kernel_quantize_input(
return A_q, A_q_scale
class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
class BatchedTritonExperts(mk.FusedMoEExpertsModular):
"""
A Triton based MoE expert class that operates on expert batched format,
i.e. E x max_num_tokens x K. This is the format that the pplx
i.e. E x max_num_tokens x K. This is the format that the batched
dispatch/combine kernels use.
"""

View File

@@ -526,7 +526,7 @@ def batched_fused_marlin_moe(
return output
class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
class MarlinExpertsBase(mk.FusedMoEExpertsModular):
def __init__(
self,
moe_config: FusedMoEConfig,

View File

@@ -53,7 +53,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op
import vllm._custom_ops as ops
import ixformer.inference.functions as ixfops
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.distributed import get_ep_group
logger = init_logger(__name__)
@@ -575,56 +578,6 @@ def fused_moe_kernel(
tl.store(c_ptrs, accumulator, mask=c_mask)
def invoke_fused_moe_kernel(
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
A_scale: torch.Tensor | None,
B_scale: torch.Tensor | None,
B_zp: torch.Tensor | None,
topk_weights: torch.Tensor | None,
topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool,
top_k: int,
config: dict[str, Any],
compute_type: tl.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: list[int] | None = None,
B_bias: torch.Tensor | None = None,
) -> None:
assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
ops.invoke_fused_moe_kernel(
A,
B,
C,
A_scale,
B_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
mul_routed_weight,
top_k,
config,
compute_type,
use_fp8_w8a8,
use_int8_w8a16,
block_shape,
)
# ops.invoke_fused_moe_kernel(A,B,C,A_scale,B_scale,topk_weights,topk_ids,sorted_token_ids,expert_ids,num_tokens_post_padded,mul_routed_weight,top_k,config,compute_type,use_fp8_w8a8,use_int8_w8a16,block_shape,B_bias)
return
# NOTE(zyongye): we can remove all the wna16 kernel
# once we drop off sm75 support
def invoke_fused_moe_wna16_cuda_kernel(
@@ -782,6 +735,7 @@ def invoke_fused_moe_triton_kernel(
A_scale: torch.Tensor | None,
B_scale: torch.Tensor | None,
topk_weights: torch.Tensor | None,
topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor | None,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
@@ -799,7 +753,9 @@ def invoke_fused_moe_triton_kernel(
):
assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1
assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1
assert sorted_token_ids.stride(0) == 1
ops.invoke_fused_moe_kernel(A,B,C,A_scale,B_scale,topk_weights,topk_ids,sorted_token_ids,expert_ids,num_tokens_post_padded,mul_routed_weight,top_k,config,compute_type,use_fp8_w8a8,use_int8_w8a16,block_shape,B_bias)
return
if use_fp8_w8a8 or use_int8_w8a8:
assert B_scale is not None
@@ -910,32 +866,6 @@ def dispatch_fused_moe_kernel(
block_shape: list[int] | None = None,
B_bias: torch.Tensor | None = None,
) -> None:
invoke_fused_moe_kernel(
A,
B,
C,
A_scale,
B_scale,
B_zp,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
mul_routed_weight,
top_k,
config,
compute_type,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
use_int4_w4a16,
per_channel_quant,
block_shape,
B_bias
)
return
assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1
assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1
@@ -999,6 +929,7 @@ def dispatch_fused_moe_kernel(
A_scale,
B_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
@@ -1397,14 +1328,13 @@ def get_default_config(
"num_warps": num_warps,
"num_stages": num_stages,
}
# TODO
numel = M * topk
if numel <= 64:
config["BLOCK_SIZE_M"] = 32
config['BLOCK_SIZE_M'] = 32
elif numel <= 1024:
config["BLOCK_SIZE_M"] = 64
config['BLOCK_SIZE_M'] = 64
else:
config["BLOCK_SIZE_M"] = 256
config['BLOCK_SIZE_M'] = 256
return config
@@ -1424,14 +1354,12 @@ def try_get_optimal_moe_config(
else:
# First try to load optimal config from the file
E, _, N = w2_shape
if dtype == "int4_w4a16":
N = N * 2
block_n = block_shape[0] if block_shape else 0
block_k = block_shape[1] if block_shape else 0
configs = get_moe_configs(E, N, dtype, block_n, block_k)
# block_n = block_shape[0] if block_shape else 0
# block_k = block_shape[1] if block_shape else 0
# configs = get_moe_configs(E, N, dtype, block_n, block_k)
configs = None
if configs:
# If an optimal configuration map has been found, look up the
# optimal config
@@ -1560,13 +1488,12 @@ def outplace_fused_experts(
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> torch.Tensor:
return fused_experts_impl(
return fused_experts_impl_opt(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
False,
activation,
apply_router_weight_on_input,
use_fp8_w8a8,
@@ -1626,14 +1553,12 @@ direct_register_custom_op(
def torch_vllm_inplace_fused_experts(**kwargs) -> torch.Tensor:
# torch.ops.vllm.inplace_fused_experts(**kwargs)
inplace_fused_experts(**kwargs)
hidden_states = kwargs["hidden_states"]
hidden_states = kwargs['hidden_states']
return hidden_states
def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
# return torch.ops.vllm.outplace_fused_experts(**kwargs)
return outplace_fused_experts(**kwargs)
@@ -1661,7 +1586,6 @@ def fused_experts(
"""Run fused MoE expert computation using Triton kernels."""
if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
assert not inplace or not disable_inplace()
return dispatch_fused_experts_func(inplace)(
@@ -1691,6 +1615,245 @@ def fused_experts(
w2_bias=quant_config.w2_bias,
)
def fused_experts_impl_opt(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
ocp_mx_scheme: str | None = None,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
w1_zp: torch.Tensor | None = None,
w2_zp: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
block_shape: torch.Tensor | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
output: torch.Tensor | None = None
) -> torch.Tensor:
# check constraints
if use_fp8_w8a8 or use_int8_w8a8 or use_int8_w8a16 or use_int4_w4a16 or w1_scale or \
w2_scale or w1_zp or w2_zp or a1_scale or a2_scale:
raise ValueError("Quantized MoE is not supported")
attn_metadata = get_forward_context().attn_metadata
use_ep = expert_map is not None
# unsupported ep now
if attn_metadata:
only_decode = (use_ep == False and all(t.num_decodes > 0 and t.num_prefills ==0 for t in list(attn_metadata.values())))
else:
only_decode = False
assert topk_weights.size() == topk_ids.size(), "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
num_tokens = hidden_states.size(0)
num_experts = w1.size(0)
top_k = topk_weights.size(1)
if use_ep:
local_num_experts = w1.size(0)
start_eid = get_ep_group().device_group.rank() * local_num_experts
end_eid = min((get_ep_group().device_group.rank() + 1) * local_num_experts, global_num_experts)
hidden_size = hidden_states.shape[1]
(
src_to_dst,
sorted_token_ids,
expert_sizes_gpu,
expert_sizes_cpu,
expand_tokens,
) = ixfops.moe_compute_token_index_ep(
topk_ids=topk_ids,
num_experts=global_num_experts,
start_expert_id=start_eid,
end_expert_id=end_eid,
)
if expert_sizes_cpu.sum() == 0:
return torch.zeros(
(num_tokens, hidden_size),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
else:
expand_tokens = num_tokens * top_k
(
src_to_dst,
sorted_token_ids,
expert_sizes_gpu,
expert_sizes_cpu,
) = ixfops.moe_compute_token_index(
topk_ids=topk_ids,
num_experts=num_experts,
)
if only_decode:
# expand + reorder
hidden_states = ixfops.moe_expand_input(
hidden_states=hidden_states,
dst_to_src=sorted_token_ids,
dst_tokens=expand_tokens,
topk=top_k,
src_to_dst=src_to_dst,
)
# group gemm 1
pt_output_1 = ixfops.moe_w16a16_group_gemv(
input=hidden_states,
weight=w1,
output_dtype=hidden_states.dtype,
tokens_per_experts_gpu=expert_sizes_gpu,
dst_to_src=None,
bias=w1_bias,
format="TN",
)
# act
if activation == "silu":
pt_output_2 = ixfops.silu_and_mul(pt_output_1)
elif activation == "gelu":
pt_output_2 = ixfops.gelu_and_mul(pt_output_1)
elif activation == "swigluoai":
pt_output_2 = ixfops.swigluoai_and_mul(pt_output_1)
elif activation == "swiglustep":
from vllm.model_executor.layers.activation import swiglustep_and_mul_triton
output_dim = pt_output_1.shape[1]
pt_output_2 = torch.empty(
(num_tokens * top_k, output_dim//2),
device=pt_output_1.device,
dtype=pt_output_1.dtype,
)
swiglustep_and_mul_triton(pt_output_2, pt_output_1)
else:
raise ValueError(f"Unsupported activation: {activation}")
# group gemm 2 + reorder
pt_output_3 = ixfops.moe_w16a16_group_gemv(
input=pt_output_2,
weight=w2,
output_dtype=hidden_states.dtype,
tokens_per_experts_gpu=expert_sizes_gpu,
dst_to_src=sorted_token_ids,
bias=w2_bias,
format="TN",
)
# mul + reduce_sum
final_hidden_states = ixfops.moe_output_reduce_sum(
input=pt_output_3.view(num_tokens, top_k, -1),
topk_weight=topk_weights,
)
else:
expert_sizes_cpu = expert_sizes_gpu.cpu()
# expand + reorder
hidden_states = ixfops.moe_expand_input(
hidden_states=hidden_states,
dst_to_src=sorted_token_ids,
dst_tokens=expand_tokens,
topk=top_k,
src_to_dst=src_to_dst,
)
# group gemm 1
pt_output_1 = ixfops.moe_w16a16_group_gemm(
input=hidden_states,
weight=w1,
output_dtype=hidden_states.dtype,
tokens_per_experts=expert_sizes_cpu,
dst_to_src=None,
bias=w1_bias,
format="TN",
)
# act
if activation == "silu":
pt_output_2 = ixfops.silu_and_mul(pt_output_1)
elif activation == "gelu":
pt_output_2 = ixfops.gelu_and_mul(pt_output_1)
elif activation == "swigluoai":
pt_output_2 = ixfops.swigluoai_and_mul(pt_output_1)
elif activation == "swiglustep":
from vllm.model_executor.layers.activation import swiglustep_and_mul_triton
output_dim = pt_output_1.shape[1]
pt_output_2 = torch.empty(
(num_tokens * top_k, output_dim//2),
device=pt_output_1.device,
dtype=pt_output_1.dtype,
)
swiglustep_and_mul_triton(pt_output_2, pt_output_1)
else:
raise ValueError(f"Unsupported activation: {activation}")
if use_ep:
pt_output_3 = torch.empty(
(num_tokens * top_k, hidden_size),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
# group gemm 2 + reorder
pt_output_3 = ixfops.moe_w16a16_group_gemm(
input=pt_output_2,
weight=w2,
output_dtype=hidden_states.dtype,
tokens_per_experts=expert_sizes_cpu,
dst_to_src=sorted_token_ids,
format="TN",
bias=w2_bias,
output=pt_output_3,
)
# mul + reduce_sum
reduce_mask = src_to_dst == -1
if output != None:
ixfops.moe_output_reduce_sum(
input=pt_output_3.view(num_tokens, top_k, -1),
topk_weight=topk_weights,
output=output,
mask=reduce_mask,
)
else:
final_hidden_states = ixfops.moe_output_reduce_sum(
input=pt_output_3.view(num_tokens, top_k, -1),
topk_weight=topk_weights,
mask=reduce_mask,
)
else:
# group gemm 2 + reorder
pt_output_3 = ixfops.moe_w16a16_group_gemm(
input=pt_output_2,
weight=w2,
output_dtype=hidden_states.dtype,
tokens_per_experts=expert_sizes_cpu,
dst_to_src=sorted_token_ids,
bias=w2_bias,
format="TN",
)
# mul + reduce_sum
final_hidden_states = ixfops.moe_output_reduce_sum(
input=pt_output_3.view(num_tokens, top_k, -1),
topk_weight=topk_weights,
)
if output == None:
return final_hidden_states
def _get_config_quant_dtype(
use_fp8_w8a8: bool,
@@ -1825,7 +1988,7 @@ def fused_experts_impl(
intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K)
# This needs separate memory since it's used concurrently with cache1
activation_out_dim = mk.FusedMoEPermuteExpertsUnpermute.adjust_N_for_activation(
activation_out_dim = mk.FusedMoEExpertsModular.adjust_N_for_activation(
N, activation_enum
)
intermediate_cache2 = torch.empty(
@@ -1910,28 +2073,28 @@ def fused_experts_impl(
ocp_mx_scheme=ocp_mx_scheme,
)
# SPARSITY_FACTOR is a heuristic margin ensuring tokens_in_chunk * top_k
# activates only a small fraction of total experts
SPARSITY_FACTOR = 4
# block quantized code path is not implemented yet.
naive_block_assignment = (
expert_map is None
and tokens_in_chunk * top_k_num * SPARSITY_FACTOR <= global_num_experts
and not (
(use_int8_w8a16 or use_int4_w4a16)
and block_shape is not None
and block_shape[1] > 0
)
)
# # SPARSITY_FACTOR is a heuristic margin ensuring tokens_in_chunk * top_k
# # activates only a small fraction of total experts
# SPARSITY_FACTOR = 4
# # block quantized code path is not implemented yet.
# naive_block_assignment = (
# expert_map is None
# and tokens_in_chunk * top_k_num * SPARSITY_FACTOR <= global_num_experts
# and not (
# (use_int8_w8a16 or use_int4_w4a16)
# and block_shape is not None
# and block_shape[1] > 0
# )
# )
# if not naive_block_assignment:
# sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
# curr_topk_ids,
# config["BLOCK_SIZE_M"],
# global_num_experts,
# expert_map,
# ignore_invalid_experts=True,
# )
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
curr_topk_ids,
config["BLOCK_SIZE_M"],
global_num_experts,
expert_map,
ignore_invalid_experts=True,
)
# else:
# max_num_tokens_padded = topk_ids.numel() * config["BLOCK_SIZE_M"]
# expert_ids = curr_topk_ids.view(-1)
@@ -1941,14 +2104,6 @@ def fused_experts_impl(
# num_tokens_post_padded.fill_(max_num_tokens_padded)
# sorted_token_ids = None
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
curr_topk_ids,
config["BLOCK_SIZE_M"],
global_num_experts,
expert_map,
ignore_invalid_experts=True,
)
dispatch_fused_moe_kernel(
qcurr_hidden_states,
w1,
@@ -2015,20 +2170,14 @@ def fused_experts_impl(
B_bias=w2_bias,
)
# ops.moe_sum(
# intermediate_cache3.view(*intermediate_cache3.size()),
# out_hidden_states[begin_chunk_idx:end_chunk_idx],
# )
torch.sum(
intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
)
torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
out=out_hidden_states[begin_chunk_idx:end_chunk_idx])
return out_hidden_states
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
class TritonExperts(mk.FusedMoEExpertsModular):
"""Triton-based fused MoE expert implementation."""
def __init__(
@@ -2091,8 +2240,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
# return not moe_parallel_config.use_fi_all2allv_kernels
return True
return not moe_parallel_config.use_fi_all2allv_kernels
def supports_chunking(self) -> bool:
return True
@@ -2138,157 +2286,31 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
# Check constraints.
if self.quant_config.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), "Hidden size mismatch"
else:
assert hidden_states.size(-1) == w1.size(2), (
f"Hidden size mismatch {hidden_states.size(-1)} != {w1.size(2)}"
)
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert hidden_states.dim() == 2
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
assert hidden_states.dtype in [
torch.float32,
torch.float16,
torch.bfloat16,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
]
E, num_tokens, N, K, top_k_num = self.moe_problem_size(
hidden_states, w1, w2, topk_ids
)
if global_num_experts == -1:
global_num_experts = E
config = try_get_optimal_moe_config(
w1.size(),
w2.size(),
top_k_num,
self.quant_config.config_name(hidden_states.dtype),
num_tokens,
block_shape=self.block_shape,
)
if hidden_states.dtype == torch.bfloat16:
compute_type = tl.bfloat16
elif hidden_states.dtype == torch.float16:
compute_type = tl.float16
elif hidden_states.dtype == torch.float32:
compute_type = tl.float32
elif (
hidden_states.dtype == torch.float8_e4m3fn
or hidden_states.dtype == torch.float8_e4m3fnuz
):
compute_type = tl.bfloat16
else:
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
# Note that the output tensor might be in workspace1
intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N))
cache2_dim = self.adjust_N_for_activation(N, activation)
intermediate_cache2 = _resize_cache(
workspace13, (num_tokens * top_k_num, cache2_dim)
)
intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K))
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map
)
invoke_fused_moe_triton_kernel(
hidden_states,
w1,
intermediate_cache1,
a1q_scale,
self.w1_scale,
None, # topk_weights
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False, # mul_routed_weights
top_k_num,
config,
compute_type=compute_type,
use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
use_int8_w8a8=self.quant_config.use_int8_w8a8,
use_int8_w8a16=self.quant_config.use_int8_w8a16,
use_int4_w4a16=self.quant_config.use_int4_w4a16,
per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape,
B_bias=self.w1_bias,
)
self.activation(
activation, intermediate_cache2, intermediate_cache1.view(-1, N)
)
a2q_scale: torch.Tensor | None = None
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
intermediate_cache2,
a2_scale,
self.quant_dtype,
self.per_act_token_quant,
self.block_shape,
)
# invoke_fused_moe_triton_kernel(
# qintermediate_cache2,
# w2,
# intermediate_cache3,
# a2q_scale,
# self.w2_scale,
# topk_weights,
# sorted_token_ids,
# expert_ids,
# num_tokens_post_padded,
# not apply_router_weight_on_input,
# 1,
# config,
# compute_type=compute_type,
# use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
# use_int8_w8a8=self.quant_config.use_int8_w8a8,
# use_int8_w8a16=self.quant_config.use_int8_w8a16,
# use_int4_w4a16=self.quant_config.use_int4_w4a16,
# per_channel_quant=self.per_act_token_quant,
# block_shape=self.block_shape,
# B_bias=self.w2_bias,
# )
invoke_fused_moe_kernel(
qintermediate_cache2,
w2,
intermediate_cache3,
a2q_scale,
self.w2_scale,
self.w2_zp,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
not apply_router_weight_on_input,
1,
config,
compute_type=compute_type,
use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
use_int8_w8a8=self.quant_config.use_int8_w8a8,
use_int8_w8a16=self.quant_config.use_int8_w8a16,
use_int4_w4a16=self.quant_config.use_int4_w4a16,
per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape,
B_bias=self.w2_bias,
)
# separate function is required for MoE + LoRA
self.moe_sum(intermediate_cache3, output)
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
ops.moe_sum(input, output)
fused_experts_impl_opt(hidden_states,
w1,
w2,
topk_weights,
topk_ids,
activation,
apply_router_weight_on_input,
self.quant_config.use_fp8_w8a8,
self.quant_config.use_int8_w8a8,
self.quant_config.use_int8_w8a16,
self.quant_config.use_int4_w4a16,
self.quant_config.ocp_mx_scheme,
self.quant_config.per_act_token_quant,
global_num_experts,
expert_map,
self.quant_config.w1_scale,
self.quant_config.w2_scale,
self.quant_config.w1_zp,
self.quant_config.w2_zp,
self.quant_config.a1_scale,
self.quant_config.a2_scale,
self.quant_config.block_shape,
self.quant_config.w1_bias,
self.quant_config.w2_bias,
output)
class TritonWNA16Experts(TritonExperts):

View File

@@ -12,8 +12,8 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
FusedMoEExpertsModular,
FusedMoEPrepareAndFinalizeModular,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase,
@@ -27,19 +27,21 @@ class FusedMoEMethodBase(QuantizeMethodBase):
super().__init__()
self.moe: FusedMoEConfig = moe
self.moe_quant_config: FusedMoEQuantConfig | None = None
self.moe_mk: mk.FusedMoEModularKernel | None = None
self.moe_kernel: mk.FusedMoEKernel | None = None
@property
def supports_internal_mk(self) -> bool:
# NOTE(rob): temporary attribute to indicate support for
# completed migration to the new internal MK interface.
return self.moe_mk is not None
return self.moe_kernel is not None
@property
def mk_owns_shared_expert(self) -> bool:
# NOTE(rob): temporary attribute to indicate support for
# completed migration to the new internal MK interface.
return self.moe_mk is not None and self.moe_mk.shared_experts is not None
return (
self.moe_kernel is not None and self.moe_kernel.shared_experts is not None
)
@abstractmethod
def create_weights(
@@ -66,35 +68,25 @@ class FusedMoEMethodBase(QuantizeMethodBase):
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> FusedMoEPrepareAndFinalize | None:
) -> FusedMoEPrepareAndFinalizeModular | None:
from .all2all_utils import maybe_make_prepare_finalize
return maybe_make_prepare_finalize(
pf = maybe_make_prepare_finalize(
self.moe, self.moe_quant_config, routing_tables
)
assert pf is None or isinstance(pf, FusedMoEPrepareAndFinalizeModular)
return pf
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
prepare_finalize: FusedMoEPrepareAndFinalizeModular,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
) -> FusedMoEExpertsModular:
# based on the all2all implementation, select the appropriate
# gemm implementation
raise NotImplementedError(
f"{self.__class__.__name__} must select appropriate gemm "
"implementation based on the prepare_finalize"
)
def prepare_dp_allgather_tensor(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, list[torch.Tensor]]:
"""Hook to prepare tensors and extra tensors for DP allgather + EP dispatch."""
raise NotImplementedError(
"Method 'prepare_dp_allgather_tensor' is not implemented in "
f"{self.__class__.__name__}."
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
@abstractmethod
@@ -105,8 +97,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
@property
def topk_indices_dtype(self) -> torch.dtype | None:
if self.moe_mk is not None:
return self.moe_mk.prepare_finalize.topk_indices_dtype()
if self.moe_kernel is not None:
return self.moe_kernel.prepare_finalize.topk_indices_dtype()
return None
@property
@@ -119,7 +111,12 @@ class FusedMoEMethodBase(QuantizeMethodBase):
@property
def is_monolithic(self) -> bool:
return False
if self.moe_kernel is None:
if hasattr(self, "experts_cls"):
return self.experts_cls.is_monolithic()
else:
return False
return self.moe_kernel.is_monolithic
def apply(
self,

View File

@@ -13,8 +13,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel,
FusedMoEPrepareAndFinalize,
FusedMoEKernel,
FusedMoEPrepareAndFinalizeModular,
)
logger = init_logger(__name__)
@@ -26,15 +26,15 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
# --8<-- [end:modular_fused_moe]
def __init__(
self, old_quant_method: FusedMoEMethodBase, experts: FusedMoEModularKernel
self, old_quant_method: FusedMoEMethodBase, moe_kernel: FusedMoEKernel
):
super().__init__(old_quant_method.moe)
self.moe_quant_config = old_quant_method.moe_quant_config
self.moe_mk = experts
self.moe_kernel = moe_kernel
self.disable_expert_map = getattr(
old_quant_method,
"disable_expert_map",
not self.moe_mk.supports_expert_map(),
not self.moe_kernel.supports_expert_map(),
)
self.old_quant_method = old_quant_method
logger.debug("Swapping out %s", self.old_quant_method.__class__.__name__)
@@ -43,13 +43,13 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
def make(
moe_layer: torch.nn.Module,
old_quant_method: FusedMoEMethodBase,
prepare_finalize: FusedMoEPrepareAndFinalize,
prepare_finalize: FusedMoEPrepareAndFinalizeModular,
shared_experts: torch.nn.Module | None,
inplace: bool = False,
) -> "FusedMoEModularMethod":
return FusedMoEModularMethod(
old_quant_method,
FusedMoEModularKernel(
FusedMoEKernel(
prepare_finalize,
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
shared_experts,
@@ -90,8 +90,8 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.moe_mk is not None
return self.moe_mk(
assert self.moe_kernel is not None
return self.moe_kernel.apply(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,

View File

@@ -6,6 +6,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
@@ -178,7 +179,40 @@ def triton_kernel_moe_forward(
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
unpadded_N_w1=None,
unpadded_K_w1=None,
unpadded_N_w2=None,
unpadded_K_w2=None,
) -> torch.Tensor:
if (
quant_config is not None
and quant_config.use_mxfp4_w4a8
and rocm_aiter_ops.is_enabled()
):
from aiter.ops.triton.moe_routing.routing import routing as aiter_routing
routing_data, gather_idx, scatter_idx = aiter_routing(
gating_output, topk, sm_first=not renormalize
)
return triton_kernel_fused_mxfp4_w4a8_experts(
None,
hidden_states,
w1,
w2,
routing_data,
gather_idx,
scatter_idx,
activation=activation.value,
quant_config=quant_config,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
unpadded_N_w1=unpadded_N_w1,
unpadded_K_w1=unpadded_K_w1,
unpadded_N_w2=unpadded_N_w2,
unpadded_K_w2=unpadded_K_w2,
)
if expert_map is not None:
# With expert parallelism, legacy_routing produces routing data
# using global expert IDs which don't correspond to local weight
@@ -210,6 +244,9 @@ def triton_kernel_moe_forward(
effective_global_num_experts = global_num_experts
output = torch.empty_like(hidden_states)
effective_quant_config = (
quant_config if quant_config is not None else FUSED_MOE_UNQUANTIZED_CONFIG
)
return triton_kernel_fused_experts(
output,
@@ -221,7 +258,7 @@ def triton_kernel_moe_forward(
scatter_idx,
topk=topk,
activation=activation,
quant_config=quant_config,
quant_config=effective_quant_config,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=effective_global_num_experts,
expert_map=effective_expert_map,
@@ -252,8 +289,7 @@ def triton_kernel_fused_experts(
assert activation == MoEActivation.SWIGLUOAI, (
"Only SWIGLUOAI activation is supported"
)
if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
assert quant_config is not None
# type check, uint8 means mxfp4
assert hidden_states.dtype == torch.bfloat16
@@ -330,6 +366,98 @@ def triton_kernel_fused_experts(
return output_tensor
# This is a triton implementation of the fused_experts function
def triton_kernel_fused_mxfp4_w4a8_experts(
output_tensor: torch.Tensor,
hidden_states: torch.Tensor,
w1, # Tensor or triton_kernels.Tensor
w2, # Tensor or triton_kernels.Tensor
routing_data, # RoutingData
gather_indx, # GatherIndx
scatter_indx, # ScatterIndx
activation: str = "silu",
quant_config: FusedMoEQuantConfig | None = None,
swiglu_alpha: float = 1.702,
swiglu_limit: float = 7.0,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
a1q_scale: torch.Tensor | None = None,
unpadded_N_w1=None,
unpadded_K_w1=None,
unpadded_N_w2=None,
unpadded_K_w2=None,
) -> torch.Tensor:
assert quant_config is not None
# type check, uint8 means mxfp4
assert hidden_states.dtype == torch.bfloat16
assert quant_config.w1_bias is None or quant_config.w1_bias.dtype == torch.float32
assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32
# Shape check, only check non-mxfp4
assert hidden_states.shape[-1] == w1.shape[-2]
assert w2.shape[-1] == w1.shape[1]
E, _, N = w1.shape
if global_num_experts == -1:
global_num_experts = E
gammas = routing_data.gate_scal if routing_data else None
from aiter.ops.triton.moe_op_gemm_a8w4 import moe_gemm_a8w4
from aiter.ops.triton.quant_moe import downcast_to_static_fp8
assert quant_config.w1_precision is not None, (
"w1_precision in quant config can't be None"
)
assert quant_config.w2_precision is not None, (
"w2_precision in quant config can't be None"
)
hidden_states = downcast_to_static_fp8(
hidden_states, quant_config.w1_precision.flex_ctx.lhs_data.scale
)
intermediate_cache1 = moe_gemm_a8w4(
hidden_states,
w1.storage.data,
None,
quant_config.w1_precision.weight_scale.storage.data,
quant_config.w1_precision.flex_ctx.lhs_data.scale,
quant_config.w2_precision.flex_ctx.lhs_data.scale,
quant_config.w1_bias,
routing_data,
gather_indx=gather_indx,
gammas=gammas if apply_router_weight_on_input else None,
swizzle_mx_scale="CDNA4_SCALE",
out_dtype=torch.float8_e4m3fn,
apply_swiglu=True,
alpha=swiglu_alpha,
limit=swiglu_limit,
unpadded_N=unpadded_N_w1,
unpadded_K=unpadded_K_w1,
)
intermediate_cache3 = moe_gemm_a8w4(
intermediate_cache1,
w2.storage.data,
None,
quant_config.w2_precision.weight_scale.storage.data,
quant_config.w2_precision.flex_ctx.lhs_data.scale,
None,
quant_config.w2_bias,
routing_data,
scatter_indx=scatter_indx,
gammas=None if apply_router_weight_on_input else gammas,
swizzle_mx_scale="CDNA4_SCALE",
unpadded_N=unpadded_N_w2,
unpadded_K=unpadded_K_w2,
)
return intermediate_cache3
def make_routing_data(
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
@@ -383,7 +511,7 @@ def make_routing_data(
return routing_data, gather_indx, scatter_indx
class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
class BaseOAITritonExperts(mk.FusedMoEExpertsModular):
@staticmethod
def _supports_current_device() -> bool:
raise NotImplementedError(
@@ -520,6 +648,9 @@ class OAITritonExperts(BaseOAITritonExperts):
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
if self.quant_config is None:
self.quant_config: FusedMoEQuantConfig = FUSED_MOE_UNQUANTIZED_CONFIG
if expert_map is not None:
topk_ids = expert_map[topk_ids]

View File

@@ -5,8 +5,8 @@ from collections.abc import Callable, Iterable
from enum import Enum
from typing import Literal, cast, get_args, overload
import ast, re
import torch
import torch.nn.functional as F
from torch.nn.parameter import UninitializedParameter
import vllm.envs as envs
@@ -54,10 +54,14 @@ from vllm.model_executor.layers.quantization.base_config import (
)
from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up
from vllm.model_executor.layers.utils import (
parse_opt_exclude_layers,
weight_quant_l1,
weight_quant_l2,
)
logger = init_logger(__name__)
class FusedMoeWeightScaleSupported(Enum):
TENSOR = "tensor"
CHANNEL = "channel"
@@ -333,6 +337,7 @@ class FusedMoE(CustomOp):
gate: torch.nn.Module | None = None,
shared_experts: torch.nn.Module | None = None,
routed_input_transform: torch.nn.Module | None = None,
fused_shared_output: bool = False,
):
super().__init__()
@@ -483,6 +488,8 @@ class FusedMoE(CustomOp):
(expert_mask == 0) | (expert_mask == 1)
), "Aiter Fused MoE kernel only supports expert_map with 0 and 1s."
self.hidden_size = hidden_size
self.num_experts = num_experts
assert intermediate_size % self.tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
@@ -526,16 +533,18 @@ class FusedMoE(CustomOp):
# Round up hidden size before creating moe_config.
# This way moe_config is created with the correct hidden_size from the start.
unpadded_hidden_size = hidden_size
self.model_type = (
self.vllm_config.model_config.hf_config.model_type
if self.vllm_config.model_config is not None
else None
)
hidden_size = maybe_roundup_hidden_size(
hidden_size=hidden_size,
act_dtype=moe_in_dtype,
moe_parallel_config=self.moe_parallel_config,
is_lora_enabled=vllm_config.lora_config is not None,
model_type=(
self.vllm_config.model_config.hf_config.model_type
if self.vllm_config.model_config is not None
else None
),
model_type=self.model_type,
is_mxfp4_quant=(
quant_config is not None and quant_config.is_mxfp4_quant(prefix, self)
),
@@ -581,14 +590,27 @@ class FusedMoE(CustomOp):
"""
quant_method = None
if self.quant_config is not None:
self.opt_level = 0
quant_method = self.quant_config.get_quant_method(self, prefix)
if quant_method is None:
quant_method = UnquantizedFusedMoEMethod(self.moe_config)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
CompressedTensorsL1OptMoEMethod, CompressedTensorsL2OptMoEMethod)
if self.opt_level == 1:
quant_method = CompressedTensorsL1OptMoEMethod(self.moe_config)
elif self.opt_level == 2:
quant_method = CompressedTensorsL2OptMoEMethod(self.moe_config)
else:
quant_method = UnquantizedFusedMoEMethod(self.moe_config)
assert isinstance(quant_method, FusedMoEMethodBase)
return quant_method
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
self.opt_level = envs.VLLM_MOE_OPT_LEVEL
if parse_opt_exclude_layers(envs.VLLM_OPT_EXCLUDE_LAYERS, prefix):
self.opt_flag = False
logger.info(f"Excluding layer {prefix} from optimization")
self.quant_method: FusedMoEMethodBase = _get_quant_method()
if not self.moe_config.is_act_and_mul and not current_platform.is_cuda_alike():
@@ -611,6 +633,7 @@ class FusedMoE(CustomOp):
moe_quant_params = {
"num_experts": self.local_num_experts,
"hidden_size": hidden_size,
"unpadded_hidden_size": unpadded_hidden_size,
"intermediate_size_per_partition": self.intermediate_size_per_partition,
"params_dtype": params_dtype,
"weight_loader": self.weight_loader,
@@ -625,6 +648,7 @@ class FusedMoE(CustomOp):
moe_quant_params["intermediate_size_full"] = intermediate_size
self.quant_method.create_weights(layer=self, **moe_quant_params)
self.base_quant_method = self.quant_method
# Disable shared expert overlap if:
# - we are using eplb with non-default backend, because of correctness issues
@@ -638,7 +662,10 @@ class FusedMoE(CustomOp):
)
and self._shared_experts is not None
)
if fused_shared_output:
assert self.use_ep == False, "Fused shared output is only supported when EP is disabled."
assert shared_experts is not None, "Shared experts must be provided when fused_shared_output is True."
self.fused_shared_output = fused_shared_output
self.runner = self._init_runner()
def _init_runner(self):
@@ -655,6 +682,7 @@ class FusedMoE(CustomOp):
quant_method=self.quant_method,
reduce_results=self.reduce_results,
enable_dbo=self.vllm_config.parallel_config.enable_dbo,
fused_shared_output=self.fused_shared_output,
)
# TODO(bnell): This method is provided as a hook so vllm/lora/layers/fused_moe.py
@@ -681,7 +709,7 @@ class FusedMoE(CustomOp):
# routing_tables only needed for round-robin expert placement with
# DeepEP all2all backend.
routing_tables = self._maybe_init_expert_routing_tables()
prepare_finalize = self.quant_method.maybe_make_prepare_finalize(
prepare_finalize = self.base_quant_method.maybe_make_prepare_finalize(
routing_tables=routing_tables
)
if prepare_finalize is not None:
@@ -691,7 +719,7 @@ class FusedMoE(CustomOp):
self._replace_quant_method(
FusedMoEModularMethod.make(
self,
self.quant_method,
self.base_quant_method,
prepare_finalize,
self.shared_experts,
inplace=not self.moe_config.disable_inplace,
@@ -959,11 +987,7 @@ class FusedMoE(CustomOp):
else:
assert shard_id == "w3"
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
try:
expert_data.copy_(loaded_weight)
except Exception as e:
print(expert_data.shape, expert_data.dtype, loaded_weight.shape, loaded_weight.dtype)
raise e
expert_data.copy_(loaded_weight)
def _load_w2(
self,
@@ -976,7 +1000,7 @@ class FusedMoE(CustomOp):
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
shard_size = expert_data.shape[shard_dim]
shard_size = loaded_weight.shape[shard_dim] // self.tp_size
# Only narrow if the loaded_weight is not a scalar (0-dim tensor)
# and we're not loading the full weight
if not load_full and loaded_weight.ndim > 0:
@@ -984,7 +1008,55 @@ class FusedMoE(CustomOp):
shard_dim, shard_size * tp_rank, shard_size
)
# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight)
expert_data.narrow(shard_dim, 0, shard_size).copy_(loaded_weight)
def _load_model_opt_weight_or_group_weight_scale(self,
shard_dim: int,
shard_dim_scale: int,
expert_data: torch.Tensor,
scale_data: torch.Tensor,
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int,
opt_level: int,
load_full_w2: bool = False):
"""
Load grouped weight scales for group quantization or model weights
:param shard_dim: dimension to shard
:param expert_data: parameter for a particular expert
:param shard_id: either w1, w2, or w3
:param loaded_weight: checkpoint weight to load into the param
:param tp_rank: tensor parallel rank
:param load_full_w2: whether or not the w2 loaded should be sharded.
"""
assert opt_level in [1, 2]
if opt_level == 1:
weight, scale = weight_quant_l1(loaded_weight)
else:
weight, scale = weight_quant_l2(loaded_weight)
scale = scale.view(1, -1)
if shard_id == "w2":
# In the case where we have actorder/g_idx, we do not partition the
# w2 scales, as indicated by `load_full` argument, for all tp cases
self._load_w2(shard_dim=shard_dim,
loaded_weight=weight,
expert_data=expert_data,
tp_rank=tp_rank,
load_full=load_full_w2)
scale_data.copy_(scale)
elif shard_id in ("w1", "w3"):
self._load_w13(shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=weight,
expert_data=expert_data,
tp_rank=tp_rank)
self._load_w13(shard_id=shard_id,
shard_dim=shard_dim_scale,
loaded_weight=scale,
expert_data=scale_data,
tp_rank=tp_rank)
def _load_single_value(
self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int
@@ -1147,7 +1219,6 @@ class FusedMoE(CustomOp):
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
if is_transposed:
shard_dim = int(not shard_dim)
shard_dim_force = getattr(param, "shard_dim", None)
shard_dim = shard_dim_force if shard_dim_force is not None else shard_dim
@@ -1309,13 +1380,28 @@ class FusedMoE(CustomOp):
# Case model weights
if "weight" in weight_name:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=self.tp_rank,
)
if self.opt_level != 0:
scale_name = weight_name.split('.')[-1] + "_scale"
params_dict = dict(self.named_parameters())
scale_param = params_dict[scale_name]
shard_dim_scale = getattr(scale_param, "shard_dim", None)
scale_expert_data = scale_param.data if full_load else scale_param.data[expert_id]
self._load_model_opt_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
shard_dim_scale=shard_dim_scale,
loaded_weight=loaded_weight,
expert_data=expert_data,
scale_data=scale_expert_data,
opt_level=self.opt_level,
tp_rank=self.tp_rank)
else:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=self.tp_rank)
return True if return_success else None
return False if return_success else None

View File

@@ -20,6 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache,
@@ -56,25 +57,25 @@ logger = init_logger(__name__)
# MoE kernel implementations.
#
# The following main classes are defined:
# * FusedMoEPrepareAndFinalize - an abstract base class for preparation of MoE
# * FusedMoEPrepareAndFinalizeModular - an abstract base class for preparation of MoE
# inputs (e.g. quantization, distribution) and finalization of Moe outputs.
# The prepare method must take care of any needed quantization and the
# finalize method, informed by the FusedMoEPermuteExpertsUnpermute method,
# finalize method, informed by the FusedMoEExpertsModular method,
# may apply weights and/or do the final reduction of the output.
# * FusedMoEPermuteExpertsUnpermute - an abstract base class for the main fused
# * FusedMoEExpertsModular - an abstract base class for the main fused
# MoE operation, i.e matmul + act_mul + optionally quant + matmul.
# Some FusedMoEPermuteExpertsUnpermute implementations may choose to do
# Some FusedMoEExpertsModular implementations may choose to do
# the weight application and/or reduction. The class communicates this
# to [Finalize] via a TopKWeightAndReduce object.
# * FusedMoEModularKernel - an interface class that combines a
# FusedMoEPrepareAndFinalize and a FusedMoEPermuteExpertsUnpermute to
# FusedMoEPrepareAndFinalizeModular and a FusedMoEExpertsModular to
# provide the standard fused MoE kernel interface.
# * TopKWeightAndReduce - A TopKWeightAndReduce implementation chosen
# by the FusedMoEPermuteExpertsUnpermute implementation that is passed
# by the FusedMoEExpertsModular implementation that is passed
# on to [Finalize].
#
# [Quantize-Prepare] and [Finalize] functionality are bundled into a single
# class `FusedMoEPrepareAndFinalize` since they could use collective
# class `FusedMoEPrepareAndFinalizeModular` since they could use collective
# communication mechanisms that need to be consistent.
#
@@ -155,25 +156,96 @@ PrepareResultType = tuple[
torch.Tensor | None,
]
#
# PrepareResultType is a tuple of:
# - quantized + dispatched a.
# - quantized + dispatched a1_scales.
# - dispatched router logits.
#
# See `prepare_monolithic` method below.
#
PrepareMonolithicResultType = tuple[
torch.Tensor,
torch.Tensor | None,
torch.Tensor,
]
ReceiverType = Callable[[], PrepareResultType]
################################################################################
# Prepare/Finalize
################################################################################
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
class FusedMoEPrepareAndFinalize(ABC):
"""
An abstract base class for the [Quantize-Prepare] and [Finalize] steps
described above.
There are two variants of this class:
* FusedMoEPrepareAndFinalizeModular - this operates on topk ids and weights
* FusedMoEPrepareAndFinalizeMonolithic - the operates on router_logits
"""
def post_init_setup(self, fused_experts: "FusedMoEPermuteExpertsUnpermute"):
def post_init_setup(self, fused_experts: "FusedMoEExperts"):
"""
Initialize FusedMoEPrepareAndFinalize settings that depend on
FusedMoEPermuteExpertsUnpermute experts object.
The FusedMoEPrepareAndFinalize implementations that have such
Initialize FusedMoEPrepareAndFinalizeModular settings that depend on
FusedMoEExpertsModular experts object.
The FusedMoEPrepareAndFinalizeModular implementations that have such
dependencies may choose to override this function.
"""
return
@property
@abstractmethod
def activation_format(self) -> FusedMoEActivationFormat:
"""
A property indicating the output format of the activations for the
'prepare' method.
"""
raise NotImplementedError
@abstractmethod
def topk_indices_dtype(self) -> torch.dtype | None:
"""
The PrepareFinalize All2All implementations generally constrain the
dtype of the topk_ids they support. This function returns the
required topk indices dtype so it can be respected.
Return None if there are no such restrictions.
"""
raise NotImplementedError
@abstractmethod
def max_num_tokens_per_rank(self) -> int | None:
"""
Some PrepareFinalize All2All implementations are batched. Meaning,
they can process only as set of tokens at a time. This
function returns the batch size i.e the maximum number of tokens
the implementation can process at a time.
Return None if there are no such restrictions.
"""
raise NotImplementedError
@abstractmethod
def num_dispatchers(self) -> int:
raise NotImplementedError
@abstractmethod
def output_is_reduced(self) -> bool:
"""
Indicates whether or not the output of finalize is reduced across all
ranks.
"""
raise NotImplementedError
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
class FusedMoEPrepareAndFinalizeModular(FusedMoEPrepareAndFinalize):
"""
An abstract base class for the [Quantize-Prepare] and [Finalize] steps
described above for the Modular case.
"""
@abstractmethod
def prepare(
self,
@@ -198,7 +270,7 @@ class FusedMoEPrepareAndFinalize(ABC):
activations, before quantization + dispatching.
- quant_config: Quantization info provided by the fused experts.
- defer_input_quant: Runtime parameter indicating whether or not to
defer input quantization to the FusedMoEPermuteExpertsUnpermute
defer input quantization to the FusedMoEExpertsModular
in cases where the compute kernel expects unquantized inputs
Returns a tuple of:
@@ -245,7 +317,7 @@ class FusedMoEPrepareAndFinalize(ABC):
- apply_router_weight_on_input: When True, apply the weights to the
activations, before quantization + dispatching.
- defer_input_quant: Runtime parameter indicating whether or not to
defer input quantization to the FusedMoEPermuteExpertsUnpermute
defer input quantization to the FusedMoEExpertsModular
in cases where the compute kernel expects unquantized inputs
Returns a callback or a hook callback pair that when invoked waits for
@@ -338,56 +410,58 @@ class FusedMoEPrepareAndFinalize(ABC):
"""
raise NotImplementedError
@property
class FusedMoEPrepareAndFinalizeMonolithic(FusedMoEPrepareAndFinalize):
"""
An abstract base class for the [Quantize-Prepare] and [Finalize] steps
described above for the monolithic case.
"""
@abstractmethod
def activation_format(self) -> FusedMoEActivationFormat:
def prepare(
self,
a1: torch.Tensor,
router_logits: torch.Tensor,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> PrepareMonolithicResultType:
"""
A property indicating the output format of the activations for the
'prepare' method.
Optional method for subclasses compatible with monolithic
FusedMoEExpertsModular kernels.
Perform any quantization (and/or) dispatching needed for this kernel.
- a1: The (unquantized) input to the MoE layer.
- quant_config: Quantization info provided by the fused experts.
- defer_input_quant: Runtime parameter indicating whether or not to
defer input quantization to the FusedMoEExpertsModular
Returns a tuple of:
- quantized + dispatched a.
- Optional quantized + dispatched a1_scales.
"""
raise NotImplementedError
@abstractmethod
def topk_indices_dtype(self) -> torch.dtype | None:
def finalize(self, fused_expert_output: torch.Tensor) -> torch.Tensor:
"""
The PrepareFinalize All2All implementations generally constrain the
dtype of the topk_ids they support. This function returns the
required topk indices dtype so it can be respected.
Return None if there are no such restrictions.
Optional method for subclasses compatible with monolithic
FusedMoEExpertsModular kernels.
Perform any combine plus apply weights and perform a reduction on the
fused experts output.
- fused_expert_output: The unweighted, unreduced output of the fused
experts, it will have (M, topk, K) shape.
"""
raise NotImplementedError
@abstractmethod
def max_num_tokens_per_rank(self) -> int | None:
"""
Some PrepareFinalize All2All implementations are batched. Meaning,
they can process only as set of tokens at a time. This
function returns the batch size i.e the maximum number of tokens
the implementation can process at a time.
Return None if there are no such restrictions.
"""
raise NotImplementedError
@abstractmethod
def num_dispatchers(self) -> int:
raise NotImplementedError
@abstractmethod
def output_is_reduced(self) -> bool:
"""
Indicates whether or not the output of finalize is reduced across all
ranks.
"""
raise NotImplementedError
################################################################################
# Experts
################################################################################
# TODO: add supported activations method (return string)
class FusedMoEPermuteExpertsUnpermute(ABC):
"""
An abstract base class for the [Permute-Experts-Unpermute] step described
above.
"""
class FusedMoEExperts(ABC):
def __init__(
self,
moe_config: FusedMoEConfig,
@@ -419,6 +493,10 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
@staticmethod
def is_monolithic() -> bool:
raise NotImplementedError("Implemented by subclasses.")
@property
def expects_unquantized_inputs(self) -> bool:
"""
@@ -439,49 +517,6 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
"""
raise NotImplementedError
def moe_problem_size(
self,
a1: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
) -> tuple[int, int, int, int, int]:
"""
Extract the MoE problem size from the given tensor arguments:
- a: The hidden states, input to the MoE layer.
- w1: The first set of expert weights.
- w2: The second set of expert weights.
- topk_ids: The topk ids.
Note: extracting the problem shape from the weight and activation
tensors is not obvious. It needs to be done this way specifically
due to subtle issues with particular kernels, e.g. the int4 kernels
divide the trailing dimension by two, so it's not "correct" to
extract N or K from the trailing dimension of w1 or w2. Similarly,
some kernels transpose the weights, so this needs to be kept in mind.
Note: This implementation covers most cases. However, if experts
require a specialized implementation, like MarlinExperts, they are free
to override this function.
"""
assert w1.dim() == 3 and w2.dim() == 3
E, N, _ = w1.size()
K = a1.size(-1)
if a1.dim() == 2:
# Make sure we are using the correct a1 (pre-permute).
assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}"
M = a1.size(0)
else:
assert a1.dim() == 3
assert a1.size(0) == E, f"{a1.size(0)} == {E}"
M = a1.size(1) # This is max_num_tokens
assert topk_ids.dim() == 2
topk = topk_ids.size(1)
return E, M, N, K, topk
#
# Various helpers for registering support for various features.
# Used by the oracle to select a particular kernel for a deployment.
@@ -489,7 +524,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
@staticmethod
def is_supported_config(
cls: type["FusedMoEPermuteExpertsUnpermute"],
cls: type["FusedMoEExperts"],
moe_config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
@@ -512,6 +547,21 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
return False, _make_reason(
f"parallel config {moe_config.moe_parallel_config}"
)
elif not cls._supports_routing_method(
moe_config.routing_method, weight_key, activation_key
):
return False, _make_reason(f"routing method {moe_config.routing_method}")
elif not cls._supports_router_logits_dtype(
moe_config.router_logits_dtype,
moe_config.routing_method,
):
return False, _make_reason(
f"router logits dtype {moe_config.router_logits_dtype}"
)
elif not cls._supports_shape(moe_config.hidden_dim):
return False, _make_reason(
f"{moe_config.hidden_dim} hidden dim is not supported"
)
elif activation_format != cls.activation_format():
return False, _make_reason(f"{activation_format.value} activation format")
return True, None
@@ -554,10 +604,48 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
@abstractmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
"""
Whether the kernel supports deployment in expert parallel.
Whether the kernel supports deployment in particular parallel config.
Can be overriden if a kernel does not support EP, SP or some other
configuration.
"""
raise NotImplementedError
@staticmethod
def _supports_routing_method(
routing_method: RoutingMethodType,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""
Whether the kernel supports a routing method (e.g. GroupedTopK).
Can be overriden by monolithic kernels that execute the router
in addition to the experts if certain routers are not supported.
"""
return True
@staticmethod
def _supports_router_logits_dtype(
router_logits_dtype: torch.dtype | None,
routing_method: RoutingMethodType,
) -> bool:
"""
Whether a kernel supports a particular dtype for router logits input.
Can be overriden by monolithic kernels that execute the router
in addition to the experts if certain dtypes are not supported.
"""
return True
@staticmethod
def _supports_shape(hidden_dim: int) -> bool:
"""
Whether a kernel supports a particular shape. Can be overridden if a kernel
has specific shape requirements.
"""
return True
#
# Various helpers for accessing quantization parameters from the
# quant_config.
@@ -654,6 +742,65 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
"""
return False
def enable_chunking(self):
return (
envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and self.supports_chunking()
)
class FusedMoEExpertsModular(FusedMoEExperts):
"""
An abstract base class for the [Permute-Experts-Unpermute] step described
above.
"""
@staticmethod
def is_monolithic() -> bool:
return False
def moe_problem_size(
self,
a1: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
) -> tuple[int, int, int, int, int]:
"""
Extract the MoE problem size from the given tensor arguments:
- a: The hidden states, input to the MoE layer.
- w1: The first set of expert weights.
- w2: The second set of expert weights.
- topk_ids: The topk ids.
Note: extracting the problem shape from the weight and activation
tensors is not obvious. It needs to be done this way specifically
due to subtle issues with particular kernels, e.g. the int4 kernels
divide the trailing dimension by two, so it's not "correct" to
extract N or K from the trailing dimension of w1 or w2. Similarly,
some kernels transpose the weights, so this needs to be kept in mind.
Note: This implementation covers most cases. However, if experts
require a specialized implementation, like MarlinExperts, they are free
to override this function.
"""
assert w1.dim() == 3 and w2.dim() == 3
E, N, _ = w1.size()
K = a1.size(-1)
if a1.dim() == 2:
# Make sure we are using the correct a1 (pre-permute).
assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}"
M = a1.size(0)
else:
assert a1.dim() == 3
assert a1.size(0) == E, f"{a1.size(0)} == {E}"
M = a1.size(1) # This is max_num_tokens
assert topk_ids.dim() == 2
topk = topk_ids.size(1)
return E, M, N, K, topk
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
"""
Workspace type: The dtype to use for the workspace tensors.
@@ -726,11 +873,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
) -> None:
apply_moe_activation(activation, output, input)
def enable_chunking(self):
return (
envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and self.supports_chunking()
)
@abstractmethod
def finalize_weight_and_reduce_impl(self) -> TopKWeightAndReduce:
raise NotImplementedError
@@ -791,6 +934,67 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
raise NotImplementedError
class FusedMoEExpertsMonolithic(FusedMoEExperts):
"""
An abstract base class for the [Permute-Experts-Unpermute] step described
above, but with the monolithic interface (accepts router logits
rather than topk ids and weights).
"""
@staticmethod
def _supports_routing_method(
routing_method: RoutingMethodType,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""
Whether the kernel supports a routing method (e.g. GroupedTopK).
Monolithic kernels should explicitly opt-in to support.
"""
raise NotImplementedError
@staticmethod
def _supports_router_logits_dtype(
router_logits_dtype: torch.dtype | None,
routing_method: RoutingMethodType,
) -> bool:
"""
Whether the kernel supports a dtype for router logits.
Modular kernels should opt-in to support.
"""
raise NotImplementedError
@staticmethod
def is_monolithic() -> bool:
return True
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
apply_router_weight_on_input: bool,
# grouped topk + fused topk bias parameters
num_expert_group: int | None = None,
e_score_correction_bias: torch.Tensor | None = None,
routed_scaling_factor: float | None = None,
topk_group: int | None = None,
) -> torch.Tensor:
"""
Same as apply(), except uses router_logits as opposed
to the topk_ids and topk_weights. This is useful for kernels
with fused router and fused_experts (e.g. FLASHINFER_TRTLLM).
"""
raise NotImplementedError
def _slice_scales(
scales: torch.Tensor | None, start: int, end: int
) -> torch.Tensor | None:
@@ -802,75 +1006,32 @@ def _slice_scales(
return None
################################################################################
# Kernel
################################################################################
@final
class FusedMoEModularKernel(torch.nn.Module):
"""
This class combines a FusedMoEPrepareAndFinalize instance and
a FusedMoEPermuteExpertsUnpermute to provide an interface that
is compatible with the `fused_experts` function in fused_moe.py.
It takes care of managing any required scratch space.
Note: Instances of this class should only be used for a single model
layer due to any layer specific state that may be used by the component
objects.
"""
class FusedMoEKernelModularImpl:
def __init__(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
fused_experts: FusedMoEPermuteExpertsUnpermute,
prepare_finalize: FusedMoEPrepareAndFinalizeModular,
fused_experts: FusedMoEExpertsModular,
shared_experts: torch.nn.Module | None = None,
moe_parallel_config: FusedMoEParallelConfig | None = None,
inplace: bool = False,
):
super().__init__()
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
self.shared_experts = shared_experts
self.moe_parallel_config = moe_parallel_config
self.inplace = inplace
# prefer an explicit FusedMoEParallelConfig when available (from
# FusedMoE layers / tests).
# if not provided, assume this kernel is
# running in a non-DP+EP context
self.moe_parallel_config: FusedMoEParallelConfig | None = moe_parallel_config
self.is_dp_ep = (
moe_parallel_config is not None
and moe_parallel_config.dp_size > 1
and moe_parallel_config.use_ep
)
self._post_init_setup()
assert (
prepare_finalize.activation_format == fused_experts.activation_format()
), (
f"{prepare_finalize.__class__.__name__}."
f"{prepare_finalize.activation_format} == "
f"{fused_experts.__class__.__name__}."
f"{fused_experts.activation_format()}"
)
def _post_init_setup(self):
"""
Resolve any leftover setup dependencies between self.prepare_finalize
and self.fused_experts here.
"""
self.prepare_finalize.post_init_setup(self.fused_experts)
def supports_expert_map(self) -> bool:
"""
A flag indicating whether or not this class supports expert maps.
"""
return self.fused_experts.supports_expert_map()
def output_is_reduced(self) -> bool:
"""
Indicates whether or not the output of fused MoE kernel
is reduced across all ranks.
"""
return self.prepare_finalize.output_is_reduced()
def _chunk_info(self, M: int) -> tuple[int, int]:
"""
Compute number of chunks and chunk size for given M.
@@ -919,7 +1080,7 @@ class FusedMoEModularKernel(torch.nn.Module):
workspace_dtype = self.fused_experts.workspace_dtype(out_dtype)
# Force worst-case allocation in profiling run for
# "mk.FusedMoEModularKernel.Standard" formats where this is only bounded
# "mk.FusedMoEKernel.Standard" formats where this is only bounded
# by `VLLM_FUSED_MOE_CHUNK_SIZE` and may not be seen during profiling with
# DP+EP due to the random token routing.
is_profile_run = (
@@ -1172,9 +1333,9 @@ class FusedMoEModularKernel(torch.nn.Module):
# This happens when none of the tokens from the all2all reach this
# EP rank. Also, note that this is only relevant for CUDAGraph
# incompatible all2all kernels like the DeepEP high-throughput
# kernels. CUDAGraph compatible all2all kernels like the pplx
# kernels and the DeepEP low-latency kernels are always batched
# and can never run into the tensor.numel() == 0 case.
# kernels. CUDAGraph compatible all2all kernels like the DeepEP
# low-latency kernels are always batched and can never run into
# the tensor.numel() == 0 case.
if M_full == 0:
assert num_chunks == 0
workspace13 = None
@@ -1313,19 +1474,18 @@ class FusedMoEModularKernel(torch.nn.Module):
assert shared_output is not None
return shared_output, output
def forward(
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
activation: MoEActivation = MoEActivation.SILU,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
shared_experts_input: torch.Tensor | None = None,
**kwargs
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
This function computes a Mixture of Experts (MoE) layer using two sets
@@ -1335,8 +1495,7 @@ class FusedMoEModularKernel(torch.nn.Module):
- hidden_states: (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_weights (torch.Tensor): The topk weights applied at the end of
the layer.
- topk_weights (torch.Tensor): The topk weights applied at the end of the layer.
- topk_ids (torch.Tensor): A map of row to expert id.
- activation (MoEActivation): The activation function to apply after the first
MoE layer.
@@ -1355,23 +1514,6 @@ class FusedMoEModularKernel(torch.nn.Module):
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
from .fused_moe import fused_experts as fused_experts_kernel
result = fused_experts_kernel(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
quant_config=kwargs.get("quant_config", None),
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
return result
if self.inplace:
assert self.shared_experts is None
assert not disable_inplace()
@@ -1417,3 +1559,206 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input,
shared_experts_input=shared_experts_input,
)
@final
class FusedMoEKernelMonolithicImpl:
def __init__(
self,
prepare_finalize: FusedMoEPrepareAndFinalizeMonolithic,
fused_experts: FusedMoEExpertsMonolithic,
):
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
# grouped topk + fused topk bias parameters
num_expert_group: int | None = None,
e_score_correction_bias: torch.Tensor | None = None,
routed_scaling_factor: float | None = None,
topk_group: int | None = None,
) -> torch.Tensor:
"""
Same as forward(), except uses router_logits as opposed
to the topk_ids and topk_weights. This is used for kernels
that have fused router + experts (e.g. FLASHINFER_TRTLLM).
"""
# TODO(rob): add inplace support.
a1q, a1q_scale, router_logits = self.prepare_finalize.prepare(
hidden_states,
router_logits=router_logits,
quant_config=self.fused_experts.quant_config,
defer_input_quant=self.fused_experts.expects_unquantized_inputs,
)
fused_out = self.fused_experts.apply(
hidden_states=a1q,
w1=w1,
w2=w2,
router_logits=router_logits,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
a1q_scale=a1q_scale,
# grouped topk + fused topk bias parameters
num_expert_group=num_expert_group,
e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
topk_group=topk_group,
)
output = self.prepare_finalize.finalize(fused_out)
return output
@final
class FusedMoEKernel:
def __init__(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
fused_experts: FusedMoEExperts,
shared_experts: torch.nn.Module | None = None,
moe_parallel_config: FusedMoEParallelConfig | None = None,
inplace: bool = False,
):
super().__init__()
self.shared_experts = shared_experts # NOTE: check if we can remove
# Initialize the implementation (monolithic or modular).
self.impl: FusedMoEKernelModularImpl | FusedMoEKernelMonolithicImpl
if isinstance(
prepare_finalize, FusedMoEPrepareAndFinalizeModular
) and isinstance(fused_experts, FusedMoEExpertsModular):
self.impl = FusedMoEKernelModularImpl(
prepare_finalize,
fused_experts,
shared_experts,
moe_parallel_config,
inplace,
)
elif isinstance(
prepare_finalize, FusedMoEPrepareAndFinalizeMonolithic
) and isinstance(fused_experts, FusedMoEExpertsMonolithic):
assert shared_experts is None
assert not inplace
self.impl = FusedMoEKernelMonolithicImpl(
prepare_finalize,
fused_experts,
)
else:
raise ValueError(
"prepare_finalize and fused_experts must both be either monolithic "
f"or non-monolithic but got {prepare_finalize.__class__.__name__} "
f"and {fused_experts.__class__.__name__}"
)
self._post_init_setup()
@property
def is_monolithic(self) -> bool:
return isinstance(self.impl, FusedMoEKernelMonolithicImpl)
@property
def prepare_finalize(self) -> FusedMoEPrepareAndFinalize:
return self.impl.prepare_finalize
@property
def fused_experts(self) -> FusedMoEExperts:
return self.impl.fused_experts
def _post_init_setup(self):
"""
Resolve any leftover setup dependencies between self.prepare_finalize
and self.fused_experts here.
"""
self.prepare_finalize.post_init_setup(self.impl.fused_experts)
assert (
self.prepare_finalize.activation_format
== self.fused_experts.activation_format()
)
def supports_expert_map(self) -> bool:
"""
A flag indicating whether or not this class supports expert maps.
"""
return self.fused_experts.supports_expert_map()
def output_is_reduced(self) -> bool:
"""
Indicates whether or not the output of fused MoE kernel
is reduced across all ranks.
"""
return self.prepare_finalize.output_is_reduced()
def apply_monolithic(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
# grouped topk + fused topk bias parameters
num_expert_group: int | None = None,
e_score_correction_bias: torch.Tensor | None = None,
routed_scaling_factor: float | None = None,
topk_group: int | None = None,
) -> torch.Tensor:
assert isinstance(self.impl, FusedMoEKernelMonolithicImpl)
return self.impl.apply(
hidden_states=hidden_states,
w1=w1,
w2=w2,
router_logits=router_logits,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
num_expert_group=num_expert_group,
e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
topk_group=topk_group,
)
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
shared_experts_input: torch.Tensor | None = None,
) -> torch.Tensor:
assert isinstance(self.impl, FusedMoEKernelModularImpl)
return self.impl.apply(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
shared_experts_input=shared_experts_input,
)

View File

@@ -12,7 +12,7 @@ from vllm.platforms import current_platform
logger = init_logger(__name__)
class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
"""
Prepare/Finalize using MoRI kernels.
"""

View File

@@ -18,13 +18,9 @@ from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config,
fp8_w8a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import (
is_supported_config_trtllm_fp8,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
get_flashinfer_moe_backend,
make_fp8_moe_alpha_scales_for_fi,
prepare_fp8_moe_layer_for_fi,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
@@ -103,9 +99,13 @@ def _get_priority_backends(
def backend_to_kernel_cls(
backend: Fp8MoeBackend,
) -> type[mk.FusedMoEPermuteExpertsUnpermute]:
) -> type[mk.FusedMoEExperts]:
if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
raise NotImplementedError
from vllm.model_executor.layers.fused_moe.experts.trtllm_fp8_moe import ( # noqa: E501
TrtLlmFp8Experts,
)
return TrtLlmFp8Experts
elif backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
@@ -205,13 +205,11 @@ def select_fp8_moe_backend(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
allow_vllm_cutlass: bool = False,
) -> tuple[Fp8MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]:
) -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts] | None]:
"""
Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
k_cls: type[mk.FusedMoEPermuteExpertsUnpermute] | None = None
if config.is_lora_enabled:
return Fp8MoeBackend.TRITON, backend_to_kernel_cls(Fp8MoeBackend.TRITON)
@@ -252,7 +250,7 @@ def select_fp8_moe_backend(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat,
) -> tuple[Fp8MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]:
) -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts]]:
k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config(
k_cls, config, weight_key, activation_key, activation_format
@@ -287,16 +285,6 @@ def select_fp8_moe_backend(
"vLLM CUTLASS FP8 MoE backend is disabled for this configuration."
)
# Handle FLASHINFER_TRTLLM specially (no kernel class).
if requested_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
supported, reason = is_supported_config_trtllm_fp8(
config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(requested_backend))
return requested_backend, None
raise ValueError(_make_log_unsupported(requested_backend, reason))
return _return_or_raise(
requested_backend, config, weight_key, activation_key, activation_format
)
@@ -311,51 +299,32 @@ def select_fp8_moe_backend(
elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"):
# If user is explicit about backend, validate it.
fi_backend = get_flashinfer_moe_backend()
if fi_backend == FlashinferMoeBackend.TENSORRT_LLM:
backend = Fp8MoeBackend.FLASHINFER_TRTLLM
supported, reason = is_supported_config_trtllm_fp8(
config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend))
return backend, None
else:
raise ValueError(_make_log_unsupported(backend, reason))
elif fi_backend == FlashinferMoeBackend.CUTLASS:
if fi_backend == FlashinferMoeBackend.CUTLASS:
backend = Fp8MoeBackend.FLASHINFER_CUTLASS
return _return_or_raise(
backend, config, weight_key, activation_key, activation_format
)
elif fi_backend == FlashinferMoeBackend.TENSORRT_LLM:
backend = Fp8MoeBackend.FLASHINFER_TRTLLM
else:
assert fi_backend == FlashinferMoeBackend.CUTEDSL
raise ValueError("FlashInfer MaskedGEMM not supported for FP8")
raise ValueError(
f"FlashInfer MOE backend {fi_backend} does not support FP8 MoE."
)
k_cls = backend_to_kernel_cls(backend)
return _return_or_raise(
backend, config, weight_key, activation_key, activation_format
)
else:
# If the user is not explicit about the backend, try both.
for backend in [
Fp8MoeBackend.FLASHINFER_TRTLLM,
Fp8MoeBackend.FLASHINFER_CUTLASS,
]:
if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
k_cls = None
supported, reason = is_supported_config_trtllm_fp8(
config,
weight_key,
activation_key,
activation_format,
)
else:
k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config(
k_cls,
config,
weight_key,
activation_key,
activation_format,
)
k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config(
k_cls,
config,
weight_key,
activation_key,
activation_format,
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
@@ -408,23 +377,14 @@ def select_fp8_moe_backend(
# Select kernels in order of backend.
for backend in AVAILABLE_BACKENDS:
if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
k_cls = None
supported, reason = is_supported_config_trtllm_fp8(
config,
weight_key,
activation_key,
activation_format,
)
else:
k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config(
k_cls,
config,
weight_key,
activation_key,
activation_format,
)
k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config(
k_cls,
config,
weight_key,
activation_key,
activation_format,
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
@@ -510,7 +470,7 @@ def make_fp8_moe_quant_config(
block_shape: list[int] | None = None,
per_act_token_quant: bool = False,
per_out_ch_quant: bool = False,
) -> FusedMoEQuantConfig | None:
) -> FusedMoEQuantConfig:
"""
Create FusedMoEQuantConfig for the specified FP8 Backend.
The FusedMoEQuantConfig holds the scales that are used
@@ -523,9 +483,6 @@ def make_fp8_moe_quant_config(
In a future PR, we will have this function should be
a method of the modular kernel itself.
"""
# TRTLLM does not use Modular Kernel abstraction yet.
if fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return None
# MARLIN is mixed precision W8A16 config.
if fp8_backend == Fp8MoeBackend.MARLIN:
@@ -539,12 +496,6 @@ def make_fp8_moe_quant_config(
# (alpha = w_scale * a_scale) and inverse a2 scale.
if fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS and block_shape is None:
assert a1_scale is not None and a2_scale is not None
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
w1_scale,
a1_scale,
w2_scale,
a2_scale,
)
return fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
@@ -552,8 +503,8 @@ def make_fp8_moe_quant_config(
a2_scale=a2_scale,
a1_gscale=(1.0 / a1_scale),
a2_gscale=(1.0 / a2_scale),
g1_alphas=g1_alphas,
g2_alphas=g2_alphas,
g1_alphas=(w1_scale * a1_scale).squeeze(),
g2_alphas=(w2_scale * a2_scale).squeeze(),
)
# All other backends use normal config.
return fp8_w8a8_moe_quant_config(
@@ -570,17 +521,18 @@ def make_fp8_moe_quant_config(
def make_fp8_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
experts_cls: type[mk.FusedMoEExperts],
fp8_backend: Fp8MoeBackend,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
shared_experts: torch.nn.Module | None = None,
) -> mk.FusedMoEModularKernel:
) -> mk.FusedMoEKernel:
# Create Prepare/Finalize.
prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config,
quant_config=moe_quant_config,
routing_tables=routing_tables,
allow_new_interface=True,
use_monolithic=issubclass(experts_cls, mk.FusedMoEExpertsMonolithic),
)
assert prepare_finalize is not None
@@ -603,9 +555,9 @@ def make_fp8_moe_kernel(
)
# NOTE(rob): we only want the mk to control the shared_expert
# if using all2all (for SBO). bnell is making this explict in
# if using all2all (for SBO). bnell is making this explicit in
# the new MoE runner class.
kernel = mk.FusedMoEModularKernel(
kernel = mk.FusedMoEKernel(
prepare_finalize,
experts,
shared_experts=(

View File

@@ -19,7 +19,6 @@ from vllm.model_executor.layers.fused_moe.config import (
nvfp4_w4a16_moe_quant_config,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
is_supported_config_trtllm,
prepare_nvfp4_moe_layer_for_fi_or_cutlass,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
@@ -67,39 +66,46 @@ def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool:
def backend_to_kernel_cls(
backend: NvFp4MoeBackend,
) -> type[mk.FusedMoEPermuteExpertsUnpermute]:
) -> list[type[mk.FusedMoEExperts]]:
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
raise NotImplementedError(
"FLASHINFER_TRTLLM doesn't support Modular Kernel Interface"
from vllm.model_executor.layers.fused_moe.experts.trtllm_nvfp4_moe import (
TrtLlmNvFp4ExpertsModular,
TrtLlmNvFp4ExpertsMonolithic,
)
# NOTE: prefer Monolthic > Modular, so return Monolithic first.
return [
TrtLlmNvFp4ExpertsMonolithic,
TrtLlmNvFp4ExpertsModular,
]
elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
return FlashInferExperts
return [FlashInferExperts]
elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL:
from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (
FlashInferCuteDSLExperts,
)
return FlashInferCuteDSLExperts
return [FlashInferCuteDSLExperts]
elif backend == NvFp4MoeBackend.VLLM_CUTLASS:
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp4,
)
return CutlassExpertsFp4
return [CutlassExpertsFp4]
elif backend == NvFp4MoeBackend.MARLIN:
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
)
return MarlinExperts
return [MarlinExperts]
else:
raise ValueError(f"Unknown NvFP4 MoE backend: {backend.value}")
@@ -125,7 +131,7 @@ def select_nvfp4_moe_backend(
config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]:
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEExperts]]:
"""
Select the primary NvFP4 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
@@ -143,10 +149,7 @@ def select_nvfp4_moe_backend(
# NOTE(rob): this is kind of a hack. We need to peak into
# the prepare-finalize selection to determine if we are using
# the batched or standard expert format.
use_batched = (
config.moe_parallel_config.use_deepep_ll_kernels
or config.moe_parallel_config.use_pplx_kernels
)
use_batched = config.moe_parallel_config.use_deepep_ll_kernels
activation_format = (
mk.FusedMoEActivationFormat.BatchedExperts
if use_batched
@@ -178,29 +181,21 @@ def select_nvfp4_moe_backend(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat,
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]:
k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config(
k_cls, config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend))
return backend, k_cls
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEExperts]]:
for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config(
k_cls, config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend))
return backend, k_cls
raise ValueError(_make_log_unsupported(backend, reason))
# Handle explicit moe_backend from user.
runner_backend = config.moe_backend
if runner_backend != "auto":
requested_backend = map_nvfp4_backend(runner_backend)
if requested_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
supported, reason = is_supported_config_trtllm(
config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(requested_backend))
return requested_backend, None
raise ValueError(_make_log_unsupported(requested_backend, reason))
return _return_or_raise(
requested_backend, config, weight_key, activation_key, activation_format
)
@@ -213,36 +208,14 @@ def select_nvfp4_moe_backend(
elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"):
# If user is explicit about backend, validate it.
fi_backend = get_flashinfer_moe_backend()
if fi_backend == FlashinferMoeBackend.TENSORRT_LLM:
backend = NvFp4MoeBackend.FLASHINFER_TRTLLM
supported, reason = is_supported_config_trtllm(
config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend))
return backend, None
else:
raise ValueError(_make_log_unsupported(backend, reason))
else:
backend = fi_2_vllm_backend_map[fi_backend]
return _return_or_raise(
backend, config, weight_key, activation_key, activation_format
)
backend = fi_2_vllm_backend_map[get_flashinfer_moe_backend()]
return _return_or_raise(
backend, config, weight_key, activation_key, activation_format
)
else:
# If the user is not explicit about the backend, try each.
for backend in FLASHINFER_NVFP4_MOE_BACKENDS:
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
k_cls = None
supported, reason = is_supported_config_trtllm(
config,
weight_key,
activation_key,
activation_format,
)
else:
k_cls = backend_to_kernel_cls(backend)
for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config(
k_cls,
config,
@@ -250,13 +223,13 @@ def select_nvfp4_moe_backend(
activation_key,
activation_format,
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, None
else:
logger.debug_once(
_make_log_unsupported(backend, reason), scope="local"
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
else:
logger.debug_once(
_make_log_unsupported(backend, reason), scope="local"
)
raise NotImplementedError(
"Found VLLM_USE_FLASHINFER_MOE_FP4=1, but no "
@@ -271,16 +244,7 @@ def select_nvfp4_moe_backend(
# Select kernels in order of backend.
for backend in AVAILABLE_BACKENDS:
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
k_cls = None # type: ignore[assignment]
supported, reason = is_supported_config_trtllm(
config,
weight_key,
activation_key,
activation_format,
)
else:
k_cls = backend_to_kernel_cls(backend)
for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config(
k_cls,
config,
@@ -289,11 +253,11 @@ def select_nvfp4_moe_backend(
activation_format,
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
else:
logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
else:
logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
raise NotImplementedError(
"No NvFp4 MoE backend supports the deployment configuration."
@@ -401,12 +365,8 @@ def make_nvfp4_moe_quant_config(
w2_scale_2: torch.Tensor,
a13_scale: torch.Tensor,
a2_scale: torch.Tensor,
) -> FusedMoEQuantConfig | None:
UNSUPPORTED = [NvFp4MoeBackend.FLASHINFER_TRTLLM]
if backend in UNSUPPORTED:
return None
elif backend == NvFp4MoeBackend.MARLIN:
) -> FusedMoEQuantConfig:
if backend == NvFp4MoeBackend.MARLIN:
return nvfp4_w4a16_moe_quant_config(
g1_alphas=w13_scale_2,
g2_alphas=w2_scale_2,
@@ -423,22 +383,27 @@ def make_nvfp4_moe_quant_config(
a2_gscale=(1.0 / a2_scale),
w1_scale=w13_scale,
w2_scale=w2_scale,
# NOTE(rob): this is a hack until the MoE kernels
# create their own quant configs. TRTLLM kernel
# does not accept swizzled input quant scales.
is_nvfp4_scale_swizzled=(backend != NvFp4MoeBackend.FLASHINFER_TRTLLM),
)
def make_nvfp4_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
experts_cls: type[mk.FusedMoEExperts],
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
shared_experts: torch.nn.Module | None = None,
) -> mk.FusedMoEModularKernel:
) -> mk.FusedMoEKernel:
# Create Prepare/Finalize.
prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config,
quant_config=moe_quant_config,
routing_tables=routing_tables,
allow_new_interface=True,
use_monolithic=issubclass(experts_cls, mk.FusedMoEExpertsMonolithic),
)
assert prepare_finalize is not None
@@ -461,9 +426,9 @@ def make_nvfp4_moe_kernel(
)
# NOTE(rob): we only want the mk to control the shared_expert
# if using all2all (for SBO). bnell is making this explict in
# if using all2all (for SBO). bnell is making this explicit in
# the new MoE runner class.
kernel = mk.FusedMoEModularKernel(
kernel = mk.FusedMoEKernel(
prepare_finalize,
experts,
shared_experts=(

View File

@@ -19,7 +19,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import (
is_supported_config_trtllm_bf16,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
MoEPrepareAndFinalizeNoDPEPModular,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
swap_w13_to_w31,
@@ -209,7 +209,7 @@ def make_unquantized_moe_kernel(
backend: UnquantizedMoeBackend,
quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
) -> mk.FusedMoEModularKernel | None:
) -> mk.FusedMoEKernel | None:
if backend in UNSUPPORTED_BACKEND:
return None
@@ -218,8 +218,8 @@ def make_unquantized_moe_kernel(
FlashInferExperts,
)
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
kernel = mk.FusedMoEKernel(
MoEPrepareAndFinalizeNoDPEPModular(),
FlashInferExperts(
moe_config=moe_config,
quant_config=quant_config,
@@ -232,8 +232,8 @@ def make_unquantized_moe_kernel(
AiterExperts,
)
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
kernel = mk.FusedMoEKernel(
MoEPrepareAndFinalizeNoDPEPModular(),
AiterExperts(
moe_config=moe_config,
quant_config=quant_config,
@@ -241,25 +241,6 @@ def make_unquantized_moe_kernel(
inplace=not moe_config.disable_inplace,
)
elif backend == UnquantizedMoeBackend.TRITON:
from vllm.model_executor.layers.fused_moe import TritonExperts
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonExperts(
moe_config=moe_config,
quant_config=quant_config,
),
inplace=not moe_config.disable_inplace,
)
elif backend == UnquantizedMoeBackend.XPU:
from vllm.model_executor.layers.fused_moe import XPUExperts
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
XPUExperts(
moe_config=moe_config,
quant_config=quant_config,
),
inplace=not moe_config.disable_inplace,
)
from vllm.model_executor.layers.fused_moe import fused_experts
kernel = fused_experts
return kernel

View File

@@ -1,373 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import pplx_kernels as pplx
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate,
)
from vllm.model_executor.layers.fused_moe.utils import (
_validate_scale_shape,
moe_kernel_quantize_input,
)
from vllm.utils.math_utils import cdiv, round_up
logger = init_logger(__name__)
def pplx_hidden_dim_scale_bytes(
max_num_tokens: int,
hidden_dim: int,
in_dtype: torch.dtype,
quant_dtype: torch.dtype | str | None,
per_act_token_quant: bool,
block_shape: list[int] | None,
):
# All pplx byte sizes must be 16-byte aligned.
align = 16
# For blocked per token: set to
# cdiv(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to 4 * sizeof(float32) (x4 for alignment)
if quant_dtype is not None:
assert isinstance(quant_dtype, torch.dtype)
assert quant_dtype.itemsize == 1
hidden_dim_bytes = hidden_dim * quant_dtype.itemsize
elem_size = torch.float32.itemsize
if per_act_token_quant:
# per-token (M x 1)
assert block_shape is None
hidden_scale_bytes = elem_size
elif block_shape is not None:
# per-group (M x K_tiles)
block_size = block_shape[1]
num_blocks = cdiv(hidden_dim, block_size)
hidden_scale_bytes = num_blocks * elem_size
else:
# per-tensor (1 x 1)
hidden_scale_bytes = elem_size
else:
hidden_dim_bytes = hidden_dim * in_dtype.itemsize
hidden_scale_bytes = 0
return (
round_up(hidden_dim_bytes, align),
round_up(hidden_scale_bytes, align),
)
class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"""PPLX-based prepare and finalize for expert parallelism."""
def __init__(
self,
a2a: pplx.AllToAll,
max_num_tokens: int,
num_local_experts: int,
num_dispatchers: int,
):
super().__init__()
assert max_num_tokens > 0
assert num_local_experts > 0
self.a2a = a2a
self.max_num_tokens = max_num_tokens
self.num_local_experts = num_local_experts
self.num_dispatchers_ = num_dispatchers
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.BatchedExperts
def max_num_tokens_per_rank(self) -> int | None:
return self.max_num_tokens
def topk_indices_dtype(self) -> torch.dtype | None:
return torch.uint32
def num_dispatchers(self) -> int:
return self.num_dispatchers_
def output_is_reduced(self) -> bool:
return True
def supports_async(self) -> bool:
return True
def prepare_async(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> tuple[Callable, mk.ReceiverType]:
if defer_input_quant:
raise NotImplementedError(
f"{self.__class__.__name__} does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
num_tokens = a1.size(0) # M
hidden_dim = a1.size(-1) # K
assert topk_ids.size(0) == num_tokens
# expert_map should be None because with expert map, -1 id is used for
# non-local token; this causes error when casting ids to the
# topk_indices_dtype() int32
#
if expert_map is not None:
logger.warning_once(
"The PPLX backend does not support expert mapping. "
"The provided `expert_map` will be ignored."
)
expert_map = None # noqa: F841
# Is this always going to be a1.device?
device = a1.device
if apply_router_weight_on_input:
topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
a1 = a1 * topk_weights.to(a1.dtype)
repeat_cols = 4
repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0)
# TODO(bnell): always pass quant_config.a1_scale?
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
(None if quant_config.per_act_token_quant else quant_config.a1_scale),
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape,
)
_validate_scale_shape(
a1q, a1q_scale, quant_config.per_act_token_quant, quant_config.block_shape
)
orig_a_scale_block_shape: int | None = None
if a1q_scale is not None:
scalar_scales = a1q_scale.numel() == 1
# pplx requires 2-d scales even for scalar scales
if a1q_scale.dim() <= 1:
assert scalar_scales
a1q_scale = a1q_scale.view(1, 1)
orig_a_scale_block_shape = a1q_scale.shape[-1]
if not quant_config.is_block_quantized:
# TODO (bnell): use group_broadcast instead?
a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols)
assert a1q_scale is None or a1q_scale.ndim == 2, (
f"{0 if a1q_scale is None else (a1q_scale.ndim, a1q_scale.shape)}"
)
expert_num_tokens = torch.empty(
self.num_local_experts,
dtype=torch.int32,
device=device,
)
expert_x = torch.empty(
(
self.num_local_experts,
self.max_num_tokens * self.num_dispatchers(),
hidden_dim,
),
dtype=a1q.dtype,
device=device,
)
expert_x_scale: torch.Tensor | None = None
if a1q.dtype.itemsize == 1:
if quant_config.is_per_act_token:
# (M x 1) -> (E x M x K)
final_dim = expert_x.size(2)
elif quant_config.is_per_tensor:
# (1 x 1) -> (E x 1 x 1)
final_dim = 1
else:
# (M x K_tiles) -> (E x M x K_tiles)
assert quant_config.block_shape is not None
num_blocks = cdiv(expert_x.size(2), quant_config.block_shape[1])
final_dim = num_blocks
expert_x_scale_shape = (
self.num_local_experts,
expert_x.size(1),
round_up(final_dim, 4), # round up for alignment
)
expert_x_scale = torch.empty(
expert_x_scale_shape,
dtype=torch.float32,
device=expert_x.device,
)
# This argument is optional, defaults to indices.size(0)
# There's not much point setting this unless it is != indices.size(0)
bound_m: torch.Tensor | None = None
self.a2a.dispatch(
out_expert_num_tokens=expert_num_tokens,
out_expert_x=expert_x,
out_expert_x_scale=expert_x_scale,
dp_x=a1q,
dp_x_scale=a1q_scale,
indices=topk_ids,
bound_m=bound_m,
do_send=True,
do_recv=False,
)
hook = lambda: self.a2a.dispatch(
out_expert_num_tokens=expert_num_tokens,
out_expert_x=expert_x,
out_expert_x_scale=expert_x_scale,
dp_x=a1q,
dp_x_scale=a1q_scale,
indices=topk_ids,
bound_m=bound_m,
do_send=False,
do_recv=True,
)
return (
hook,
lambda: self._receiver(
expert_num_tokens,
expert_x,
expert_x_scale,
orig_a_scale_block_shape,
),
)
def _receiver(
self,
expert_num_tokens: torch.Tensor,
expert_x: torch.Tensor,
expert_x_scale: torch.Tensor | None,
orig_a_scale_block_shape: int | None,
) -> mk.PrepareResultType:
if expert_x_scale is not None:
expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape]
assert expert_x_scale.ndim == 3
expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None
)
return expert_x, expert_x_scale, expert_tokens_meta, None, None
def prepare(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType:
hook, receiver = self.prepare_async(
a1,
topk_weights,
topk_ids,
num_experts,
expert_map,
apply_router_weight_on_input,
quant_config,
defer_input_quant=defer_input_quant,
)
hook()
return receiver()
def finalize_async(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> Callable:
assert isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate), (
"Weight application and reduction happens in the combine kernel."
)
# This argument is optional
# There's not much point setting this unless it is != topk_ids.size(0)
bound_m: torch.Tensor | None = None
# TODO (bnell): fails in test_pplx_moe.py, figure out what's going on
# num_tokens = output.size(0) # M
# assert topk_ids.size(0) == num_tokens, (
# f"{topk_ids.size(0)} == {num_tokens}")
assert topk_ids.size() == topk_weights.size(), (
f"{topk_ids.size()} == {topk_weights.size()}"
)
assert output.size(0) <= self.max_num_tokens, (
f"{output.size(0)} <= {self.max_num_tokens}"
)
assert output.size(1) == fused_expert_output.size(-1)
# Set weights to 1 if we did them in dispatch. This is hacky.
if apply_router_weight_on_input:
topk_weights = torch.ones_like(topk_weights)
topk_ids_u32 = topk_ids.view(dtype=torch.uint32)
self.a2a.combine(
out_tokens=output,
indices=topk_ids_u32,
weights=topk_weights,
expert_y=fused_expert_output,
bound_m=bound_m,
do_send=True,
do_recv=False,
)
return lambda: self.a2a.combine(
out_tokens=output,
indices=topk_ids_u32,
weights=topk_weights,
expert_y=fused_expert_output,
bound_m=bound_m,
do_send=False,
do_recv=True,
)
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
receiver = self.finalize_async(
output,
fused_expert_output,
topk_weights,
topk_ids,
apply_router_weight_on_input,
weight_and_reduce_impl,
)
receiver()

View File

@@ -1,209 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.distributed import get_ep_group
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceContiguous,
TopKWeightAndReduceDelegate,
)
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize):
def __init__(
self,
is_sequence_parallel: bool = False,
num_dispatchers: int = 1,
) -> None:
super().__init__()
self.is_sequence_parallel = is_sequence_parallel
self._num_dispatchers = num_dispatchers
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def max_num_tokens_per_rank(self) -> int | None:
return None
def topk_indices_dtype(self) -> torch.dtype | None:
return None
def num_dispatchers(self) -> int:
return self._num_dispatchers
def output_is_reduced(self) -> bool:
return False
def prepare(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType:
if apply_router_weight_on_input:
topk = topk_ids.size(1)
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
# Note: do not use inplace for shared experts overlap
a1 = a1 * topk_weights.to(a1.dtype)
# Defer input quantization to the MoE kernel.
use_nvfp4 = quant_config.use_nvfp4_w4a4
if defer_input_quant:
a1q = a1
a1q_scale = None
else:
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
quant_config.a1_gscale if use_nvfp4 else quant_config.a1_scale,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
# NOTE: swizzling pads the scales to multiple of 128
# which makes the scales tensor different shape than
# the hidden states, breaking the A2A kernel. So, we
# delay the swizzling until after the A2A.
is_fp4_scale_swizzled=False,
)
# Skip gathering scales if we have static quantization
# (the scale is a scalar, replicated on all ranks) or
# if quantization is deferred.
skip_gather_scales = a1q_scale is None or a1q_scale.ndim == 0
scales = None if skip_gather_scales else [a1q_scale]
res = get_ep_group().dispatch(
a1q,
topk_weights,
topk_ids,
is_sequence_parallel=self.is_sequence_parallel,
extra_tensors=scales,
)
if skip_gather_scales:
a1q, topk_weights, topk_ids = res
else:
a1q, topk_weights, topk_ids, scales = res
assert scales is not None and len(scales) == 1
a1q_scale = scales[0]
if quant_config.quant_dtype == "nvfp4":
assert a1q_scale is not None
if a1q_scale.element_size() == 1:
a1q_scale = a1q_scale.view(torch.uint8)
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
return a1q, a1q_scale, None, topk_ids, topk_weights
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
out = weight_and_reduce_impl.apply(
output=None,
fused_expert_output=fused_expert_output,
topk_weights=topk_weights,
topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input,
)
output.copy_(
get_ep_group().combine(out, is_sequence_parallel=self.is_sequence_parallel)
)
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
"""MoE prepare and finalize without expert parallelism."""
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def max_num_tokens_per_rank(self) -> int | None:
return None
def topk_indices_dtype(self) -> torch.dtype | None:
return None
def num_dispatchers(self) -> int:
return 1
def output_is_reduced(self) -> bool:
return False
def prepare(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType:
if apply_router_weight_on_input:
topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
# Note: do not use inplace for shared experts overlap
a1 = a1 * topk_weights.to(a1.dtype)
# Defer input quant to moe kernel for backends (e.g. AITER, FI)
# which use a single kernel call for quant + experts.
if defer_input_quant:
return a1, None, None, None, None
input_sf = (
quant_config.a1_gscale
if quant_config.use_nvfp4_w4a4
else quant_config.a1_scale
)
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
input_sf,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
)
return a1q, a1q_scale, None, None, None
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
weight_and_reduce_impl.apply(
output=output,
fused_expert_output=fused_expert_output,
topk_weights=topk_weights,
topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input,
)

View File

@@ -0,0 +1,22 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.model_executor.layers.fused_moe.prepare_finalize.naive_dp_ep import (
MoEPrepareAndFinalizeNaiveDPEPModular,
MoEPrepareAndFinalizeNaiveDPEPMonolithic,
make_moe_prepare_and_finalize_naive_dp_ep,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize.no_dp_ep import (
MoEPrepareAndFinalizeNoDPEPModular,
MoEPrepareAndFinalizeNoDPEPMonolithic,
make_moe_prepare_and_finalize_no_dp_ep,
)
__all__ = [
"MoEPrepareAndFinalizeNaiveDPEPMonolithic",
"MoEPrepareAndFinalizeNaiveDPEPModular",
"make_moe_prepare_and_finalize_naive_dp_ep",
"MoEPrepareAndFinalizeNoDPEPMonolithic",
"MoEPrepareAndFinalizeNoDPEPModular",
"make_moe_prepare_and_finalize_no_dp_ep",
]

View File

@@ -0,0 +1,253 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.distributed import get_ep_group
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceContiguous,
TopKWeightAndReduceDelegate,
)
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
def _quantize_and_setup_dispatch(
a1: torch.Tensor,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> tuple[torch.Tensor, list[torch.Tensor] | None]:
# Defer input quantization to the MoE kernel.
if defer_input_quant:
a1q = a1
a1q_scale = None
else:
input_sf = (
quant_config.a1_gscale
if quant_config.use_nvfp4_w4a4
else quant_config.a1_scale
)
# NOTE: swizzling pads the scales to multiple of 128
# which makes the scales tensor different shape than
# the hidden states, breaking the A2A kernel. So, we
# delay the swizzling until after the A2A.
a1q, a1q_scale = a1q, a1q_scale = moe_kernel_quantize_input(
a1,
input_sf,
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape,
is_fp4_scale_swizzled=False,
)
# Skip gathering scales if we have static quantization
# (the scale is a scalar, replicated on all ranks) or
# if quantization is deferred.
skip_gather_scales = a1q_scale is None or a1q_scale.ndim == 0
scales = None if skip_gather_scales else [a1q_scale]
return a1q, scales
def _unwrap_scale_and_prepare_for_moe(
scales: list[torch.Tensor] | None,
quant_config: FusedMoEQuantConfig,
) -> torch.Tensor:
assert scales is not None and len(scales) == 1
a1q_scale = scales[0]
# Apply swizzling after a2a if the MoE kernel needs it.
if quant_config.quant_dtype == "nvfp4" and quant_config.is_nvfp4_scale_swizzled:
assert a1q_scale is not None
if a1q_scale.element_size() == 1:
a1q_scale = a1q_scale.view(torch.uint8)
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
return a1q_scale
class MoEPrepareAndFinalizeNaiveDPEPModular(mk.FusedMoEPrepareAndFinalizeModular):
"""
Naive Prepare/Finalize for Dp/Ep case for Modular Kernels.
Uses Torch AR/RS or AR for dispatch/combine operations, applied
to the topk weights and ids.
"""
def __init__(
self,
is_sequence_parallel: bool = False,
num_dispatchers: int = 1,
) -> None:
super().__init__()
self.is_sequence_parallel = is_sequence_parallel
self._num_dispatchers = num_dispatchers
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def max_num_tokens_per_rank(self) -> int | None:
return None
def topk_indices_dtype(self) -> torch.dtype | None:
return None
def num_dispatchers(self) -> int:
return self._num_dispatchers
def output_is_reduced(self) -> bool:
return False
def prepare(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType:
"""Quantize and Dispatch Topk Weights and Topk Ids."""
if apply_router_weight_on_input:
topk = topk_ids.size(1)
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
# Note: do not use inplace for shared experts overlap
a1 = a1 * topk_weights.to(a1.dtype)
a1q, scales = _quantize_and_setup_dispatch(a1, quant_config, defer_input_quant)
res = get_ep_group().dispatch(
a1q,
topk_weights,
topk_ids,
is_sequence_parallel=self.is_sequence_parallel,
extra_tensors=scales,
)
if scales is None:
a1q, topk_weights, topk_ids = res
a1q_scale = None
else:
a1q, topk_weights, topk_ids, scales = res
a1q_scale = _unwrap_scale_and_prepare_for_moe(scales, quant_config)
return a1q, a1q_scale, None, topk_ids, topk_weights
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
out = weight_and_reduce_impl.apply(
output=None,
fused_expert_output=fused_expert_output,
topk_weights=topk_weights,
topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input,
)
output.copy_(
get_ep_group().combine(out, is_sequence_parallel=self.is_sequence_parallel)
)
class MoEPrepareAndFinalizeNaiveDPEPMonolithic(mk.FusedMoEPrepareAndFinalizeMonolithic):
"""
Naive Prepare/Finalize for Dp/Ep case for Modular Kernels.
Uses Torch AR/RS or AR for dispatch/combine operations, applied
to the router logits (the MoE kernel runs the router internally).
"""
def __init__(
self,
is_sequence_parallel: bool = False,
num_dispatchers: int = 1,
) -> None:
super().__init__()
self.is_sequence_parallel = is_sequence_parallel
self._num_dispatchers = num_dispatchers
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def max_num_tokens_per_rank(self) -> int | None:
return None
def topk_indices_dtype(self) -> torch.dtype | None:
return None
def num_dispatchers(self) -> int:
return self._num_dispatchers
def output_is_reduced(self) -> bool:
return False
def prepare(
self,
a1: torch.Tensor,
router_logits: torch.Tensor,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareMonolithicResultType:
"""Quantize and Dispatch Router Logits."""
a1q, scales = _quantize_and_setup_dispatch(a1, quant_config, defer_input_quant)
res = get_ep_group().dispatch_router_logits(
a1q,
router_logits,
is_sequence_parallel=self.is_sequence_parallel,
extra_tensors=scales,
)
if scales is None:
a1q, router_logits = res
a1q_scale = None
else:
a1q, router_logits, scales = res
a1q_scale = _unwrap_scale_and_prepare_for_moe(scales, quant_config)
return a1q, a1q_scale, router_logits
def finalize(
self,
fused_expert_output: torch.Tensor,
) -> torch.Tensor:
out = get_ep_group().combine(
fused_expert_output, is_sequence_parallel=self.is_sequence_parallel
)
return out
def make_moe_prepare_and_finalize_naive_dp_ep(
use_monolithic: bool,
is_sequence_parallel: bool = False,
num_dispatchers: int = 1,
) -> MoEPrepareAndFinalizeNaiveDPEPModular | MoEPrepareAndFinalizeNaiveDPEPMonolithic:
return (
MoEPrepareAndFinalizeNaiveDPEPMonolithic(
is_sequence_parallel=is_sequence_parallel,
num_dispatchers=num_dispatchers,
)
if use_monolithic
else MoEPrepareAndFinalizeNaiveDPEPModular(
is_sequence_parallel=is_sequence_parallel,
num_dispatchers=num_dispatchers,
)
)

View File

@@ -0,0 +1,141 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceContiguous,
TopKWeightAndReduceDelegate,
)
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
def _quantize_input(
a1: torch.Tensor,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None]:
# Defer input quant to moe kernel for backends (e.g. AITER, FI)
# which use a single kernel call for quant + experts.
if defer_input_quant:
return a1, None
input_sf = (
quant_config.a1_gscale if quant_config.use_nvfp4_w4a4 else quant_config.a1_scale
)
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
input_sf,
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape,
is_fp4_scale_swizzled=quant_config.is_nvfp4_scale_swizzled,
)
return a1q, a1q_scale
class MoEPrepareAndFinalizeNoDPEPModular(mk.FusedMoEPrepareAndFinalizeModular):
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def max_num_tokens_per_rank(self) -> int | None:
return None
def topk_indices_dtype(self) -> torch.dtype | None:
return None
def num_dispatchers(self) -> int:
return 1
def output_is_reduced(self) -> bool:
return False
def prepare(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType:
if apply_router_weight_on_input:
topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
# Note: do not use inplace for shared experts overlap
a1 = a1 * topk_weights.to(a1.dtype)
a1q, a1q_scale = _quantize_input(a1, quant_config, defer_input_quant)
return a1q, a1q_scale, None, None, None
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
weight_and_reduce_impl.apply(
output=output,
fused_expert_output=fused_expert_output,
topk_weights=topk_weights,
topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input,
)
class MoEPrepareAndFinalizeNoDPEPMonolithic(mk.FusedMoEPrepareAndFinalizeMonolithic):
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def max_num_tokens_per_rank(self) -> int | None:
return None
def topk_indices_dtype(self) -> torch.dtype | None:
return None
def num_dispatchers(self) -> int:
return 1
def output_is_reduced(self) -> bool:
return False
def prepare(
self,
a1: torch.Tensor,
router_logits: torch.Tensor,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareMonolithicResultType:
a1q, a1q_scale = _quantize_input(a1, quant_config, defer_input_quant)
return a1q, a1q_scale, router_logits
def finalize(
self,
fused_expert_output: torch.Tensor,
) -> torch.Tensor:
return fused_expert_output
def make_moe_prepare_and_finalize_no_dp_ep(
use_monolithic: bool,
) -> MoEPrepareAndFinalizeNoDPEPModular | MoEPrepareAndFinalizeNoDPEPMonolithic:
return (
MoEPrepareAndFinalizeNoDPEPMonolithic()
if use_monolithic
else MoEPrepareAndFinalizeNoDPEPModular()
)

View File

@@ -292,7 +292,7 @@ def rocm_aiter_fused_experts(
)
class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
class AiterExperts(mk.FusedMoEExpertsModular):
@property
def expects_unquantized_inputs(self) -> bool:
return True

View File

@@ -20,6 +20,7 @@ import torch
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.forward_context import get_forward_context
from vllm.platforms import current_platform
logger = logging.getLogger(__name__)
@@ -132,7 +133,7 @@ class RoutedExpertsCapturer:
self._device_buffer = torch.zeros(
(max_num_batched_tokens, num_layers, num_experts_per_tok),
dtype=torch.int32,
device="cuda",
device=current_platform.device_type,
)
self.dp_rank = vllm_config.parallel_config.data_parallel_rank

View File

@@ -64,7 +64,7 @@ if current_platform.is_cuda_alike():
# TODO(bowen): When using `FusedMoEModularKernel`, this
# can be done in a more unified way, since
# `FusedMoEPrepareAndFinalize` will return the expert
# `FusedMoEPrepareAndFinalizeModular` will return the expert
# token count, in some cases directly from the kernel.
# However, now there are many code paths not using
# the modular kernel, e.g. calling `fused_experts`,
@@ -175,6 +175,7 @@ class BaseRouter(FusedMoERouter):
topk_ids = topk_ids.to(dtype=indices_type)
assert topk_ids.dtype == indices_type or indices_type is None
topk_ids = topk_ids.to(torch.int32)
return topk_ids
@abstractmethod

View File

@@ -31,7 +31,7 @@ def vllm_topk_softmax(
token_expert_indices,
gating_output,
renormalize,
e_score_correction_bias,
e_score_correction_bias
)
return topk_weights, topk_indices
@@ -85,13 +85,14 @@ def fused_topk_bias(
token_expert_indices = torch.empty(
M, topk, dtype=torch.int32, device=hidden_states.device
)
gating_output_float = gating_output.float() # TODO(woosuk): Optimize this.
if scoring_func == "softmax":
topk_weights, topk_ids = vllm_topk_softmax(
topk_weights,
topk_ids,
token_expert_indices,
gating_output,
gating_output_float,
renormalize,
e_score_correction_bias,
)
@@ -186,7 +187,7 @@ class FusedTopKBiasRouter(BaseRouter):
indices_type=indices_type,
)
if self.routed_scaling_factor != 1.0:
topk_weights *= self.routed_scaling_factor
# if self.routed_scaling_factor != 1.0:
# topk_weights *= self.routed_scaling_factor
return topk_weights, topk_ids

View File

@@ -26,8 +26,9 @@ def vllm_topk_softmax(
topk_indices,
token_expert_indices,
gating_output,
renormalize,
)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_indices
@@ -90,13 +91,14 @@ def fused_topk(
token_expert_indices = torch.empty(
M, topk, dtype=torch.int32, device=hidden_states.device
)
gating_output_float = gating_output.float()
if scoring_func == "softmax":
topk_func = dispatch_topk_softmax_func(
use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled()
)
topk_weights, topk_ids = topk_func(
topk_weights, topk_ids, token_expert_indices, gating_output.float(), renormalize
topk_weights, topk_ids, token_expert_indices, gating_output_float, renormalize
)
return topk_weights, topk_ids, token_expert_indices
@@ -105,7 +107,7 @@ def fused_topk(
use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled()
)
topk_weights, topk_ids = topk_func(
topk_weights, topk_ids, token_expert_indices, gating_output.float(), renormalize
topk_weights, topk_ids, token_expert_indices, gating_output_float, renormalize
)
return topk_weights, topk_ids, token_expert_indices

View File

@@ -0,0 +1,115 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch.nn.parameter import Parameter
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.platforms import current_platform
@PluggableLayer.register("gate_linear")
class GateLinear(ReplicatedLinear):
"""MoE gate linear layer with three-tier GEMM dispatch:
1. DSV3 specialized kernel (SM90+, batch<=16, supported dims)
2. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype)
3. F.linear via ReplicatedLinear (ultimate fallback)
The ``out_dtype`` attribute is mutable and can be set after init
(e.g. when the required dtype depends on the expert quantization
method which is only known later).
"""
# Dimensions supported by the DSV3 specialized kernel
DSV3_SUPPORTED_NUM_EXPERTS = [256, 384]
DSV3_SUPPORTED_HIDDEN_SIZES = [7168]
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
out_dtype: torch.dtype | None = None,
params_dtype: torch.dtype | None = None,
force_fp32_compute: bool = False,
prefix: str = "",
):
is_hopper_or_blackwell = current_platform.is_device_capability(
(9, 0)
) or current_platform.is_device_capability_family(100)
can_use_specialized_kernels = False
# If fp32 compute is required and no specialized kernel is available,
# store weights in fp32 so Tier 3 computes in fp32 natively.
if force_fp32_compute and not can_use_specialized_kernels:
params_dtype = torch.float32
super().__init__(
input_size,
output_size,
bias=bias,
params_dtype=params_dtype,
quant_config=None,
prefix=prefix,
)
self.out_dtype = out_dtype
# DSV3 specialized kernel eligibility (SM90+, exact dims)
self.allow_specialized_router_gemm = can_use_specialized_kernels
self.allow_dsv3_router_gemm = (
self.allow_specialized_router_gemm
and output_size in self.DSV3_SUPPORTED_NUM_EXPERTS
and input_size in self.DSV3_SUPPORTED_HIDDEN_SIZES
)
# cuBLAS bf16→fp32 eligibility
self.allow_cublas_router_gemm = (
self.allow_specialized_router_gemm
and self.weight.dtype == torch.bfloat16
and self.out_dtype == torch.float32
)
def set_out_dtype(self, out_dtype: torch.dtype) -> None:
"""Set output dtype for the router logits after init.
Useful when the required dtype depends on the expert quantization
method which is only known after the gate is constructed.
"""
if self.out_dtype is not None:
raise ValueError("out_dtype has already been set")
self.out_dtype = out_dtype
if (
not self.allow_cublas_router_gemm
and self.allow_specialized_router_gemm
and out_dtype == torch.float32
):
self.allow_cublas_router_gemm = self.weight.dtype == torch.bfloat16
def forward(
self, x: torch.Tensor
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
import vllm._custom_ops as ops
# Tier 1: DSV3 specialized kernel
if self.allow_dsv3_router_gemm and x.shape[0] <= 16:
output = ops.dsv3_router_gemm(
hidden_states=x,
router_weight=self.weight,
output_dtype=self.out_dtype,
)
return output, None
# Tier 2: cuBLAS bf16→fp32
if self.allow_cublas_router_gemm and x.dtype == torch.bfloat16:
output = ops.router_gemm_bf16_fp32(x, self.weight)
return output, None
# Tier 3: F.linear (ReplicatedLinear)
if self.out_dtype is not None and x.dtype != self.weight.dtype:
x = x.to(self.weight.dtype)
output, output_bias = super().forward(x)
if self.out_dtype is not None and output.dtype != self.out_dtype:
output = output.to(self.out_dtype)
return output, output_bias

View File

@@ -92,77 +92,9 @@ def grouped_topk(
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if (
envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK
and current_platform.is_cuda()
and num_expert_group <= 32
and topk <= 32
and e_score_correction_bias is not None
):
return fused_grouped_topk(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
e_score_correction_bias=e_score_correction_bias,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
)
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
if scoring_func == "softmax":
scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
num_token = scores.size(0)
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_scores = scores
scores = scores + e_score_correction_bias.unsqueeze(0)
group_scores = (
scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
)
else:
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = vllm_is_batch_invariant()
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, scores.size(-1) // num_expert_group)
.reshape(num_token, -1)
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(
tmp_scores, k=topk, dim=-1, sorted=use_sorted
)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
if routed_scaling_factor != 1.0:
topk_weights = topk_weights * routed_scaling_factor
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
from ixformer.inference.functions import moe_grouped_topk as grouped_topk
topk_weights, topk_ids = grouped_topk(gating_output, topk, num_expert_group, topk_group, scoring_func, e_score_correction_bias,renormalize = renormalize)
return topk_weights, topk_ids
# --8<-- [start:grouped_topk]
@@ -246,7 +178,6 @@ class GroupedTopk(CustomOp):
hidden_states, gating_output, e_score_correction_bias
)
from ixformer.inference.functions import moe_grouped_topk as grouped_topk
class GroupedTopKRouter(BaseRouter):
"""Router using grouped top-k routing (e.g., DeepSeekV2/V3)."""
@@ -316,8 +247,8 @@ class GroupedTopKRouter(BaseRouter):
topk=self.top_k,
renormalize=self.renormalize,
)
if self.routed_scaling_factor != 1.0:
topk_weights *= self.routed_scaling_factor
# if self.routed_scaling_factor != 1.0:
# topk_weights *= self.routed_scaling_factor
else:
topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states=hidden_states,
@@ -340,14 +271,14 @@ class GroupedTopKRouter(BaseRouter):
grouped_topk_impl = grouped_topk
topk_weights, topk_ids = grouped_topk_impl(
# hidden_states=hidden_states,
hidden_states=hidden_states,
gating_output=router_logits,
topk=self.top_k,
renormalize=self.renormalize,
num_expert_group=self.num_expert_group,
topk_group=self.topk_group,
scoring_func=self.scoring_func,
# routed_scaling_factor=self.routed_scaling_factor,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias,
)

View File

@@ -44,7 +44,7 @@ def create_fused_moe_router(
# grouped topk + fused topk bias parameters
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
# custom routing paramaters
# custom routing parameters
custom_routing_function: Callable | None = None,
# eplb parameters
enable_eplb: bool = False,

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import nullcontext
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F
@@ -30,6 +31,8 @@ from vllm.model_executor.layers.fused_moe.runner.moe_runner import MoERunner
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import (
HAS_OPAQUE_TYPE,
ModuleName,
aux_stream,
current_stream,
direct_register_custom_op,
@@ -56,13 +59,27 @@ def get_layer_from_name(layer_name: str) -> torch.nn.Module:
return forward_context.no_compile_layers[layer_name]
# On torch >= 2.11, layer_name is a hoisted ModuleName opaque object;
# on older versions it remains a plain str.
if TYPE_CHECKING:
from typing import TypeAlias
_layer_name_type: TypeAlias = str | ModuleName
else:
_layer_name_type = ModuleName if HAS_OPAQUE_TYPE else str
def _resolve_layer_name(layer_name: str | ModuleName) -> str:
return layer_name.value if isinstance(layer_name, ModuleName) else layer_name
def _moe_forward(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
layer_name: str,
layer_name: _layer_name_type,
) -> torch.Tensor:
layer = get_layer_from_name(layer_name)
layer = get_layer_from_name(_resolve_layer_name(layer_name))
# TODO(bnell): this can be removed after MK migration is complete.
layer.ensure_moe_quant_config_init()
return layer.runner.forward_impl(
@@ -74,7 +91,7 @@ def _moe_forward_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
layer_name: str,
layer_name: _layer_name_type,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
@@ -83,9 +100,9 @@ def _moe_forward_shared(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
layer_name: str,
layer_name: _layer_name_type,
) -> tuple[torch.Tensor, torch.Tensor]:
layer = get_layer_from_name(layer_name)
layer = get_layer_from_name(_resolve_layer_name(layer_name))
# TODO(bnell): this can be removed after MK migration is complete.
layer.ensure_moe_quant_config_init()
return layer.runner.forward_impl(
@@ -97,7 +114,7 @@ def _moe_forward_shared_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
layer_name: str,
layer_name: _layer_name_type,
) -> tuple[torch.Tensor, torch.Tensor]:
# Output shapes:
# - fused_out: same as hidden_states (routed experts use transformed size)
@@ -105,12 +122,10 @@ def _moe_forward_shared_fake(
# hidden_states
# (For latent MoE: shared experts use original hidden_size, not latent size)
fused_out = torch.empty_like(hidden_states)
if shared_experts_input is not None:
shared_out = torch.empty_like(shared_experts_input)
else:
shared_out = torch.empty_like(hidden_states)
return shared_out, fused_out
@@ -165,6 +180,7 @@ class DefaultMoERunner(MoERunner):
quant_method: FusedMoEMethodBase,
reduce_results: bool,
enable_dbo: bool,
fused_shared_output: bool = False,
):
super().__init__()
self.moe_config = moe_config
@@ -175,6 +191,9 @@ class DefaultMoERunner(MoERunner):
self.quant_method = quant_method
self.reduce_results = reduce_results
self.enable_dbo = enable_dbo
self.fused_shared_output = fused_shared_output
if self.fused_shared_output:
assert self.shared_experts is not None, "Shared experts must be provided when fused_shared_output is True."
# Allow disabling of the separate shared experts stream for
# debug purposes.
@@ -195,19 +214,19 @@ class DefaultMoERunner(MoERunner):
# Needed for string -> FusedMoE layer lookup in custom ops.
self.layer_name = layer.layer_name
if current_platform.is_tpu() or current_platform.is_cpu():
# if current_platform.is_tpu() or current_platform.is_cpu():
# TODO: Once the OOM issue for the TPU backend is resolved, we
# will switch to using the moe_forward custom op.
# Note: CPU doesn't require wrapped forward_impl.
if self.shared_experts is None:
self.moe_forward = _moe_forward
else:
self.moe_forward = _moe_forward_shared
if self.shared_experts is None:
self.moe_forward = _moe_forward
else:
if self.shared_experts is None:
self.moe_forward = torch.ops.vllm.moe_forward
else:
self.moe_forward = torch.ops.vllm.moe_forward_shared
self.moe_forward = _moe_forward_shared
# else:
# if self.shared_experts is None:
# self.moe_forward = torch.ops.vllm.moe_forward
# else:
# self.moe_forward = torch.ops.vllm.moe_forward_shared
# Chunked all2all staging tensor
self.batched_hidden_states: torch.Tensor | None = None
@@ -216,8 +235,7 @@ class DefaultMoERunner(MoERunner):
@property
def use_dp_chunking(self) -> bool:
return (
self.moe_config.moe_parallel_config.use_pplx_kernels
or self.moe_config.moe_parallel_config.use_deepep_ll_kernels
self.moe_config.moe_parallel_config.use_deepep_ll_kernels
or self.moe_config.moe_parallel_config.use_mori_kernels
or self.moe_config.moe_parallel_config.use_fi_all2allv_kernels
) and envs.VLLM_ENABLE_MOE_DP_CHUNK
@@ -306,8 +324,8 @@ class DefaultMoERunner(MoERunner):
"""
assert self.quant_method is not None
return (
self.quant_method.moe_mk is not None
and self.quant_method.moe_mk.output_is_reduced()
self.quant_method.moe_kernel is not None
and self.quant_method.moe_kernel.output_is_reduced()
)
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
@@ -362,13 +380,15 @@ class DefaultMoERunner(MoERunner):
if isinstance(states, tuple):
return tuple(
[func(s, trunc_size) for s, trunc_size in zip(states, trunc_sizes)]
[None if s is None else func(s, trunc_size) for s, trunc_size in zip(states, trunc_sizes)]
)
else:
assert len(trunc_sizes) == 1
return func(states, trunc_sizes[0])
def _encode_layer_name(self) -> str:
def _encode_layer_name(self) -> str | ModuleName:
if HAS_OPAQUE_TYPE:
return ModuleName(self.layer_name)
# Can be unavailable or None in unittests
if (
is_forward_context_available()
@@ -624,53 +644,27 @@ class DefaultMoERunner(MoERunner):
)
with sp_ctx:
extra_tensors = None
if do_naive_dispatch_combine:
post_quant_allgather = (
self.quant_method is not None
and self.moe_config.dp_size > 1
and self.moe_config.use_ep
and getattr(self.quant_method, "do_post_quant_allgather", False)
)
if post_quant_allgather:
hidden_states_to_dispatch, extra_tensors = (
self.quant_method.prepare_dp_allgather_tensor(
layer, hidden_states, router_logits
)
)
else:
hidden_states_to_dispatch = hidden_states
dispatch_res = get_ep_group().dispatch_router_logits(
hidden_states_to_dispatch,
router_logits,
self.moe_config.is_sequence_parallel,
extra_tensors=extra_tensors,
)
if extra_tensors is not None:
(
orig_hidden_states,
router_logits,
extra_tensors_combined,
) = dispatch_res
hidden_states_combined = (
orig_hidden_states,
extra_tensors_combined[0],
)
else:
hidden_states_combined, router_logits = dispatch_res
orig_hidden_states = hidden_states_combined
else:
orig_hidden_states = hidden_states
# Run shared experts before matrix multiply.
# because matrix multiply maybe modify the hidden_states.
if has_separate_shared_experts and not use_shared_experts_stream:
if has_separate_shared_experts: # and not use_shared_experts_stream:
assert self.shared_experts is not None
shared_input = (
shared_input if shared_input is not None else hidden_states
)
shared_output = self.shared_experts(shared_input)
else:
assert self.fused_shared_output == False, "fused_shared_output is only supported when has_separate_shared_experts is True"
shared_output = None
# For naive dispatch/combine Dp/Ep, dispatch the hidden states and
# router logits to all experts.
# NOTE: this will be removed once all kernels are migrated into the
# MoEKernel framework.
if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch_router_logits(
hidden_states,
router_logits,
self.moe_config.is_sequence_parallel,
)
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
@@ -685,42 +679,33 @@ class DefaultMoERunner(MoERunner):
dim=0,
)
# TODO(bnell): deal with fp4 flashinfer tuple hidden states hack (#30014).
# Figure out nicer way to do this.
if do_naive_dispatch_combine:
x = hidden_states_combined
x_orig = orig_hidden_states
else:
x = hidden_states
x_orig = hidden_states
# Matrix multiply.
if self.quant_method.is_monolithic:
final_hidden_states = self.quant_method.apply_monolithic(
layer=layer,
x=x,
x=hidden_states,
router_logits=router_logits,
)
else:
topk_weights, topk_ids = self.router.select_experts(
hidden_states=x_orig,
hidden_states=hidden_states,
router_logits=router_logits,
)
final_hidden_states = self.quant_method.apply(
layer=layer,
x=x, # The type signture of this is wrong due to the hack.
x=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
shared_experts_input=shared_input,
router_logits=router_logits,
top_k=topk_ids.shape[-1]
# Assign the value of shared_experts_output to variable shared_experts_input for fusion
shared_experts_input=shared_output if self.fused_shared_output else None,
)
if has_separate_shared_experts:
assert self.shared_experts is not None
if use_shared_experts_stream:
assert use_shared_experts_stream == False, "Running shared experts in parallel with the main MoE execution is currently not supported!"
# Run shared experts in parallel on a separate stream
# NOTE: We start the separate stream here and mark the
# sync end point immediately after it is done. This is
@@ -733,7 +718,7 @@ class DefaultMoERunner(MoERunner):
current_stream().wait_stream(self.shared_experts_stream)
final_hidden_states = (
shared_output,
None if self.fused_shared_output else shared_output,
final_hidden_states,
)

View File

@@ -10,14 +10,15 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
class TopKWeightAndReduceDelegate(mk.TopKWeightAndReduce):
"""
Useful in the case when some FusedMoEPermuteExpertsUnpermute
Useful in the case when some FusedMoEExpertsModular
implementation does not perform weight application and reduction
but cannot address the needs of all the compatible PrepareAndFinalize
implementations.
For example, BatchedTritonExperts is compatible with both
PplxPrepareAndFinalize and BatchedPrepareAndFinalize. PplxPrepareAndFinalize
does the weight-application + reduction as part of the pplx combine kernel.
But the BatchedPrepareAndFinalize needs an implementation. To facilitate
For example, BatchedTritonExperts is compatible with both batched
PrepareAndFinalize implementations like DeepEPLLPrepareAndFinalize and
BatchedPrepareAndFinalize. Some PrepareAndFinalize implementations do
the weight-application + reduction as part of the combine kernel, while
BatchedPrepareAndFinalize needs an explicit implementation. To facilitate
this case, the BatchedTritonExperts could use TopKWeightAndReduceDelegate
so the PrepareAndFinalize implementations could choose how to
weight + reduce.
@@ -61,7 +62,7 @@ class TopKWeightAndReduceNoOP(mk.TopKWeightAndReduce):
if output is None:
return fused_expert_output
# MoEPrepareAndFinalizeNoEP needs the output to be in the `output`
# MoEPrepareAndFinalizeNoDPEPModular needs the output to be in the `output`
# tensor.
assert output.size() == fused_expert_output.size(), (
"output shape is expected to match the fused_expert_output shape. "

View File

@@ -32,8 +32,8 @@ class TritonOrCutlassExperts(FallbackExperts):
@staticmethod
def get_clses() -> tuple[
type[mk.FusedMoEPermuteExpertsUnpermute],
type[mk.FusedMoEPermuteExpertsUnpermute],
type[mk.FusedMoEExpertsModular],
type[mk.FusedMoEExpertsModular],
]:
return (CutlassExpertsFp8, TritonExperts)
@@ -77,7 +77,7 @@ class TritonOrCutlassExperts(FallbackExperts):
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
) -> mk.FusedMoEPermuteExpertsUnpermute:
) -> mk.FusedMoEExpertsModular:
# Small batch fallback for sm100.
if self.is_sm100 and hidden_states.shape[0] <= 8:
return self.fallback_experts

View File

@@ -32,8 +32,8 @@ class TritonOrDeepGemmExperts(FallbackExperts):
@staticmethod
def get_clses() -> tuple[
type[mk.FusedMoEPermuteExpertsUnpermute],
type[mk.FusedMoEPermuteExpertsUnpermute],
type[mk.FusedMoEExpertsModular],
type[mk.FusedMoEExpertsModular],
]:
return (DeepGemmExperts, TritonExperts)
@@ -79,7 +79,7 @@ class TritonOrDeepGemmExperts(FallbackExperts):
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
) -> mk.FusedMoEPermuteExpertsUnpermute:
) -> mk.FusedMoEExpertsModular:
if is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2):
return self.experts
else:

View File

@@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
class TrtLlmGenExperts(mk.FusedMoEExpertsModular):
"""TensorRT-LLM-based fused MoE expert implementation."""
def __init__(

View File

@@ -24,8 +24,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat,
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
FusedMoEExpertsModular,
FusedMoEPrepareAndFinalizeModular,
)
from vllm.model_executor.layers.fused_moe.oracle.unquantized import (
UnquantizedMoeBackend,
@@ -42,9 +42,9 @@ from vllm.platforms.interface import CpuArchEnum
if current_platform.is_cuda_alike() or current_platform.is_xpu():
from .fused_batched_moe import BatchedTritonExperts
from .fused_moe import TritonExperts
else:
TritonExperts = None # type: ignore
fused_experts = None
logger = init_logger(__name__)
@@ -70,7 +70,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self.rocm_aiter_moe_enabled = (
rocm_aiter_ops.is_fused_moe_enabled() and moe.is_act_and_mul
)
self.kernel: mk.FusedMoEModularKernel | None = None
self.kernel: mk.FusedMoEKernel | None = None
self._is_monolithic = (
current_platform.is_cpu()
or self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM
@@ -107,7 +107,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> FusedMoEPrepareAndFinalize | None:
) -> FusedMoEPrepareAndFinalizeModular | None:
if self.unquantized_backend == UnquantizedMoeBackend.AITER:
return None
else:
@@ -115,9 +115,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
prepare_finalize: FusedMoEPrepareAndFinalizeModular,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
) -> FusedMoEExpertsModular:
assert self.moe_quant_config is not None
if (
prepare_finalize.activation_format
@@ -296,16 +296,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
# Assign the value of shared_experts_output to variable shared_experts_input for fusion
shared_experts_input: torch.Tensor | None,
**kwargs
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward(
result = self.forward(
layer=layer,
x=x,
topk_weights=topk_weights,
topk_ids=topk_ids,
# not used
shared_experts_input=shared_experts_input,
)
) * layer.routed_scaling_factor
if shared_experts_input is not None:
result += shared_experts_input
return result
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
if self.moe.has_bias:
@@ -333,10 +337,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
quant_config=self.moe_quant_config,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
shared_experts_input=shared_experts_input,
expert_map=layer.expert_map
)
def forward_monolithic_cuda(

View File

@@ -23,7 +23,7 @@ if current_platform.is_xpu():
from vllm_xpu_kernels.fused_moe_interface import xpu_fused_moe
class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
class XPUExperts(mk.FusedMoEExpertsModular):
def __init__(
self,
moe_config: FusedMoEConfig,

View File

@@ -82,11 +82,12 @@ def fused_add_rms_norm(
return rms_norm_batch_invariant(
x + residual, weight, variance_epsilon
), x + residual
ops.fused_add_rms_norm(
x, residual = ops.fused_add_rms_norm(
x,
residual,
weight,
variance_epsilon,
residual_alpha,
)
return x, residual
@@ -125,7 +126,7 @@ def dispatch_rocm_rmsnorm_func(
return fused_add_rms_norm
return rms_norm
def rms_norm_qk(
input_q: torch.Tensor,
input_k: torch.Tensor,
@@ -140,11 +141,7 @@ def rms_norm_qk(
output_q, output_k, input_q, input_k, weight_q, weight_k, epsilon
)
return output_q, output_k
def dispatch_cuda_rmsnorm_qk_func() -> callable:
return rms_norm_qk
@CustomOp.register("rms_norm_qk")
class RMSNormQK(CustomOp):
@@ -226,8 +223,7 @@ class RMSNormQK(CustomOp):
f"[RMSNormQK] Expected input_q and input_k to have same dtype, "
f"but got {input_q.dtype} vs {input_k.dtype}"
)
norm_func = dispatch_cuda_rmsnorm_qk_func()
return norm_func(
return rms_norm_qk(
input_q,
input_k,
weight_q,
@@ -264,7 +260,7 @@ class RMSNormQK(CustomOp):
f"eps={self.variance_epsilon}, "
)
# --8<-- [start:rms_norm]
@CustomOp.register("rms_norm")
class RMSNorm(CustomOp):
"""Root mean square normalization.
@@ -375,7 +371,7 @@ class RMSNorm(CustomOp):
# otherwise Inductor eliminates the casts to and from f16,
# increasing memory usage (and complicating pattern matching)
x = x + residual
residual = x.to(orig_dtype).contiguous()
residual = x.to(orig_dtype)
if x.shape[-1] != hidden_size:
raise ValueError(
@@ -425,6 +421,7 @@ class RMSNorm(CustomOp):
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
residual_alpha: float = 1.0,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if self.variance_size_override is not None:
return self.forward_native(x, residual)
@@ -499,7 +496,7 @@ class RMSNorm(CustomOp):
add_residual = residual is not None
if add_residual:
return fused_add_rms_norm(
x, residual, self.weight.data, self.variance_epsilon
x, residual, self.weight.data, self.variance_epsilon,residual_alpha
)
else:
return rms_norm(x, self.weight.data, self.variance_epsilon)
@@ -649,6 +646,7 @@ class RMSNormGated(CustomOp):
norm_before_gate: bool = False,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
activation: str = "swish",
):
"""Initialize RMSNormGated.
@@ -663,10 +661,12 @@ class RMSNormGated(CustomOp):
If False and z is provided: out = norm(x * silu(z))
device: Device to create parameters on
dtype: Data type for parameters
activation: Activation function name for gating
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
self.activation = activation
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.register_parameter("bias", None)
self.group_size = group_size
@@ -693,6 +693,11 @@ class RMSNormGated(CustomOp):
- norm_before_gate=True: out = norm(x) * silu(z)
- norm_before_gate=False: out = norm(x * silu(z))
"""
orig_dtype = x.dtype
x = x.float()
weight = self.weight.float()
z = z.float() if z is not None else None
# Apply gating before normalization if needed
if z is not None and not self.norm_before_gate:
x = x * F.silu(z)
@@ -702,7 +707,7 @@ class RMSNormGated(CustomOp):
# Standard RMS norm across the last dimension
variance = x.pow(2).mean(dim=-1, keepdim=True)
x_normed = x * torch.rsqrt(variance + self.eps)
out = x_normed * self.weight
out = x_normed * weight
else:
# Group RMS norm
from einops import rearrange
@@ -710,13 +715,13 @@ class RMSNormGated(CustomOp):
x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size)
variance = x_group.pow(2).mean(dim=-1, keepdim=True)
x_normed = x_group * torch.rsqrt(variance + self.eps)
out = rearrange(x_normed, "... g d -> ... (g d)") * self.weight
out = rearrange(x_normed, "... g d -> ... (g d)") * weight
# Apply gating after normalization if needed
if z is not None and self.norm_before_gate:
out = out * F.silu(z)
return out
return out.to(orig_dtype)
def forward_cuda(
self, x: torch.Tensor, z: torch.Tensor | None = None
@@ -731,6 +736,7 @@ class RMSNormGated(CustomOp):
eps=self.eps,
group_size=self.group_size,
norm_before_gate=self.norm_before_gate,
activation=self.activation,
)

View File

@@ -2,8 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import ast, re
from abc import abstractmethod
from typing import Any
import torch
from torch.nn.parameter import Parameter, UninitializedParameter
@@ -16,6 +16,7 @@ from vllm.distributed import (
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.batch_invariant import (
@@ -28,7 +29,9 @@ from vllm.model_executor.layers.quantization.base_config import (
)
from vllm.model_executor.layers.utils import (
dispatch_unquantized_gemm,
is_layer_moe_router_gate,
parse_opt_exclude_layers,
weight_quant_l1,
weight_quant_l2,
)
from vllm.model_executor.parameter import (
BasevLLMParameter,
@@ -41,12 +44,11 @@ from vllm.model_executor.parameter import (
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
import vllm.envs as envs
from compressed_tensors.quantization import QuantizationStrategy
logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED = [
"UnquantizedLinearMethod",
"CompressedTensorsLinearMethod",
"CompressedTensorsLinearTransformMethod",
"AWQMarlinLinearMethod",
@@ -66,6 +68,14 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"PetitNvFp4LinearMethod",
]
LINEAR_OPT_SUPPORTED = [
"ColumnParallelLinear",
"ReplicatedLinear",
"RowParallelLinear",
"QKVParallelLinear",
"MergedColumnParallelLinear",
]
def adjust_marlin_shard(
param: Parameter,
@@ -135,44 +145,6 @@ def adjust_scalar_to_fused_array(
return param_data[shard_id], loaded_weight
# TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets.
def left_shift_bitsandbytes_4bit_shard(
bnb_weight_attrs: dict[str, Any],
) -> tuple[dict[str, Any], dict[str, Any]]:
"""
Separate the BitsAndBytes 4-bit shard.
For example, given bnb weight attributes as below:
{
'bnb_shard_offsets': array([0, 4, 8, 16]),
'bnb_quant_state': {0: ..., 1: ..., 2: ...},
}
The function will return:
{
'bnb_shard_offsets': array([0, 4]),
'bnb_quant_state': {0: ...},
}
and
{
'bnb_shard_offsets': array([0, 4, 12]),
'bnb_quant_state': {0: ..., 1: ...},
}
"""
shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
offset_l = shard_offsets[:2]
offset_r = shard_offsets[1:] - shard_offsets[1]
quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
quant_state_r = {
i - 1: bnb_weight_attrs["bnb_quant_state"][i]
for i in range(1, len(shard_offsets) - 1)
}
left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
return left, right
class LinearMethodBase(QuantizeMethodBase):
"""Base class for different (maybe quantized) linear methods."""
@@ -231,17 +203,11 @@ class UnquantizedLinearMethod(LinearMethodBase):
# The weights are not quantized, and they are not sharded.
# The amount of memory allocated for the weights is
# sum(output_partition_sizes) * input_size_per_partition.
weight_loader = extra_weight_attrs.pop("weight_loader")
weight = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
@@ -258,11 +224,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if (
vllm_is_batch_invariant()
and current_platform.is_cuda_alike()
and is_layer_moe_router_gate(getattr(layer, "prefix", ""))
):
if vllm_is_batch_invariant() and current_platform.is_cuda_alike():
return linear_batch_invariant(x, layer.weight, bias)
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
@@ -305,15 +267,31 @@ class LinearBase(PluggableLayer):
self.quant_config = quant_config
self.prefix = prefix
self.allow_fp8_block_shape_mismatch = False
if quant_config is None:
self.opt_level = envs.VLLM_LINEAR_OPT_LEVEL
if parse_opt_exclude_layers(envs.VLLM_LINEAR_SPECIFIED_LAYERS, self.prefix) or \
(envs.VLLM_LINEAR_SPECIFIED_KEYS != "" and envs.VLLM_LINEAR_SPECIFIED_KEYS in self.prefix):
self.opt_level = envs.VLLM_LINEAR_SPECIFIED_OPT_LEVEL
self.opt_flag = quant_config is None and self.opt_level != 0 and \
self.__class__.__name__ in LINEAR_OPT_SUPPORTED
if parse_opt_exclude_layers(envs.VLLM_OPT_EXCLUDE_LAYERS, self.prefix):
self.opt_flag = False
logger.info(f"Excluding layer {self.prefix} from optimization")
if self.opt_flag:
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import CompressedTensorsLinearMethod
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import CompressedTensorsW8A8Int8
self.quant_method: QuantizeMethodBase | None = CompressedTensorsLinearMethod(None)
self.scheme = CompressedTensorsW8A8Int8(QuantizationStrategy.CHANNEL, False, True, is_w4a8_linear=True if self.opt_level == 2 else False)
elif quant_config is None:
self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod()
else:
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
self.return_bias = return_bias
self.output_padding_size = 0
self.disable_tp = disable_tp
self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
self.output_padding_size = 0
def update_param_tp_status(self):
for param in self.parameters():
@@ -402,7 +380,7 @@ class ReplicatedLinear(LinearBase):
# If the weight on disk does not have a shape, give it one
# (such scales for AutoFp8).
# Special case for GGUF
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type:
@@ -419,7 +397,17 @@ class ReplicatedLinear(LinearBase):
f"Tried to load weights of size {loaded_weight.size()}"
f"to a parameter of size {param.size()}"
)
if self.opt_flag:
if self.opt_level == 1:
loaded_weight, scale = weight_quant_l1(loaded_weight)
elif self.opt_level == 2:
loaded_weight, scale = weight_quant_l2(loaded_weight, format="NN")
param.data.copy_(loaded_weight)
if self.opt_flag:
params_dict = dict(self.named_parameters())
scale_param = params_dict["weight_scale"]
scale_param.data.copy_(scale)
def forward(
self,
@@ -609,7 +597,18 @@ class ColumnParallelLinear(LinearBase):
if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
if self.opt_flag:
if self.opt_level == 1:
loaded_weight, scale = weight_quant_l1(loaded_weight)
elif self.opt_level == 2:
loaded_weight, scale = weight_quant_l2(loaded_weight, format="NN")
param.load_column_parallel_weight(loaded_weight=loaded_weight)
if self.opt_flag:
params_dict = dict(self.named_parameters())
scale_param = params_dict["weight_scale"]
scale_param.load_column_parallel_weight(loaded_weight=scale)
def forward(
self,
@@ -733,16 +732,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
loaded_shard_id: tuple[int, ...] | int | None = None,
):
self.validate_shard_id(loaded_shard_id)
# FIXME(Isotr0py): Enable tuple shard_id for BNB quantization.
if isinstance(loaded_shard_id, tuple):
raise NotImplementedError(
"Shard id with multiple indices is not supported in weight_loader, "
"please use weight_loader_v2 instead."
)
# Special case for GGUF
# initialize GGUF param after we know the quantize type
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if isinstance(loaded_shard_id, tuple) and (
is_gguf_weight or is_gguf_weight_type
):
raise NotImplementedError(
"Shard id with multiple indices is not supported for GGUF."
)
if is_gguf_weight_type:
if loaded_shard_id is not None:
param.data[loaded_shard_id].copy_(loaded_weight)
@@ -770,7 +769,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# Special case for per-tensor scale to load scalar into fused array.
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
if loaded_shard_id is None:
if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
# Loaded weight is already fused on disk (mlp).
# (e.g., Phi-3's gate_up_proj).
if output_dim is None:
@@ -782,10 +781,25 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
output_sizes = (
self.output_sizes[loaded_shard_id[0] : loaded_shard_id[-1] + 1]
if loaded_shard_id is not None
else self.output_sizes
)
current_shard_offset = 0
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
if (
use_bitsandbytes_4bit
and isinstance(loaded_shard_id, tuple)
and self.tp_size > 1
):
raise NotImplementedError(
"Shard id with multiple indices is not supported "
"for BNB quantization with TP yet."
)
shard_offsets: list[tuple[int, int, int]] = []
for i, output_size in enumerate(self.output_sizes):
for i, output_size in enumerate(output_sizes):
shard_offsets.append((i, current_shard_offset, output_size))
current_shard_offset += output_size
packed_dim = getattr(param, "packed_dim", None)
@@ -850,9 +864,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
if use_bitsandbytes_4bit:
shard_size = loaded_weight.shape[output_dim]
shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
index = list(itertools.accumulate([0] + self.output_sizes))
orig_offsets = {
str(i): (index[i], size) for i, size in enumerate(self.output_sizes)
}
orig_offsets["total"] = (self.output_size, 0)
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_offsets, str(loaded_shard_id)
)
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
start_idx = self.tp_rank * shard_size
if not is_sharded_weight:
@@ -921,12 +940,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
loaded_weight: torch.Tensor,
loaded_shard_id: tuple[int, ...] | int | None = None,
):
if self.opt_flag:
if self.opt_level == 1:
loaded_weight, scale = weight_quant_l1(loaded_weight)
elif self.opt_level == 2:
loaded_weight, scale = weight_quant_l2(loaded_weight, format="NN")
self.validate_shard_id(loaded_shard_id)
dtype = loaded_weight.dtype
if envs.VLLM_W8A8_LINEAR_USE_W4A8 and not (param.shape[0] == 1 or param.shape[1] == 1) and dtype == torch.int8:
load_sizes = [self.output_sizes[i] // 2 for i in range(len(self.output_sizes))]
else:
load_sizes = self.output_sizes
if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
if isinstance(param, PerTensorScaleParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
@@ -953,19 +972,21 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert loaded_shard_id < len(self.output_sizes)
# shard_offset = sum(self.output_sizes[:loaded_shard_id])
# shard_size = self.output_sizes[loaded_shard_id]
shard_offset = sum(load_sizes[:loaded_shard_id])
shard_size = load_sizes[loaded_shard_id]
shard_offset = sum(self.output_sizes[:loaded_shard_id])
shard_size = self.output_sizes[loaded_shard_id]
shard_offset //= self.tp_size
shard_size //= self.tp_size
scale_shard_offset = shard_offset
scale_shard_size = shard_size
if self.opt_flag and self.opt_level == 2:
shard_offset = shard_offset // 2
shard_size = shard_size // 2
if isinstance(param, BlockQuantScaleParameter):
weight_block_size = getattr(self, "weight_block_size", None)
shard_size, shard_offset = adjust_block_scale_shard(
weight_block_size, shard_size, shard_offset
)
param.load_merged_column_weight(
loaded_weight=loaded_weight,
shard_id=loaded_shard_id,
@@ -973,6 +994,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_size=shard_size,
tp_rank=self.tp_rank,
)
if self.opt_flag:
params_dict = dict(self.named_parameters())
scale_param = params_dict["weight_scale"]
scale_param.load_merged_column_weight(
loaded_weight=scale,
shard_id=loaded_shard_id,
shard_offset=scale_shard_offset,
shard_size=scale_shard_size,
tp_rank=self.tp_rank,
)
class QKVParallelLinear(ColumnParallelLinear):
@@ -1128,12 +1159,24 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
shard_size=shard_size, shard_offset=shard_offset
)
loaded_weight_shard = loaded_weight.narrow(
param.output_dim, shard_offset, shard_size
)
if self.opt_level == 2:
loaded_weight_shard = loaded_weight.narrow(
0, shard_offset, shard_size
)
else:
loaded_weight_shard = loaded_weight.narrow(
param.output_dim, shard_offset, shard_size
)
self.weight_loader_v2(param, loaded_weight_shard, shard_id)
def quant(self, loaded_weight: torch.Tensor):
if self.opt_flag:
if self.opt_level == 1:
return weight_quant_l1(loaded_weight)
elif self.opt_level == 2:
return weight_quant_l2(loaded_weight, format="NN")
return loaded_weight, None
def weight_loader_v2(
self,
param: BasevLLMParameter,
@@ -1141,15 +1184,27 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_shard_id: str | None = None,
):
self.validate_shard_id(loaded_shard_id)
params_dict = dict(self.named_parameters())
if loaded_shard_id is None: # special case for certain models
if isinstance(param, PerTensorScaleParameter):
loaded_weight, scale = self.quant(loaded_weight)
param.load_qkv_weight(
loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank
)
if self.opt_flag:
scale_param = params_dict["weight_scale"]
scale_param.load_qkv_weight(
loaded_weight=scale, shard_id=0, tp_rank=self.tp_rank
)
return
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
loaded_weight, scale = self.quant(loaded_weight)
param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank)
if self.opt_flag:
scale_param = params_dict["weight_scale"]
scale_param.load_qkv_weight(loaded_weight=scale, tp_rank=self.tp_rank)
return
# TODO: @dsikka - move to parameter.py
self._load_fused_module_from_checkpoint(param, loaded_weight)
return
@@ -1158,11 +1213,15 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
shard_size = self._get_shard_size_mapping(loaded_shard_id)
dtype = loaded_weight.dtype
# w4a8 gemm需要除2scale 不需要
if envs.VLLM_W8A8_LINEAR_USE_W4A8 and not (param.shape[0] == 1 or param.shape[1] == 1) and dtype == torch.int8:
shard_offset //= 2
shard_size //= 2
scale_shard_offset = shard_offset
scale_shard_size = shard_size
loaded_weight, scale = self.quant(loaded_weight)
if self.opt_flag and self.opt_level == 2:
shard_offset = shard_offset // 2
shard_size = shard_size // 2
if isinstance(param, BlockQuantScaleParameter):
weight_block_size = getattr(self, "weight_block_size", None)
@@ -1179,6 +1238,15 @@ class QKVParallelLinear(ColumnParallelLinear):
tp_rank=self.tp_rank,
)
if self.opt_flag:
scale_param = params_dict["weight_scale"]
scale_param.load_qkv_weight(loaded_weight=scale,
num_heads=self.num_kv_head_replicas,
shard_id=loaded_shard_id,
shard_offset=scale_shard_offset,
shard_size=scale_shard_size,
tp_rank=self.tp_rank)
def weight_loader(
self,
param: Parameter,
@@ -1525,7 +1593,17 @@ class RowParallelLinear(LinearBase):
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
if self.opt_flag:
if self.opt_level == 1:
loaded_weight, scale = weight_quant_l1(loaded_weight)
elif self.opt_level == 2:
loaded_weight, scale = weight_quant_l2(loaded_weight, format="NN")
param.load_row_parallel_weight(loaded_weight=loaded_weight)
if self.opt_flag:
params_dict = dict(self.named_parameters())
scale_param = params_dict["weight_scale"]
scale_param.load_row_parallel_weight(loaded_weight=scale)
def forward(
self,

View File

@@ -61,7 +61,10 @@ class LogitsProcessor(CustomOp):
logits = hidden_states
else:
# Get the logits for the next tokens.
logits = self._get_logits(hidden_states, lm_head, embedding_bias)
if hidden_states.shape[0] > 0:
logits = self._get_logits(hidden_states, lm_head, embedding_bias)
else:
logits = torch.empty([0, lm_head.weight.shape[0]], device=hidden_states.device, dtype=hidden_states.dtype)
if logits is not None:
if self.soft_cap is not None:
logits = logits / self.soft_cap

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Callable
import torch
import torch.nn.functional as F
@@ -43,7 +44,6 @@ class MiniMaxText01RMSNormTP(CustomOp):
self.weight.weight_loader = self.weight_loader
self.variance_epsilon = eps
return
@staticmethod
def weight_loader(
@@ -56,7 +56,6 @@ class MiniMaxText01RMSNormTP(CustomOp):
shard_size = loaded_weight.shape[0] // tp_world
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
param.data.copy_(loaded_weight[shard])
return
def _forward(
self,
@@ -102,6 +101,101 @@ class MiniMaxText01RMSNormTP(CustomOp):
return q, k
def clear_linear_attention_cache_for_new_sequences(
kv_cache: torch.Tensor,
state_indices_tensor: torch.Tensor,
attn_metadata: LinearAttentionMetadata,
) -> None:
num_prefills = getattr(attn_metadata, "num_prefills", 0)
if num_prefills <= 0:
return
num_decode_tokens = getattr(attn_metadata, "num_decode_tokens", 0)
for prefill_idx in range(num_prefills):
q_start = attn_metadata.query_start_loc[num_decode_tokens + prefill_idx]
q_end = attn_metadata.query_start_loc[num_decode_tokens + prefill_idx + 1]
query_len = q_end - q_start
context_len = (
attn_metadata.seq_lens[num_decode_tokens + prefill_idx] - query_len
)
if context_len == 0:
block_to_clear = state_indices_tensor[num_decode_tokens + prefill_idx]
kv_cache[block_to_clear, ...] = 0
def linear_attention_decode(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kv_cache: torch.Tensor,
slope_rate: torch.Tensor,
state_indices_tensor: torch.Tensor,
q_start: int = 0,
q_end: int | None = None,
slot_start: int = 0,
slot_end: int | None = None,
block_size: int = 32,
) -> torch.Tensor:
q = q[q_start:q_end].unsqueeze(2).contiguous()
k = k[q_start:q_end].unsqueeze(2).contiguous()
v = v[q_start:q_end].unsqueeze(2).contiguous()
slot_id = state_indices_tensor[slot_start:slot_end]
return linear_decode_forward_triton(
q, k, v, kv_cache, slope_rate, slot_id, block_size
)
def linear_attention_prefill_and_mix(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kv_cache: torch.Tensor,
state_indices_tensor: torch.Tensor,
attn_metadata: LinearAttentionMetadata,
slope_rate: torch.Tensor,
block_size: int,
decode_fn: Callable[..., torch.Tensor],
prefix_fn: Callable[..., torch.Tensor],
layer_idx: int | None = None,
) -> torch.Tensor:
hidden = []
for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
if _prefill_idx >= len(attn_metadata.query_start_loc):
break
if _prefill_idx >= len(state_indices_tensor):
break
offset = attn_metadata.num_decode_tokens
_start = attn_metadata.query_start_loc[offset + _prefill_idx]
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
slot_id = state_indices_tensor[offset + _prefill_idx]
qs = q[_start:_end].transpose(0, 1).contiguous()
ks = k[_start:_end].transpose(0, 1).contiguous()
vs = v[_start:_end].transpose(0, 1).contiguous()
slice_layer_cache = kv_cache[slot_id, ...]
out_slice = prefix_fn(
qs,
ks,
vs,
slice_layer_cache,
slope_rate,
block_size,
layer_idx=layer_idx,
)
hidden.append(out_slice.contiguous())
if attn_metadata.num_decode_tokens > 0:
hidden_decode = decode_fn(
q, k, v, kv_cache, state_indices_tensor, attn_metadata
)
hidden.insert(0, hidden_decode)
if not hidden:
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
hidden = torch.concat(hidden, dim=0).contiguous()
return hidden
class MiniMaxText01LinearKernel:
@staticmethod
def jit_linear_forward_prefix(
@@ -258,50 +352,33 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
def _prefill_and_mix_infer(
self, q, k, v, kv_cache, state_indices_tensor, attn_metadata
):
hidden = []
for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
if _prefill_idx >= len(attn_metadata.query_start_loc):
break
if _prefill_idx >= len(state_indices_tensor):
break
offset = attn_metadata.num_decode_tokens
_start = attn_metadata.query_start_loc[offset + _prefill_idx]
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
slot_id = state_indices_tensor[offset + _prefill_idx]
qs = q[_start:_end].transpose(0, 1).contiguous()
ks = k[_start:_end].transpose(0, 1).contiguous()
vs = v[_start:_end].transpose(0, 1).contiguous()
slice_layer_cache = kv_cache[slot_id, ...]
out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix(
qs,
ks,
vs,
slice_layer_cache,
self.tp_slope,
self.BLOCK,
layer_idx=self.layer_idx,
)
hidden.append(out_slice.contiguous())
if attn_metadata.num_decode_tokens > 0:
hidden_decode = self._decode_infer(
q, k, v, kv_cache, state_indices_tensor, attn_metadata
)
hidden.insert(0, hidden_decode)
if not hidden:
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
hidden = torch.concat(hidden, dim=0).contiguous()
return hidden
return linear_attention_prefill_and_mix(
q=q,
k=k,
v=v,
kv_cache=kv_cache,
state_indices_tensor=state_indices_tensor,
attn_metadata=attn_metadata,
slope_rate=self.tp_slope,
block_size=self.BLOCK,
decode_fn=self._decode_infer,
prefix_fn=MiniMaxText01LinearKernel.jit_linear_forward_prefix,
layer_idx=self.layer_idx,
)
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, attn_metadata):
q = q[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
k = k[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
v = v[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
slot_id = state_indices_tensor[: attn_metadata.num_decodes]
hidden = linear_decode_forward_triton(
q, k, v, kv_cache, self.tp_slope, slot_id, 32
hidden = linear_attention_decode(
q,
k,
v,
kv_cache,
self.tp_slope,
state_indices_tensor,
q_start=0,
q_end=attn_metadata.num_decode_tokens,
slot_start=0,
slot_end=attn_metadata.num_decodes,
block_size=32,
)
return hidden
@@ -338,27 +415,9 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
if attn_metadata is not None:
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
state_indices_tensor = attn_metadata.state_indices_tensor
num_prefills = getattr(attn_metadata, "num_prefills", 0)
if num_prefills > 0:
num_decode_tokens = getattr(attn_metadata, "num_decode_tokens", 0)
for prefill_idx in range(num_prefills):
q_start = attn_metadata.query_start_loc[
num_decode_tokens + prefill_idx
]
q_end = attn_metadata.query_start_loc[
num_decode_tokens + prefill_idx + 1
]
query_len = q_end - q_start
context_len = (
attn_metadata.seq_lens[num_decode_tokens + prefill_idx]
- query_len
)
if context_len == 0:
block_to_clear = state_indices_tensor[
num_decode_tokens + prefill_idx
]
kv_cache[block_to_clear, ...] = 0
clear_linear_attention_cache_for_new_sequences(
kv_cache, state_indices_tensor, attn_metadata
)
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
if attn_metadata is None:

View File

@@ -271,6 +271,8 @@ class MambaMixer(MambaBase, PluggableLayer):
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
has_initial_states_p = attn_metadata.has_initial_states_p
cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p
last_chunk_indices_p = attn_metadata.last_chunk_indices_p
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
@@ -376,6 +378,8 @@ class MambaMixer(MambaBase, PluggableLayer):
block_idx_first_scheduled_token=block_idx_first_scheduled_token_p,
block_idx_last_scheduled_token=block_idx_last_scheduled_token_p,
initial_state_idx=block_idx_last_computed_token_p,
cu_chunk_seqlen=cu_chunk_seqlen_p,
last_chunk_indices=last_chunk_indices_p,
)
ssm_outputs.append(scan_out_p)

View File

@@ -289,9 +289,6 @@ def get_temporal_copy_spec(
)
get_full_copy_spec = get_temporal_copy_spec
class MambaStateCopyFuncCalculator:
@classmethod
def linear_attention_state_copy_func(cls):

View File

@@ -1159,7 +1159,7 @@ def causal_conv1d_update(
f"ERROR: conv_state_indices should have shape ({batch},*) but got {conv_state_indices.shape}"
)
# assert num_cache_lines >= batch
assert num_cache_lines >= batch
assert weight.stride(1) == 1 # Need this
# adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o'

View File

@@ -497,6 +497,8 @@ def selective_scan_fn(
block_idx_first_scheduled_token=None,
block_idx_last_scheduled_token=None,
initial_state_idx=None,
cu_chunk_seqlen=None,
last_chunk_indices=None,
) -> torch.Tensor:
"""
u: (dim, total_length) for varlen or (batch, dim, seqlen)
@@ -588,6 +590,8 @@ def selective_scan_fn(
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
initial_state_idx,
cu_chunk_seqlen,
last_chunk_indices,
)
if z is None:

View File

@@ -9,7 +9,6 @@ from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.attention import MLAAttention
from vllm.model_executor.layers.quantization import QuantizationConfig
@dataclass
class MLAModules:
"""Modules used in MLA."""
@@ -18,7 +17,7 @@ class MLAModules:
kv_b_proj: torch.nn.Module
rotary_emb: torch.nn.Module
o_proj: torch.nn.Module
fused_qkv_a_proj: torch.nn.Module | None
q_a_proj: torch.nn.Module | None
kv_a_proj_with_mqa: torch.nn.Module | None
q_a_layernorm: torch.nn.Module | None
q_b_proj: torch.nn.Module | None
@@ -74,7 +73,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.num_heads = num_heads
self.fused_qkv_a_proj = mla_modules.fused_qkv_a_proj
self.q_a_proj = mla_modules.q_a_proj
self.kv_a_proj_with_mqa = mla_modules.kv_a_proj_with_mqa
self.q_a_layernorm = mla_modules.q_a_layernorm
self.q_b_proj = mla_modules.q_b_proj
@@ -106,7 +105,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
kv_b_proj=self.kv_b_proj,
use_sparse=self.is_sparse,
indexer=self.indexer,
rotary_emb=self.rotary_emb
rotary_emb=self.rotary_emb,
)
self.prefix = prefix
@@ -119,60 +118,47 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
) -> torch.Tensor:
q_c = None
kv_lora = None
if self.q_lora_rank is not None:
assert self.fused_qkv_a_proj is not None, (
"fused_qkv_a_proj is required when q_lora_rank is not None"
)
assert self.q_a_layernorm is not None, (
"q_a_layernorm is required when q_lora_rank is not None"
)
assert self.q_b_proj is not None, (
"q_b_proj is required when q_lora_rank is not None"
)
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
q_c, kv_lora = qkv_lora.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
dim=-1,
)
q_c = self.q_a_layernorm(q_c)
q = self.q_b_proj(q_c)[0]
q = self.q_a_proj(hidden_states)[0]
kv_a, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split([self.kv_lora_rank, self.qk_rope_head_dim], dim=1)
q = self.q_a_layernorm(q)
q = self.q_b_proj(q)[0].view(-1, self.num_heads, self.qk_head_dim)
kv_a = self.kv_a_layernorm(kv_a)
else:
assert self.kv_a_proj_with_mqa is not None, (
"kv_a_proj_with_mqa is required when q_lora_rank is None"
)
assert self.q_proj is not None, (
"q_proj is required when q_lora_rank is None"
)
kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0]
q = self.q_proj(hidden_states)[0]
kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c)
q = q.view(-1, self.num_heads, self.qk_head_dim)
# Add head dim of 1 to k_pe
# k_pe = k_pe.unsqueeze(1)
# if self.rotary_emb is not None:
# q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb(
# positions, q[..., self.qk_nope_head_dim :], k_pe
# )
if self.indexer and self.is_sparse:
_topk_indices = self.indexer(
hidden_states, q_c, positions, self.indexer_rope_emb
)
q = self.q_proj(hidden_states)[0].view(-1, self.num_heads, self.qk_head_dim)
latent_kpe = self.kv_a_proj_with_mqa(hidden_states)[0]
kv_a, k_pe = latent_kpe.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=1)
kv_a = self.kv_a_layernorm(kv_a)
# NOTE attention data do not have position, pass it here
if llama_4_scaling is not None:
q *= llama_4_scaling
self.mla_attn.impl.forward_prepare(positions)
attn_out = self.mla_attn(
q,
kv_c_normed,
k_pe,
output_shape=(hidden_states.shape[0], self.num_heads * self.v_head_dim),
)
attn_out = self.mla_attn(q, kv_a, k_pe, positions)
return self.o_proj(attn_out)[0]
def forward_opt(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None = None):
if self.q_lora_rank is not None:
q_latent_kpe = self.q_a_proj(hidden_states)[0]
q, kv_a, k_pe, _ = q_latent_kpe.split([self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim, self.q_a_proj.output_padding_size], dim=1)
q_c = self.q_a_layernorm(q)
q = self.q_b_proj(q_c)[0].view(-1, self.num_heads, self.qk_head_dim)
kv_a = self.kv_a_layernorm(kv_a)
else:
q = self.q_proj(hidden_states)[0].view(-1, self.num_heads, self.qk_head_dim)
latent_kpe = self.kv_a_proj_with_mqa(hidden_states)[0]
kv_a, k_pe = latent_kpe.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=1)
kv_a = self.kv_a_layernorm(kv_a)
if self.indexer and self.is_sparse:
_topk_indices = self.indexer(hidden_states, q_c, positions,
self.rotary_emb)
# NOTE attention data do not have position, pass it here
if llama_4_scaling is not None:
q *= llama_4_scaling
attn_out = self.mla_attn(q, kv_a, k_pe, positions)
return self.o_proj(attn_out)[0]

View File

@@ -18,6 +18,7 @@ QuantizationMethods = Literal[
"modelopt",
"modelopt_fp4",
"modelopt_mxfp8",
"modelopt_mixed",
"gguf",
"gptq_marlin",
"awq_marlin",
@@ -32,6 +33,7 @@ QuantizationMethods = Literal[
"mxfp4",
"petit_nvfp4",
"cpu_awq",
"w8a16"
]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
@@ -120,12 +122,18 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .gptq import GPTQConfig
from .gptq_marlin import GPTQMarlinConfig
from .inc import INCConfig
from .modelopt import ModelOptFp8Config, ModelOptMxFp8Config, ModelOptNvFp4Config
from .modelopt import (
ModelOptFp8Config,
ModelOptMixedPrecisionConfig,
ModelOptMxFp8Config,
ModelOptNvFp4Config,
)
from .moe_wna16 import MoeWNA16Config
from .mxfp4 import Mxfp4Config
from .petit import PetitNvFp4Config
from .ptpc_fp8 import PTPCFp8Config
from .torchao import TorchAOConfig
from .w8a16 import W8a16Config
method_to_config: dict[str, type[QuantizationConfig]] = {
"awq": AWQConfig,
@@ -135,6 +143,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"modelopt": ModelOptFp8Config,
"modelopt_fp4": ModelOptNvFp4Config,
"modelopt_mxfp8": ModelOptMxFp8Config,
"modelopt_mixed": ModelOptMixedPrecisionConfig,
"gguf": GGUFConfig,
"gptq_marlin": GPTQMarlinConfig,
"awq_marlin": AWQMarlinConfig,
@@ -151,6 +160,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"mxfp4": Mxfp4Config,
"petit_nvfp4": PetitNvFp4Config,
"cpu_awq": CPUAWQConfig,
"w8a16": W8a16Config,
}
# Update the `method_to_config` with customized quantization methods.
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any
import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
@@ -9,6 +9,7 @@ from torch.nn import Parameter
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
@@ -60,6 +61,7 @@ from vllm.transformers_utils.config import get_safetensors_params_metadata
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.models.utils import WeightsMapper
import ixformer.inference.functions as ixfops
logger = init_logger(__name__)
@@ -197,7 +199,7 @@ class AWQMarlinConfig(QuantizationConfig):
quant_method.input_dtype = get_marlin_input_dtype(prefix)
return quant_method
elif isinstance(layer, FusedMoE):
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
# from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
# if is_layer_skipped(
# prefix,
@@ -213,9 +215,10 @@ class AWQMarlinConfig(QuantizationConfig):
# return MoeWNA16Config.from_config(self.full_config).get_quant_method(
# layer, prefix
# )
moe_quant_method = AWQMarlinMoEMethod(self, layer.moe_config)
moe_quant_method.input_dtype = get_marlin_input_dtype(prefix)
return moe_quant_method
# moe_quant_method = AWQMarlinMoEMethod(self, layer.moe_config)
# moe_quant_method.input_dtype = get_marlin_input_dtype(prefix)
# return moe_quant_method
return AWQMarlinMoEMethod(self, layer.moe_config)
return None
@classmethod
@@ -389,13 +392,13 @@ class AWQMarlinLinearMethod(LinearMethodBase):
replace_parameter(layer, "qweight", pad_qweight)
replace_parameter(layer, "qzeros", pad_qzeros)
replace_parameter(layer, "scales", pad_scales)
return
# TODO(gyf) Marlin format is not support for now..
device = layer.qweight.device
layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
return
# Allocate marlin workspace
layer.workspace = marlin_make_workspace_new(device)
@@ -811,49 +814,33 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
self,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
# Assign the value of shared_experts_output to variable shared_experts_input for fusion
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
# return fused_marlin_moe(
# x,
# layer.w13_qweight,
# layer.w2_qweight,
# getattr(layer, "w13_bias", None),
# getattr(layer, "w2_bias", None),
# layer.w13_scales,
# layer.w2_scales,
# topk_weights,
# topk_ids,
# input_global_scale1=getattr(layer, "w13_input_global_scale", None),
# input_global_scale2=getattr(layer, "w2_input_global_scale", None),
# quant_type_id=self.quant_type.id,
# apply_router_weight_on_input=layer.apply_router_weight_on_input,
# global_num_experts=layer.global_num_experts,
# expert_map=layer.expert_map,
# w1_zeros=layer.w13_qzeros,
# w2_zeros=layer.w2_qzeros,
# workspace=layer.workspace,
# input_dtype=self.input_dtype,
# inplace=not self.moe.disable_inplace,
# )
num_tokens, num_experts = router_logits.shape
assert layer.activation.value == "silu", "Only SiLU activation is supported."
use_ep = layer.expert_map is not None
attn_metadata = get_forward_context().attn_metadata
if attn_metadata:
if isinstance(attn_metadata, dict):
only_decode = (use_ep == False and all(t.num_decodes > 0 and t.num_prefills ==0 for t in list(attn_metadata.values())))
else:
only_decode = use_ep == False and attn_metadata.num_decodes > 0 and attn_metadata.num_prefills == 0
else:
only_decode = False
if use_ep:
start_eid = layer.ep_rank * layer.local_num_experts
end_eid = min((layer.ep_rank + 1) * layer.local_num_experts, layer.global_num_experts)
if layer.apply_router_weight_on_input:
raise NotImplementedError(
"Apply router weight on input is not supported for"
"fused Marlin MoE method.")
num_tokens = topk_ids.shape[0]
num_experts = layer.global_num_experts
if use_ep:
hidden_size = x.shape[1]
(
@@ -875,7 +862,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
dtype=x.dtype,
)
else:
expand_tokens = num_tokens * top_k
expand_tokens = num_tokens * layer.top_k
(
src_to_dst,
sorted_token_ids,
@@ -885,7 +872,6 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
topk_ids=topk_ids,
num_experts=num_experts,
)
expert_sizes_cpu = expert_sizes_gpu.cpu()
# expand + reorder
# TODO use kernel
@@ -893,76 +879,130 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
hidden_states=x,
dst_to_src=sorted_token_ids,
dst_tokens=expand_tokens,
topk=top_k,
topk=layer.top_k,
src_to_dst=src_to_dst,
)
# w4a16 group gemm 1
# pt_output_1: (expand_tokens, 2n) dtype
pt_output_1 = ixfops.moe_w4a16_group_gemm(
input=expand_hidden_states,
weight=layer.w13_qweight,
w_scales=layer.w13_scales,
quant_type="awq",
tokens_per_experts=expert_sizes_cpu,
w_zeros=layer.w13_qzeros,
group_size=self.quant_config.group_size,
dst_to_src=None,
format="NN",
tokens_per_experts_gpu=expert_sizes_gpu,
)
# act
pt_output_2 = ixfops.silu_and_mul(pt_output_1)
# w4a16 group gemm 2 + reorder
# pt_output_3: (expand_tokens, k) dtype
if use_ep:
pt_output_3 = torch.empty(
(num_tokens * top_k, hidden_size),
device=x.device,
dtype=x.dtype,
)
ixfops.moe_w4a16_group_gemm(
input=pt_output_2,
weight=layer.w2_qweight,
w_scales=layer.w2_scales,
if only_decode:
pt_output_1 = ixfops.moe_w4a16_group_gemv(
input=expand_hidden_states,
weight=layer.w13_qweight,
w_scales=layer.w13_scales,
quant_type="awq",
tokens_per_experts=expert_sizes_cpu,
w_zeros=layer.w2_qzeros,
w_zeros=layer.w13_qzeros,
group_size=self.quant_config.group_size,
dst_to_src=sorted_token_ids,
format="NN",
output=pt_output_3,
)
reduce_mask = src_to_dst == -1
final_hidden_states = ixfops.moe_output_reduce_sum(
input=pt_output_3.view(num_tokens, top_k, -1),
topk_weight=topk_weights,
scaling_factor=routed_scaling_factor,
mask=reduce_mask,
)
else:
pt_output_3 = ixfops.moe_w4a16_group_gemm(
input=pt_output_2,
weight=layer.w2_qweight,
w_scales=layer.w2_scales,
quant_type="awq",
tokens_per_experts=expert_sizes_cpu,
w_zeros=layer.w2_qzeros,
group_size=self.quant_config.group_size,
dst_to_src=sorted_token_ids,
dst_to_src=None,
format="NN",
tokens_per_experts_gpu=expert_sizes_gpu,
)
# mul + reduce_sum
# final_hidden_states: (num_tokens, k)
# act
pt_output_2 = ixfops.silu_and_mul(pt_output_1)
pt_output_3 = ixfops.moe_w4a16_group_gemv(
input=pt_output_2,
weight=layer.w2_qweight,
w_scales=layer.w2_scales,
quant_type="awq",
w_zeros=layer.w2_qzeros,
group_size=self.quant_config.group_size,
dst_to_src=sorted_token_ids,
format="NN",
tokens_per_experts_gpu=expert_sizes_gpu,
)
# mul + reduce_sum
# final_hidden_states: (num_tokens, k)
final_hidden_states = ixfops.moe_output_reduce_sum(
input=pt_output_3.view(num_tokens, top_k, -1),
input=pt_output_3.view(num_tokens, layer.top_k, -1),
topk_weight=topk_weights,
scaling_factor=routed_scaling_factor
scaling_factor=layer.routed_scaling_factor,
extra_residual=shared_experts_input,
)
else:
expert_sizes_cpu = expert_sizes_gpu.cpu()
pt_output_1 = ixfops.moe_w4a16_group_gemm(
input=expand_hidden_states,
weight=layer.w13_qweight,
w_scales=layer.w13_scales,
quant_type="awq",
tokens_per_experts=expert_sizes_cpu,
w_zeros=layer.w13_qzeros,
group_size=self.quant_config.group_size,
dst_to_src=None,
format="NN",
tokens_per_experts_gpu=expert_sizes_gpu,
)
# act
pt_output_2 = ixfops.silu_and_mul(pt_output_1)
# w4a16 group gemm 2 + reorder
# pt_output_3: (expand_tokens, k) dtype
if use_ep:
pt_output_3 = torch.empty(
(num_tokens * layer.top_k, hidden_size),
device=x.device,
dtype=x.dtype,
)
ixfops.moe_w4a16_group_gemm(
input=pt_output_2,
weight=layer.w2_qweight,
w_scales=layer.w2_scales,
quant_type="awq",
tokens_per_experts=expert_sizes_cpu,
w_zeros=layer.w2_qzeros,
group_size=self.quant_config.group_size,
dst_to_src=sorted_token_ids,
format="NN",
output=pt_output_3,
tokens_per_experts_gpu=expert_sizes_gpu,
)
reduce_mask = src_to_dst == -1
final_hidden_states = ixfops.moe_output_reduce_sum(
input=pt_output_3.view(num_tokens, layer.top_k, -1),
topk_weight=topk_weights,
scaling_factor=layer.routed_scaling_factor,
mask=reduce_mask,
)
else:
pt_output_3 = ixfops.moe_w4a16_group_gemm(
input=pt_output_2,
weight=layer.w2_qweight,
w_scales=layer.w2_scales,
quant_type="awq",
tokens_per_experts=expert_sizes_cpu,
w_zeros=layer.w2_qzeros,
group_size=self.quant_config.group_size,
dst_to_src=sorted_token_ids,
format="NN",
tokens_per_experts_gpu=expert_sizes_gpu,
)
# mul + reduce_sum
# final_hidden_states: (num_tokens, k)
final_hidden_states = ixfops.moe_output_reduce_sum(
input=pt_output_3.view(num_tokens, layer.top_k, -1),
topk_weight=topk_weights,
scaling_factor=layer.routed_scaling_factor,
extra_residual=shared_experts_input,
)
return final_hidden_states
# return torch.ops.vllm.fused_marlin_moe(
# x,
# layer.w13_qweight,
# layer.w2_qweight,
# layer.w13_scales,
# layer.w2_scales,
# router_logits,
# topk_weights,
# topk_ids,
# w1_zeros=layer.w13_qzeros,
# w2_zeros=layer.w2_qzeros,
# num_bits=self.quant_config.weight_bits,
# )

View File

@@ -18,7 +18,6 @@ from compressed_tensors.quantization import (
)
from compressed_tensors.transform import TransformConfig
import vllm.envs as envs
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
@@ -52,7 +51,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8,
CompressedTensorsWNA16,
CompressedTensorsW4A8Int8
)
from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501
CompressedTensorsLinearTransformMethod,
@@ -401,8 +399,8 @@ class CompressedTensorsConfig(QuantizationConfig):
) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
weight_strategy = (
weight_quant.strategy == QuantizationStrategy.TENSOR.value
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value
weight_quant.strategy == QuantizationStrategy.CHANNEL.value
or weight_quant.strategy == QuantizationStrategy.GROUP.value
)
is_tensor = (
weight_strategy
@@ -420,8 +418,8 @@ class CompressedTensorsConfig(QuantizationConfig):
) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
weight_strategy = (
weight_quant.strategy == QuantizationStrategy.TENSOR.value
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value
weight_quant.strategy == QuantizationStrategy.CHANNEL.value
or weight_quant.strategy == QuantizationStrategy.GROUP.value
)
is_token = (
weight_strategy and input_quant.strategy == QuantizationStrategy.TOKEN.value
@@ -663,12 +661,6 @@ class CompressedTensorsConfig(QuantizationConfig):
)
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
if envs.VLLM_W8A8_LINEAR_USE_W4A8:
return CompressedTensorsW4A8Int8(
strategy=weight_quant.strategy,
is_static_input_scheme=False,
input_symmetric=input_quant.symmetric,
)
return CompressedTensorsW8A8Int8(
strategy=weight_quant.strategy,
is_static_input_scheme=False,

View File

@@ -8,7 +8,7 @@ from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int
from .compressed_tensors_w4a16_mxfp4 import CompressedTensorsW4A16Mxfp4
from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8, CompressedTensorsW4A8Int8
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS, CompressedTensorsWNA16
@@ -28,5 +28,4 @@ __all__ = [
"CompressedTensorsW4A4Fp4",
"CompressedTensorsW4A8Int",
"CompressedTensorsW4A8Fp8",
"CompressedTensorsW4A8Int8"
]

View File

@@ -25,11 +25,18 @@ logger = init_logger(__name__)
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
def __init__(
self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool
self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool, is_w4a8_linear: bool = False
):
self.strategy = strategy
import vllm.envs as env
if env.VLLM_MIX_QUANTIZATION_TYPE == "TENSOR":
self.strategy = QuantizationStrategy.TENSOR
elif env.VLLM_MIX_QUANTIZATION_TYPE == "CHANNEL":
self.strategy = QuantizationStrategy.CHANNEL
else:
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.input_symmetric = input_symmetric
self.is_w4a8_linear = is_w4a8_linear
@classmethod
def get_min_capability(cls) -> int:
@@ -53,16 +60,32 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
input_symmetric=self.input_symmetric,
module_name=self.__class__.__name__,
)
remainder = input_size_per_partition % 64
if remainder != 0:
input_size_per_partition_padded = input_size_per_partition + (64 - remainder)
else:
input_size_per_partition_padded = input_size_per_partition
# WEIGHT
weight = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
if self.is_w4a8_linear:
# only "NN" is supported
weight = ModelWeightParameter(data=torch.empty(
input_size_per_partition_padded,
sum(output_partition_sizes) // 2,
dtype=torch.int8),
input_dim=0,
output_dim=1,
weight_loader=weight_loader,
)
else:
weight = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition_padded,
dtype=torch.int8),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
@@ -109,104 +132,4 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
) -> torch.Tensor:
return self.kernel.apply_weights(layer, x, bias)
class CompressedTensorsW4A8Int8(CompressedTensorsScheme):
def __init__(
self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool
):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.input_symmetric = input_symmetric
@classmethod
def get_min_capability(cls) -> int:
# turing and up
return 75
def create_weights(
self,
layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
):
layer.logical_widths = output_partition_sizes
self.kernel = init_int8_linear_kernel(
is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
is_static_input_scheme=self.is_static_input_scheme,
input_symmetric=self.input_symmetric,
module_name=self.__class__.__name__,
)
# WEIGHT
# weight = ModelWeightParameter(
# data=torch.empty(
# sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8
# ),
# input_dim=1,
# output_dim=0,
# weight_loader=weight_loader,
# )
weight = ModelWeightParameter(
data=torch.empty(
input_size_per_partition,
sum(output_partition_sizes) // 2,
dtype=torch.int8
),
input_dim=0,
output_dim=1,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
if self.strategy == QuantizationStrategy.CHANNEL:
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader,
)
else:
assert self.strategy == QuantizationStrategy.TENSOR
weight_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE
input_zero_point = None
input_scale = None
if self.is_static_input_scheme:
input_scale = BasevLLMParameter(
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
)
if not self.input_symmetric:
# Note: compressed-tensors stores the zp using the same dtype
# as the weights
# AZP loaded as int8 but used as int32
input_zero_point = BasevLLMParameter(
data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
)
layer.register_parameter("input_zero_point", input_zero_point)
layer.register_parameter("input_scale", input_scale)
if not hasattr(layer, "azp_adj"):
layer.register_parameter("azp_adj", None)
# Checkpoints are serialized in compressed-tensors format, which is
# different from the format the kernel may want. Handle repacking here.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.kernel.process_weights_after_loading(layer)
def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
) -> torch.Tensor:
return self.kernel.apply_weights(layer, x, bias)
return self.kernel.apply_weights(layer, x, bias, self.is_w4a8_linear)

View File

@@ -23,17 +23,13 @@ from vllm.model_executor.layers.batch_invariant import (
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
FusedMoeWeightScaleSupported,
MoEActivation,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
make_fp8_moe_quant_config,
@@ -50,9 +46,6 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
create_fp8_input_scale,
@@ -860,14 +853,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)
# Setup modular kernel for TP case and naive DP/EP case.
# In non-naive DP/EP case, we will create a ModularKernelMethod.
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
assert self.experts_cls is not None
self.moe_mk = make_fp8_moe_kernel(
self.moe_kernel = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
@@ -930,29 +919,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
# TRTLLM does not use Modular Kernel.
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return None
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
a1_scale = layer.w13_input_scale
@@ -983,10 +956,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def supports_eplb(self) -> bool:
return True
@property
def is_monolithic(self) -> bool:
return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
def apply_monolithic(
self,
layer: FusedMoE,
@@ -994,50 +963,22 @@ class Fp8MoEMethod(FusedMoEMethodBase):
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic
assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
# TODO(rob): convert this to MK.
if layer.enable_eplb:
raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
assert layer.activation == MoEActivation.SILU, (
f"Expected 'silu' activation but got {layer.activation}"
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
x,
layer.w13_weight,
layer.w2_weight,
router_logits,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
e_score_correction_bias=layer.e_score_correction_bias,
routed_scaling_factor=layer.routed_scaling_factor,
)
if self.block_quant:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
routing_logits=router_logits,
routing_bias=layer.e_score_correction_bias,
x=x,
w13_weight=layer.w13_weight,
w13_weight_scale_inv=layer.w13_weight_scale_inv,
w2_weight=layer.w2_weight,
w2_weight_scale_inv=layer.w2_weight_scale_inv,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
intermediate_size=layer.intermediate_size_per_partition,
expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
block_shape=self.weight_block_size,
routing_method_type=layer.routing_method_type,
routed_scaling=layer.routed_scaling_factor,
)
else:
return apply_fi_trtllm_fp8_per_tensor_moe(
layer=layer,
hidden_states=x,
router_logits=router_logits,
routing_bias=layer.e_score_correction_bias,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
def apply(
self,
layer: FusedMoE,
@@ -1046,9 +987,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.moe_mk is not None
assert not self.is_monolithic
return self.moe_mk(
assert self.moe_kernel is not None
return self.moe_kernel.apply(
x,
layer.w13_weight,
layer.w2_weight,

View File

@@ -7,6 +7,7 @@ from typing import Any
import gguf
import torch
import torch.nn.functional as F
from gguf import GGMLQuantizationType as WeightType
from torch.nn.parameter import Parameter, UninitializedParameter
@@ -234,7 +235,7 @@ try:
op_func=_fused_mul_mat_gguf,
fake_impl=_fused_mul_mat_gguf_fake,
)
fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf
fused_mul_mat_gguf = _fused_mul_mat_gguf
except AttributeError as error:
raise error
@@ -365,7 +366,7 @@ try:
op_func=_fused_moe_gguf,
fake_impl=_fused_moe_gguf_fake,
)
fused_moe_gguf = torch.ops.vllm._fused_moe_gguf
fused_moe_gguf = _fused_moe_gguf
except AttributeError as error:
raise error
@@ -410,7 +411,7 @@ try:
op_func=_apply_gguf_embedding,
fake_impl=_apply_gguf_embedding_fake,
)
apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding
apply_gguf_embedding = _apply_gguf_embedding
except AttributeError as error:
raise error
@@ -451,6 +452,9 @@ class GGUFLinearMethod(LinearMethodBase):
"data_container": [],
"shard_id": [],
"shard_id_map": {},
"params_dtype": params_dtype,
"input_size_per_partition" :input_size_per_partition, # restore shape for qkv and merge
"output_partition_sizes" :output_partition_sizes,
},
)
set_weight_attrs(qweight, extra_weight_attrs)
@@ -664,6 +668,10 @@ class GGUFEmbeddingMethod(GGUFLinearMethod):
"""
def embedding(self, layer: torch.nn.Module, x: torch.Tensor) -> torch.Tensor:
weight = layer.weight
return F.embedding(x, weight)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
qweight = layer.qweight
qweight_type = layer.qweight_type.weight_type
hidden_size = qweight.tensor_shape[1]

View File

@@ -128,7 +128,7 @@ class GPTQConfig(QuantizationConfig):
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16]
return [torch.bfloat16, torch.half]
@classmethod
# Need to figure it out

View File

@@ -59,9 +59,164 @@ from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.transformers_utils.config import get_safetensors_params_metadata
from vllm.utils.collection_utils import is_list_of
import ixformer.inference.functions as ixfops
logger = init_logger(__name__)
#[B,K//8,N] ->[B,K,N]
# less memmory
def unpack_k_batch_opt(packed_w: torch.Tensor, num_bits: int = 4, chunk_size: int = 2) -> torch.Tensor:
"""
Memory-efficient unpacking for 3D tensor.
Converts [B, K // pack_factor, N] int32 tensor → [B, K, N] int8 tensor,
without broadcasting huge intermediate tensors (avoids OOM).
Args:
packed_w: torch.int32 tensor of shape [B, K // pack_factor, N].
num_bits: Number of bits per packed element (e.g., 4 or 2).
chunk_size: How many bit groups to unpack at once (tradeoff between speed and memory).
Returns:
unpacked: torch.int8 tensor of shape [B, K, N].
"""
B, k_packed, N = packed_w.shape
pack_factor = 32 // num_bits
K = k_packed * pack_factor
mask = (1 << num_bits) - 1
# Allocate output tensor once
unpacked = torch.empty((B, K, N), dtype=torch.int8, device=packed_w.device)
# Process bit chunks iteratively to save memory
for i in range(0, pack_factor, chunk_size):
# Precompute shifts for this chunk
shift_vals = num_bits * torch.arange(i, min(i + chunk_size, pack_factor), device=packed_w.device)
# [chunk_size, 1, 1, 1]
shifts = shift_vals.view(-1, 1, 1, 1)
# Compute small chunk only
chunk = ((packed_w.unsqueeze(0) >> shifts) & mask).to(torch.int8)
# chunk: [chunk_size, B, k_packed, N]
# write into output
for j in range(chunk.shape[0]):
unpacked[:, (i + j)::pack_factor, :] = chunk[j]
del chunk # release memory early
return unpacked
# more memmory
def unpack_k_batch(packed_w: torch.Tensor, num_bits: int = 4) -> torch.Tensor:
"""
Efficient vectorized unpacking for 3D tensor (batch version).
Converts [B, K // pack_factor, N] int32 tensor → [B, K, N] int8 tensor.
Args:
packed_w: torch.int32 tensor of shape [B, K // pack_factor, N].
num_bits: Number of bits per packed element (e.g., 4).
Returns:
unpacked: torch.int8 tensor of shape [B, K, N].
"""
B, k_packed, n = packed_w.shape
pack_factor = 32 // num_bits
k = k_packed * pack_factor
mask = (1 << num_bits) - 1
# [pack_factor, 1, 1, 1]
shifts = (num_bits * torch.arange(pack_factor, device=packed_w.device)).view(-1, 1, 1, 1)
# [1, B, k_packed, N]
packed_expanded = packed_w.unsqueeze(0)
# Extract each group of num_bits using bitwise ops
unpacked_groups = ((packed_expanded >> shifts) & mask).to(torch.int8)
# [pack_factor, B, k_packed, N] → [B, K, N]
unpacked = unpacked_groups.permute(1, 2, 0, 3).reshape(B, k, n)
return unpacked
#[B,K,N] ->[B,K,N//8]
# less memmory
def pack_n_batch_opt(x: torch.Tensor, pack_num: int = 8, order_map=None, chunk_size: int = 2) -> torch.Tensor:
"""
Memory-efficient batch packing with correct bit order.
[B, K, N] int4 -> [B, K, N//pack_num] int32.
"""
B, K, N = x.shape
assert N % pack_num == 0, "N must be divisible by pack_num"
cols = N // pack_num
unit = 32 // pack_num
if order_map is None:
order_map = list(range(pack_num))
order_map = torch.tensor(order_map, device=x.device)
shifts = unit * torch.arange(pack_num, device=x.device) # always 0..unit*(pack_num-1)
packed = torch.zeros((B, K, cols), dtype=torch.int32, device=x.device)
x_reshape = x.view(B, K, cols, pack_num) & 0xF
# process in chunks for memory efficiency
for start in range(0, pack_num, chunk_size):
end = min(start + chunk_size, pack_num)
idx_chunk = order_map[start:end]
shift_chunk = shifts[start:end]
vals = torch.gather(x_reshape, 3, idx_chunk.view(1,1,1,-1).expand(B,K,cols,-1)).to(torch.int32)
for j in range(vals.shape[-1]):
packed.add_(vals[..., j] << shift_chunk[j])
return packed
## more memmory
def pack_n_batch(x: torch.Tensor, pack_num: int = 8, order_map=None) -> torch.Tensor:
"""
Efficient vectorized batch packing: [B, K, N] int4 -> [B, K, N//pack_num] int32.
Args:
x: torch.int32 tensor of shape [B, K, N], each element 0-15 (int4).
pack_num: Number of 4-bit elements per packed int32 (default=8).
order_map: Optional order of elements within each packed int32.
Returns:
torch.int32 tensor of shape [B, K, N//pack_num].
"""
B, K, N = x.shape
assert N % pack_num == 0, "N must be divisible by pack_num"
cols = N // pack_num
if order_map is None:
order_map = list(range(pack_num))
order_map = torch.tensor(order_map, device=x.device)
unit = 32 // pack_num # number of bits per element
# reshape to [B, K, cols, pack_num]
pack_num_int = int(pack_num)
x_reshape = x.view(B, K, cols, pack_num_int)
# reorder according to order_map
x_reorder = torch.gather(
x_reshape, 3, order_map.view(1, 1, 1, -1).expand(B, K, cols, -1)
)
# mask low 4 bits
x_reorder = x_reorder & 0xF
# bit shifts [pack_num] -> [1,1,1,pack_num] broadcastable
shifts = (unit * torch.arange(pack_num_int, device=x.device)).view(1, 1, 1, -1)
# shift and sum along last dimension to combine bits
packed = (x_reorder << shifts).sum(dim=-1).to(torch.int32)
return packed
def get_moe_quant_method(
config: "GPTQMarlinConfig",
@@ -495,8 +650,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
self.quant_config = quant_config
if self.quant_config.quant_type.size_bits == 4:
self.quant_type = scalar_types.uint4b8
elif self.quant_config.quant_type.size_bits == 8:
self.quant_type = scalar_types.uint8b128
# elif self.quant_config.quant_type.size_bits == 8:
# self.quant_type = scalar_types.uint8b128
else:
raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.")
self.input_dtype = None
@@ -594,7 +749,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
num_experts,
scales_size13,
2 * intermediate_size_per_partition // self.quant_config.pack_factor,
dtype=params_dtype,
dtype=torch.int32,
),
requires_grad=False,
)
@@ -606,7 +761,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
num_experts,
scales_size2,
hidden_size // self.quant_config.pack_factor,
dtype=params_dtype,
dtype=torch.int32,
),
requires_grad=False,
)
@@ -656,7 +811,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
device = layer.w13_qweight.device
layer.workspace = marlin_make_workspace_new(device, 4)
# layer.workspace = marlin_make_workspace_new(device, 4)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
@@ -673,119 +828,111 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer.w2_scales.data = layer.w2_scales.data * 512
# Process act_order
if self.quant_config.desc_act:
# if self.quant_config.desc_act:
# Get sorting based on g_idx
num_experts = layer.w13_g_idx.shape[0]
w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx)
w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx)
w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
for e in range(num_experts):
w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to(
torch.int32
)
w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(
torch.int32
)
w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]]
w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]]
replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices)
replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices)
else:
# Reset g_idx related tensors
num_experts = layer.w13_g_idx.shape[0]
device = layer.w13_g_idx.device
layer.w13_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w2_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
# Repack weights
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
layer.w13_qweight,
layer.w13_g_idx_sort_indices,
layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
layer.w13_qweight.shape[2],
self.quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
layer.w2_qweight,
layer.w2_g_idx_sort_indices,
layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
layer.w2_qweight.shape[2],
self.quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
# num_experts = layer.w13_g_idx.shape[0]
# w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx)
# w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx)
# w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
# w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
# for e in range(num_experts):
# w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to(
# torch.int32
# )
# w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(
# torch.int32
# )
# w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]]
# w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]]
# replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
# replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
# replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices)
# replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices)
# else:
# # Reset g_idx related tensors
# num_experts = layer.w13_g_idx.shape[0]
# device = layer.w13_g_idx.device
# layer.w13_g_idx = torch.nn.Parameter(
# torch.empty((num_experts, 0), dtype=torch.int32, device=device),
# requires_grad=False,
# )
# layer.w2_g_idx = torch.nn.Parameter(
# torch.empty((num_experts, 0), dtype=torch.int32, device=device),
# requires_grad=False,
# )
# layer.w13_g_idx_sort_indices = torch.nn.Parameter(
# torch.empty((num_experts, 0), dtype=torch.int32, device=device),
# requires_grad=False,
# )
# layer.w2_g_idx_sort_indices = torch.nn.Parameter(
# torch.empty((num_experts, 0), dtype=torch.int32, device=device),
# requires_grad=False,
# )
# # Repack weights
# marlin_w13_qweight = ops.gptq_marlin_moe_repack(
# layer.w13_qweight,
# layer.w13_g_idx_sort_indices,
# layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
# layer.w13_qweight.shape[2],
# self.quant_config.quant_type.size_bits,
# )
# replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
# marlin_w2_qweight = ops.gptq_marlin_moe_repack(
# layer.w2_qweight,
# layer.w2_g_idx_sort_indices,
# layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
# layer.w2_qweight.shape[2],
# self.quant_config.quant_type.size_bits,
# )
# replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
# # Repack scales
# marlin_w13_scales = marlin_moe_permute_scales(
# s=layer.w13_scales,
# size_k=layer.intermediate_size_per_partition,
# size_n=layer.w13_scales.shape[2],
# group_size=self.quant_config.group_size,
# )
# replace_parameter(layer, "w13_scales", marlin_w13_scales)
# marlin_w2_scales = marlin_moe_permute_scales(
# s=layer.w2_scales,
# size_k=layer.w2_scales.shape[1]
# * (
# self.quant_config.group_size
# if self.quant_config.group_size != -1
# else self.quant_config.pack_factor
# ),
# size_n=layer.w2_scales.shape[2],
# group_size=self.quant_config.group_size,
# )
# replace_parameter(layer, "w2_scales", marlin_w2_scales)
# The modular kernel expects w13_weight and w2_weight,
# but GPTQ uses w13_qweight and w2_qweight
# Alias for modular kernel
layer.w13_weight = layer.w13_qweight
# Alias for modular kernel
layer.w2_weight = layer.w2_qweight
# if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
# layer.w13_bias.data = marlin_permute_bias(layer.w13_bias)
# Repack scales
marlin_w13_scales = marlin_moe_permute_scales(
s=layer.w13_scales,
size_k=layer.intermediate_size_per_partition,
size_n=layer.w13_scales.shape[2],
group_size=self.quant_config.group_size,
is_a_8bit=is_a_8bit,
)
if self.input_dtype == torch.int8 and layer.num_groups_w13 > 1:
marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales(
marlin_w13_scales
)
layer.register_parameter(
"w13_input_global_scale",
torch.nn.Parameter(w13_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w13_scales", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales(
s=layer.w2_scales,
size_k=layer.w2_scales.shape[1]
* (
self.quant_config.group_size
if self.quant_config.group_size != -1
else self.quant_config.pack_factor
),
size_n=layer.w2_scales.shape[2],
group_size=self.quant_config.group_size,
is_a_8bit=is_a_8bit,
)
if self.input_dtype == torch.int8 and layer.num_groups_w2 > 1:
marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales(
marlin_w2_scales
)
layer.register_parameter(
"w2_input_global_scale",
torch.nn.Parameter(w2_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w2_scales", marlin_w2_scales)
if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
layer.w13_bias.data = marlin_permute_bias(layer.w13_bias)
if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
# if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
# layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
if self.quant_config.desc_act:
raise NotImplementedError(
"GPTQMarlinMoEMethod now not support desc_act. please fix it")
w13_qweight_unpacked = unpack_k_batch(layer.w13_qweight)
w13_qweight_repacked = pack_n_batch(w13_qweight_unpacked,self.quant_config.pack_factor,order_map=[0, 2, 4, 6, 1, 3, 5, 7])
replace_parameter(layer, "w13_qweight", w13_qweight_repacked)
# quant vllm/model_executor/layers/quantization/utils/quant_utils.py#quantize_weights
# if quant_type.has_bias():
# w_q += quant_type.bias
# use quant_type.bias as zp,(ixformer support)
w13_zp = torch.full_like(layer.w13_scales, self.quant_type.bias, dtype=torch.int32)
w13_zp_pack = pack_n_batch(w13_zp, self.quant_config.pack_factor, order_map=[0, 2, 4, 6, 1, 3, 5, 7]).contiguous()
replace_parameter(layer, "w13_qzeros", w13_zp_pack)
w2_qweight_unpacked = unpack_k_batch(layer.w2_qweight)
w2_qweight_repacked = pack_n_batch(w2_qweight_unpacked,self.quant_config.pack_factor,order_map=[0, 2, 4, 6, 1, 3, 5, 7])
replace_parameter(layer, "w2_qweight", w2_qweight_repacked)
w2_zp = torch.full_like(layer.w2_scales, self.quant_type.bias, dtype=torch.int32)
w2_zp_pack = pack_n_batch(w2_zp, self.quant_config.pack_factor, order_map=[0, 2, 4, 6, 1, 3, 5, 7]).contiguous()
replace_parameter(layer, "w2_qzeros", w2_zp_pack)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
@@ -900,30 +1047,165 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
# Assign the value of shared_experts_output to variable shared_experts_input for fusion
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return fused_marlin_moe(
x,
layer.w13_qweight,
layer.w2_qweight,
getattr(layer, "w13_bias", None),
getattr(layer, "w2_bias", None),
layer.w13_scales,
layer.w2_scales,
topk_weights,
topk_ids,
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
quant_type_id=self.quant_type.id,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
g_idx1=layer.w13_g_idx,
g_idx2=layer.w2_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
workspace=layer.workspace,
is_k_full=self.is_k_full,
input_dtype=self.input_dtype,
inplace=not self.moe.disable_inplace,
assert layer.activation.value == "silu", "Only SiLU activation is supported."
use_ep = layer.expert_map is not None
if use_ep:
start_eid = layer.ep_rank * layer.local_num_experts
end_eid = min((layer.ep_rank + 1) * layer.local_num_experts, layer.global_num_experts)
if layer.apply_router_weight_on_input:
raise NotImplementedError(
"GPTQMarlinMoEMethod Apply router weight on input is not supported for"
"fused Marlin MoE method.")
if (hasattr(layer, "w13_bias") and layer.w13_bias is not None) or (hasattr(layer, "w2_bias") and layer.w2_bias is not None):
raise NotImplementedError(
"GPTQMarlinMoEMethod moe_w4a16_group_gemm not supported bias, please fix this")
num_tokens = topk_ids.shape[0]
num_experts = layer.global_num_experts
if use_ep:
hidden_size = x.shape[1]
(
src_to_dst,
sorted_token_ids,
expert_sizes_gpu,
expert_sizes_cpu,
expand_tokens,
) = ixfops.moe_compute_token_index_ep(
topk_ids=topk_ids,
num_experts=num_experts,
start_expert_id=start_eid,
end_expert_id=end_eid,
)
if expert_sizes_cpu.sum() == 0:
return torch.zeros(
(num_tokens, hidden_size),
device=x.device,
dtype=x.dtype,
)
else:
expand_tokens = num_tokens * layer.top_k
(
src_to_dst,
sorted_token_ids,
expert_sizes_gpu,
expert_sizes_cpu,
) = ixfops.moe_compute_token_index(
topk_ids=topk_ids,
num_experts=num_experts,
)
expert_sizes_cpu = expert_sizes_gpu.cpu()
# expand + reorder
# TODO use kernel
expand_hidden_states = ixfops.moe_expand_input(
hidden_states=x,
dst_to_src=sorted_token_ids,
dst_tokens=expand_tokens,
topk=layer.top_k,
src_to_dst=src_to_dst,
)
# w4a16 group gemm 1
# pt_output_1: (expand_tokens, 2n) dtype
pt_output_1 = ixfops.moe_w4a16_group_gemm(
input=expand_hidden_states,
weight=layer.w13_qweight,
w_scales=layer.w13_scales,
quant_type="awq",
tokens_per_experts=expert_sizes_cpu,
w_zeros=layer.w13_qzeros,
group_size=self.quant_config.group_size,
dst_to_src=None,
format="NN",
tokens_per_experts_gpu=expert_sizes_gpu,
)
# act
pt_output_2 = ixfops.silu_and_mul(pt_output_1)
# w4a16 group gemm 2 + reorder
# pt_output_3: (expand_tokens, k) dtype
if use_ep:
pt_output_3 = torch.empty(
(num_tokens * layer.top_k, hidden_size),
device=x.device,
dtype=x.dtype,
)
ixfops.moe_w4a16_group_gemm(
input=pt_output_2,
weight=layer.w2_qweight,
w_scales=layer.w2_scales,
quant_type="awq",
tokens_per_experts=expert_sizes_cpu,
w_zeros=layer.w2_qzeros,
group_size=self.quant_config.group_size,
dst_to_src=sorted_token_ids,
format="NN",
output=pt_output_3,
tokens_per_experts_gpu=expert_sizes_gpu,
)
reduce_mask = src_to_dst == -1
final_hidden_states = ixfops.moe_output_reduce_sum(
input=pt_output_3.view(num_tokens, layer.top_k, -1),
topk_weight=topk_weights,
scaling_factor=layer.routed_scaling_factor,
mask=reduce_mask,
)
else:
pt_output_3 = ixfops.moe_w4a16_group_gemm(
input=pt_output_2,
weight=layer.w2_qweight,
w_scales=layer.w2_scales,
quant_type="awq",
tokens_per_experts=expert_sizes_cpu,
w_zeros=layer.w2_qzeros,
group_size=self.quant_config.group_size,
dst_to_src=sorted_token_ids,
format="NN",
tokens_per_experts_gpu=expert_sizes_gpu,
)
# mul + reduce_sum
# final_hidden_states: (num_tokens, k)
final_hidden_states = ixfops.moe_output_reduce_sum(
input=pt_output_3.view(num_tokens, layer.top_k, -1),
topk_weight=topk_weights,
scaling_factor=layer.routed_scaling_factor,
extra_residual=shared_experts_input,
)
return final_hidden_states
# return torch.ops.vllm.fused_marlin_moe(
# x,
# layer.w13_qweight,
# layer.w2_qweight,
# getattr(layer, "w13_bias", None),
# getattr(layer, "w2_bias", None),
# layer.w13_scales,
# layer.w2_scales,
# router_logits,
# topk_weights,
# topk_ids,
# quant_type_id=self.quant_type.id,
# apply_router_weight_on_input=apply_router_weight_on_input,
# global_num_experts=global_num_experts,
# expert_map=expert_map,
# g_idx1=layer.w13_g_idx,
# g_idx2=layer.w2_g_idx,
# sort_indices1=layer.w13_g_idx_sort_indices,
# sort_indices2=layer.w2_g_idx_sort_indices,
# workspace=layer.workspace,
# is_k_full=self.is_k_full)

View File

@@ -12,8 +12,7 @@ from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.attention import Attention, MLAAttention
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
@@ -24,14 +23,12 @@ from vllm.model_executor.layers.fused_moe.layer import (
FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
make_fp8_moe_quant_config,
select_fp8_moe_backend,
)
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
NvFp4MoeBackend,
convert_to_nvfp4_moe_kernel_format,
is_global_sf_supported_for_nvfp4_backend,
make_nvfp4_moe_kernel,
@@ -49,13 +46,6 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
flashinfer_trtllm_fp4_moe,
flashinfer_trtllm_fp4_routed_moe,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
process_fp8_input_tensor_strategy_moe,
@@ -114,6 +104,8 @@ QUANT_ALGOS = [
"NVFP4",
# MXFP8
"MXFP8",
# MIXED_PRECISION,
"MIXED_PRECISION",
]
KV_CACHE_QUANT_ALGOS = ["FP8"]
@@ -181,7 +173,7 @@ class ModelOptQuantConfigBase(QuantizationConfig):
self, layer: torch.nn.Module, prefix: str
) -> "QuantizeMethodBase | None":
# handle kv-cache first so we can focus only on weight quantization thereafter
if isinstance(layer, Attention):
if isinstance(layer, (Attention, MLAAttention)):
return self.KVCacheMethodCls(self)
# handle exclusion
@@ -235,6 +227,26 @@ class ModelOptQuantConfigBase(QuantizationConfig):
self.exclude_modules = hf_to_vllm_mapper.apply_list(new_exclude_modules)
@staticmethod
def _extract_modelopt_quant_algo(
hf_quant_cfg: dict[str, Any] | None,
) -> str | None:
"""Extract upper-cased quant_algo from a modelopt config.
Returns the quant_algo string (upper-cased), or None if the config
is not a modelopt config.
"""
if hf_quant_cfg is None:
return None
if hf_quant_cfg.get("quant_method", "").lower() != "modelopt":
return None
if "quantization" in hf_quant_cfg:
quant_config = hf_quant_cfg["quantization"]
if isinstance(quant_config, dict):
return str(quant_config.get("quant_algo", "")).upper()
return None
return str(hf_quant_cfg.get("quant_algo", "")).upper()
@staticmethod
def get_config_filenames() -> list[str]:
return ["hf_quant_config.json"]
@@ -272,10 +284,20 @@ class ModelOptQuantConfigBase(QuantizationConfig):
# "exclude_modules" is the key in the legacy hf_quant_config.json
exclude_modules = quant_config.get("exclude_modules", [])
else:
# Compressed-tensors style format:
# Compressed-tensors style format (config.json quantization_config):
# {"quant_algo": "...", "quant_method": "modelopt"}
quant_method = config.get("quant_algo")
kv_cache_quant_method = config.get("kv_cache_quant_algo")
# "kv_cache_scheme" (a dict) instead of "kv_cache_quant_algo" (a string).
kv_cache_scheme = config.get("kv_cache_scheme")
if isinstance(kv_cache_scheme, dict) and (
kv_cache_scheme.get("type") == "float"
and kv_cache_scheme.get("num_bits") == 8
):
kv_cache_quant_method = "FP8"
else:
kv_cache_quant_method = None
# "ignore" is the key in config.json
exclude_modules = config.get("ignore", [])
group_size_raw = config.get("group_size")
@@ -379,32 +401,9 @@ class ModelOptFp8Config(ModelOptQuantConfigBase):
def override_quantization_method(
cls, hf_quant_cfg, user_quant
) -> QuantizationMethods | None:
"""Detect if this ModelOpt config should be used based on
quantization config."""
if hf_quant_cfg is None:
return None
# Use the community standard 'quant_method'
quant_method = hf_quant_cfg.get("quant_method", "").lower()
# Only proceed if the method is explicitly "modelopt"
if quant_method != "modelopt":
return None
# Look for ModelOpt-specific config structure
if "quantization" in hf_quant_cfg:
quant_config = hf_quant_cfg["quantization"]
if isinstance(quant_config, dict):
quant_algo = str(quant_config.get("quant_algo", ""))
if quant_algo.upper() == "FP8":
return "modelopt"
else:
# Check for compressed-tensors style config with specific quant_algo
quant_algo = str(hf_quant_cfg.get("quant_algo", ""))
if quant_algo.upper() == "FP8":
return "modelopt"
algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
if algo is not None and algo == "FP8":
return "modelopt"
return None
@classmethod
@@ -737,7 +736,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
@@ -745,9 +744,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
) -> mk.FusedMoEExpertsModular:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
@@ -862,16 +861,15 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
# Setup modular kernel.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
assert self.experts_cls is not None
self.moe_mk = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
)
assert self.experts_cls is not None
self.moe_kernel = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w13 = layer.w13_weight
@@ -904,9 +902,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
w1_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
a1_scale = layer.w13_input_scale
@@ -920,10 +916,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
a2_scale=a2_scale,
)
@property
def is_monolithic(self) -> bool:
return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
def apply_monolithic(
self,
layer: FusedMoE,
@@ -931,28 +923,20 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic
assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
if layer.enable_eplb:
raise NotImplementedError(
"EPLB not supported for FlashInfer TRTLLM FP8 MoE Backend."
)
# TODO(rob): this validation should happen at kernel selection
# time in the oracle rather than here.
SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
assert layer.activation in SUPPORTED_ACTIVATIONS, (
f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer "
f"TRTLLM FP4 MoE, {layer.activation} found instead."
)
return apply_fi_trtllm_fp8_per_tensor_moe(
layer=layer,
hidden_states=x,
router_logits=router_logits,
routing_bias=layer.e_score_correction_bias,
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
x,
layer.w13_weight,
layer.w2_weight,
router_logits,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
e_score_correction_bias=layer.e_score_correction_bias,
routed_scaling_factor=layer.routed_scaling_factor,
)
def apply(
@@ -964,25 +948,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
# TODO(rob): this validation should happen at kernel selection
# time in the oracle rather than here.
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
assert layer.activation in (
MoEActivation.SILU,
MoEActivation.RELU2_NO_MUL,
), (
"Expected activation to be in ('silu', 'relu2_no_mul'),"
f"but got {layer.activation}"
)
assert self.moe_mk is not None
return self.moe_mk(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
assert self.moe_kernel is not None
return self.moe_kernel.apply(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
@@ -1031,32 +1003,9 @@ class ModelOptNvFp4Config(ModelOptQuantConfigBase):
def override_quantization_method(
cls, hf_quant_cfg, user_quant
) -> QuantizationMethods | None:
"""Detect if this ModelOpt FP4 config should be used based on
quantization config."""
if hf_quant_cfg is None:
return None
# Use the community standard 'quant_method'
quant_method = hf_quant_cfg.get("quant_method", "").lower()
# Only proceed if the method is explicitly "modelopt"
if quant_method != "modelopt":
return None
# Look for ModelOpt-specific config structure
if "quantization" in hf_quant_cfg:
quant_config = hf_quant_cfg["quantization"]
if isinstance(quant_config, dict):
quant_algo = quant_config.get("quant_algo", "")
if "NVFP4" in quant_algo:
return "modelopt_fp4"
else:
# Check for compressed-tensors style config with specific
# quant_algo field
quant_algo = hf_quant_cfg.get("quant_algo", "")
if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
return "modelopt_fp4"
algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
if algo is not None and ("NVFP4" in algo or "FP4" in algo):
return "modelopt_fp4"
return None
@classmethod
@@ -1249,17 +1198,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
@@ -1434,51 +1373,18 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
replace_parameter(layer, "w2_weight_scale_2", w2_scale_2)
replace_parameter(layer, "w2_input_scale", a2_scale)
# Setup modular kernel for TP case and naive DP/EP case.
# In non-naive DP/EP case, we will create a ModularKernelMethod.
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases.
# Setup modular kernel.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
assert self.experts_cls is not None
self.moe_mk = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
experts_cls=self.experts_cls,
shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(),
)
@property
def do_post_quant_allgather(self):
return self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
def prepare_dp_allgather_tensor(
self,
layer: FusedMoE,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, list[torch.Tensor]]:
"""Optionally prepare extra tensors to carry through DP allgather/EP."""
if self.nvfp4_backend != NvFp4MoeBackend.FLASHINFER_TRTLLM:
raise RuntimeError(
"prepare_dp_allgather_tensor is only supported for "
"FlashInfer TRTLLM NVFP4 MoE backend."
)
import flashinfer
hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize(
hidden_states,
layer.a1_gscale,
is_sf_swizzled_layout=False,
assert self.experts_cls is not None
self.moe_kernel = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
experts_cls=self.experts_cls,
shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(),
)
extra_tensors: list[torch.Tensor] = [hidden_states_sf]
return hidden_states_fp4, extra_tensors
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
return make_nvfp4_moe_quant_config(
backend=self.nvfp4_backend,
w13_scale=layer.w13_weight_scale,
@@ -1493,13 +1399,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
def supports_eplb(self) -> bool:
return True
@property
def is_monolithic(self) -> bool:
return (
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
and not self.moe.moe_parallel_config.enable_eplb
)
def apply_monolithic(
self,
layer: FusedMoE,
@@ -1507,22 +1406,20 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic
assert (
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
and not layer.enable_eplb
)
return flashinfer_trtllm_fp4_moe(
layer=layer,
x=x,
router_logits=router_logits,
top_k=layer.top_k,
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
x,
layer.w13_weight,
layer.w2_weight,
router_logits,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
custom_routing_function=layer.custom_routing_function,
e_score_correction_bias=layer.e_score_correction_bias,
routed_scaling_factor=layer.routed_scaling_factor,
)
def apply(
@@ -1534,33 +1431,19 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
# EPLB path
if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
assert layer.enable_eplb
return flashinfer_trtllm_fp4_routed_moe(
layer=layer,
x=x,
topk_ids=topk_ids,
topk_weights=topk_weights,
top_k=layer.top_k,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
)
else:
assert self.moe_mk is not None
return self.moe_mk(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=shared_experts_input,
)
assert self.moe_kernel is not None
return self.moe_kernel.apply(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=shared_experts_input,
)
ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
@@ -1619,31 +1502,9 @@ class ModelOptMxFp8Config(ModelOptQuantConfigBase):
def override_quantization_method(
cls, hf_quant_cfg, user_quant
) -> QuantizationMethods | None:
"""Detect if this ModelOpt MXFP8 config should be used based on
quantization config."""
if hf_quant_cfg is None:
return None
# Use the community standard 'quant_method'
quant_method = hf_quant_cfg.get("quant_method", "").lower()
# Only proceed if the method is explicitly "modelopt"
if quant_method != "modelopt":
return None
# Look for ModelOpt-specific config structure
if "quantization" in hf_quant_cfg:
quant_config = hf_quant_cfg["quantization"]
if isinstance(quant_config, dict):
quant_algo = str(quant_config.get("quant_algo", "")).upper()
if "MXFP8" in quant_algo:
return "modelopt_mxfp8"
else:
# Check for compressed-tensors style config with specific quant_algo
quant_algo = str(hf_quant_cfg.get("quant_algo", "")).upper()
if "MXFP8" in quant_algo:
return "modelopt_mxfp8"
algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
if algo is not None and "MXFP8" in algo:
return "modelopt_mxfp8"
return None
@classmethod
@@ -1841,3 +1702,188 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
# Register the method classes for ModelOptMxFp8Config
ModelOptMxFp8Config.LinearMethodCls = ModelOptMxFp8LinearMethod
ModelOptMxFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
class ModelOptMixedPrecisionConfig(ModelOptQuantConfigBase):
"""Config class for ModelOpt MIXED_PRECISION.
Supports checkpoints where different layers use different quantization
algorithms (e.g., FP8 for dense layers and NVFP4 for MoE experts).
The per-layer algorithm is specified in the ``quantized_layers`` dict
inside ``config.json``'s ``quantization_config`` (preferred) or the
legacy ``hf_quant_config.json``.
"""
def __init__(
self,
kv_cache_quant_method: str | None,
exclude_modules: list[str],
quantized_layers: dict[str, dict[str, Any]],
fp8_config: ModelOptFp8Config,
nvfp4_config: ModelOptNvFp4Config,
) -> None:
super().__init__(exclude_modules)
self.kv_cache_quant_method = kv_cache_quant_method
self.quantized_layers = quantized_layers
self.fp8_config = fp8_config
self.nvfp4_config = nvfp4_config
def get_name(self) -> QuantizationMethods:
return "modelopt_mixed"
def get_supported_act_dtypes(self) -> list[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
def get_min_capability(cls) -> int:
return 89
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
) -> QuantizationMethods | None:
algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
if algo is not None and algo == "MIXED_PRECISION":
return "modelopt_mixed"
return None
@classmethod
def _from_config(
cls,
*,
quant_method: str,
kv_cache_quant_method: str | None,
exclude_modules: list[str],
original_config: dict[str, Any],
group_size: int | None,
**kwargs: Any,
) -> "ModelOptMixedPrecisionConfig":
if "quantization" in original_config:
quantized_layers = original_config["quantization"].get(
"quantized_layers", {}
)
else:
quantized_layers = original_config.get("quantized_layers", {})
if not quantized_layers:
raise ValueError(
"MIXED_PRECISION quant_algo requires a non-empty "
"'quantized_layers' mapping in the quantization config."
)
# Determine group_size from the first NVFP4 entry if not provided.
if group_size is None:
for layer_info in quantized_layers.values():
if layer_info.get("quant_algo", "").upper() == "NVFP4":
group_size = layer_info.get("group_size", 16)
break
if group_size is None:
group_size = 16
fp8_config = ModelOptFp8Config(
quant_method="FP8",
is_checkpoint_fp8_serialized=True,
kv_cache_quant_method=kv_cache_quant_method,
exclude_modules=[],
)
nvfp4_config = ModelOptNvFp4Config(
is_checkpoint_nvfp4_serialized=True,
kv_cache_quant_algo=kv_cache_quant_method,
exclude_modules=[],
group_size=group_size,
)
return cls(
kv_cache_quant_method=kv_cache_quant_method,
exclude_modules=exclude_modules,
quantized_layers=quantized_layers,
fp8_config=fp8_config,
nvfp4_config=nvfp4_config,
)
def _resolve_quant_algo(self, prefix: str) -> str | None:
"""Look up the quant_algo for a vLLM-side layer prefix.
Tries three strategies in order:
1. Direct lookup in ``quantized_layers``.
2. Packed/fused-layer lookup (unfuse via ``packed_modules_mapping``).
3. Prefix-based lookup for FusedMoE (any child key starts with
``prefix + "."``).
Returns the upper-cased quant_algo string, or *None* if the prefix
is not found.
"""
# 1. Direct lookup
if prefix in self.quantized_layers:
return self.quantized_layers[prefix]["quant_algo"].upper()
# 2. Packed / fused layer lookup
proj_name = prefix.rsplit(".", 1)[-1]
if self.packed_modules_mapping and proj_name in self.packed_modules_mapping:
algos: set[str] = set()
base = prefix.rsplit(".", 1)[0]
for shard_name in self.packed_modules_mapping[proj_name]:
shard_prefix = f"{base}.{shard_name}"
if shard_prefix in self.quantized_layers:
algos.add(self.quantized_layers[shard_prefix]["quant_algo"].upper())
if len(algos) == 1:
return algos.pop()
if len(algos) > 1:
raise ValueError(
f"Mixed quant_algo within fused layer {prefix}: "
f"{algos}. All shards must use the same quantization."
)
# 3. Prefix-based lookup (for FusedMoE / parent modules)
prefix_dot = prefix + "."
for key, info in self.quantized_layers.items():
if key.startswith(prefix_dot):
return info["quant_algo"].upper()
return None
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> "QuantizeMethodBase | None":
"""Return quantize-method based on layer."""
# KV-cache quantization
if isinstance(layer, Attention):
if self.kv_cache_quant_method:
return ModelOptFp8KVCacheMethod(self)
return None
# Excluded layers
if self.is_layer_excluded(prefix):
if isinstance(layer, LinearBase):
return UnquantizedLinearMethod()
return None
quant_algo = self._resolve_quant_algo(prefix)
if isinstance(layer, LinearBase):
if quant_algo == "FP8":
return ModelOptFp8LinearMethod(self.fp8_config)
if quant_algo == "NVFP4":
return ModelOptNvFp4LinearMethod(self.nvfp4_config)
# Layer not in quantized_layers — leave unquantized
return UnquantizedLinearMethod()
if isinstance(layer, FusedMoE):
if quant_algo == "FP8":
return ModelOptFp8MoEMethod(
quant_config=self.fp8_config,
moe_config=layer.moe_config,
)
if quant_algo == "NVFP4":
return ModelOptNvFp4FusedMoE(
quant_config=self.nvfp4_config,
moe_config=layer.moe_config,
)
return None
return None
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
super().apply_vllm_mapper(hf_to_vllm_mapper)
if self.quantized_layers:
self.quantized_layers = hf_to_vllm_mapper.apply_dict(self.quantized_layers)

View File

@@ -6,6 +6,7 @@ import torch
from torch.nn.parameter import Parameter
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
@@ -77,6 +78,8 @@ class Mxfp4Backend(Enum):
# Triton Backend
TRITON = 6
CK = 7
def get_mxfp4_backend_with_lora() -> Mxfp4Backend:
"""
@@ -167,9 +170,15 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
elif current_platform.is_xpu():
logger.info_once("Using xpu backend on XPU")
return Mxfp4Backend.MARLIN
elif current_platform.is_rocm() and has_triton_kernels():
logger.info_once("Using Triton backend")
return Mxfp4Backend.TRITON
elif current_platform.is_rocm():
from vllm.platforms.rocm import on_gfx950
if rocm_aiter_ops.is_enabled() and on_gfx950():
logger.info_once("Using CK MXFP4 MoE backend (Aiter ROCm)")
return Mxfp4Backend.CK
elif has_triton_kernels():
logger.info_once("Using Triton backend")
return Mxfp4Backend.TRITON
return Mxfp4Backend.NONE
@@ -257,7 +266,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
)
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
# Initialized in process_weights_after_loading for CUTLASS/SM90 backends
self.moe_mk: mk.FusedMoEModularKernel | None = None
self.moe_kernel: mk.FusedMoEKernel | None = None
def create_weights(
self,
@@ -338,6 +347,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.intermediate_size = intermediate_size_per_partition_after_pad
self.hidden_size = hidden_size
self.hidden_pad = extra_weight_attrs.get("hidden_pad", 0)
self.intermediate_pad = (
intermediate_size_per_partition_after_pad - intermediate_size_per_partition
)
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(
torch.zeros(
@@ -427,7 +440,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
)
assert prepare_finalize is not None
self.moe_mk = mk.FusedMoEModularKernel(
self.moe_kernel = mk.FusedMoEKernel(
prepare_finalize,
MarlinExperts(
self.moe,
@@ -776,7 +789,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
)
assert prepare_finalize is not None
self.moe_mk = mk.FusedMoEModularKernel(
self.moe_kernel = mk.FusedMoEKernel(
prepare_finalize,
FlashInferExperts(
moe_config=self.moe,
@@ -784,6 +797,66 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
),
shared_experts=None,
)
elif self.mxfp4_backend == Mxfp4Backend.CK:
if layer.w13_bias is not None:
layer.w13_bias.data = layer.w13_bias.data.to(torch.float32)
if layer.w2_bias.data is not None:
layer.w2_bias.data = layer.w2_bias.data.to(torch.float32)
e, n, k = layer.w13_weight.shape
layer.w13_weight.view(torch.uint8).copy_(
layer.w13_weight.data.view(torch.uint8)
.view(e, n // 2, 2, k)
.permute(0, 2, 1, 3)
.contiguous()
.view(e, n, k)
)
layer.w13_weight_scale.data = (
layer.w13_weight_scale.data.view(e, n // 2, 2, -1)
.permute(0, 2, 1, 3)
.contiguous()
.view(e, n, -1)
)
layer.w13_weight.data = layer.w13_weight.data.view(torch.float4_e2m1fn_x2)
layer.w2_weight.data = layer.w2_weight.data.view(torch.float4_e2m1fn_x2)
layer.w13_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(
layer.w13_weight, 16, True
)
shuffled_w13_scale = rocm_aiter_ops.shuffle_scale_a16w4(
layer.w13_weight_scale.view(-1, layer.w13_weight_scale.shape[-1]),
self.num_experts,
True,
)
layer.w2_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(
layer.w2_weight, 16, False
)
shuffled_w2_scale = rocm_aiter_ops.shuffle_scale_a16w4(
layer.w2_weight_scale.view(-1, layer.w2_weight_scale.shape[-1]),
self.num_experts,
False,
)
layer.w13_bias.data = (
layer.w13_bias.data.view(-1, n // 2, 2)
.permute(0, 2, 1)
.contiguous()
.view(-1, n)
)
layer.w13_weight_scale = torch.nn.Parameter(
shuffled_w13_scale, requires_grad=False
)
layer.w2_weight_scale = torch.nn.Parameter(
shuffled_w2_scale, requires_grad=False
)
# replace_parameter(layer, "w13_bias", w13_bias)
# replace_parameter(layer, "w13_weight_scale", w13_weight_scale)
# replace_parameter(layer, "w2_weight_scale", w2_weight_scale)
# replace_parameter(layer, "w13_weight", w13_weight)
# replace_parameter(layer, "w2_weight", w2_weight)
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
@@ -792,18 +865,16 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer.w13_bias = Parameter(w13_bias, requires_grad=False)
layer.w2_bias = Parameter(w2_bias, requires_grad=False)
# Ideally we'd use FusedMoEModularKernel.prepare_finalize object
# (stored in self.fused_experts) to determine if the MoE has a
# batched activation format. As self.fused_experts is not
# initialized at this point, we resort to checking the MoE config
# directly.
is_batched_moe = self.moe.use_pplx_kernels or self.moe.use_deepep_ll_kernels
is_batched_moe = self.moe.use_deepep_ll_kernels
if is_batched_moe:
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
else:
num_warps = 8
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
layer.w13_weight, layer.w13_weight_scale, num_warps
)
@@ -817,13 +888,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.w2_precision_config = PrecisionConfig(
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
)
self.w13_weight = w13_weight
self.w2_weight = w2_weight
del layer.w13_weight
del layer.w2_weight
layer.w13_weight = w13_weight
layer.w2_weight = w2_weight
else:
raise ValueError(
f"Unsupported mxfp4_backend: {self.mxfp4_backend}: "
@@ -862,6 +933,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
elif self.mxfp4_backend in [
Mxfp4Backend.SM100_FI_MXFP4_BF16,
Mxfp4Backend.SM90_FI_MXFP4_BF16,
Mxfp4Backend.CK,
]:
return mxfp4_w4a16_moe_quant_config(
w1_bias=layer.w13_bias,
@@ -882,9 +954,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
) -> mk.FusedMoEExpertsModular:
if (
prepare_finalize.activation_format
== mk.FusedMoEActivationFormat.BatchedExperts
@@ -929,10 +1001,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
@property
def is_monolithic(self) -> bool:
if self.moe.is_lora_enabled:
return False
return (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
or self.mxfp4_backend == Mxfp4Backend.TRITON
or self.mxfp4_backend == Mxfp4Backend.CK
)
def apply(
@@ -968,8 +1043,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
or self.mxfp4_backend == Mxfp4Backend.MARLIN
)
assert self.moe_mk is not None
return self.moe_mk(
assert self.moe_kernel is not None
return self.moe_kernel.apply(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
@@ -1054,6 +1129,27 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
tune_max_num_tokens=max(self.max_capture_size, 1),
)[0]
return trtllm_gen_output
elif self.mxfp4_backend == Mxfp4Backend.CK:
topk_weights, topk_ids = rocm_aiter_ops.fused_topk(
x, router_logits, layer.top_k, True
)
output = rocm_aiter_ops.fused_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation_method=rocm_aiter_ops.get_aiter_activation_type("swiglu"),
quant_method=rocm_aiter_ops.get_aiter_quant_type("per_1x32"),
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
doweight_stage1=False,
hidden_pad=self.hidden_pad // 128 * 128,
intermediate_pad=self.intermediate_pad // 64 * 64 * 2,
bias1=layer.w13_bias,
bias2=layer.w2_bias,
)
return output
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501
triton_kernel_moe_forward,
@@ -1162,7 +1258,7 @@ class XpuMxfp4MoEMethod(Mxfp4MoEMethod):
topk_weights=routing_weights,
topk_ids=selected_experts,
n_experts_per_token=layer.top_k,
activation=layer.activation,
activation=layer.activation.value,
num_experts=layer.local_num_experts,
is_mxfp4=True,
)

View File

@@ -7,7 +7,6 @@ import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import (
init_fp8_linear_kernel,
)
@@ -26,10 +25,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from vllm.platforms import current_platform
ACTIVATION_SCHEMES = ["static", "dynamic"]
logger = init_logger(__name__)
class PTPCFp8Config(Fp8Config):
"""Config class for Per-Token-Per-Channel Dynamic Quantization Fp8."""

View File

@@ -35,6 +35,7 @@ from vllm.model_executor.layers.quantization.quark.utils import (
)
from vllm.model_executor.models.utils import WeightsMapper
from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config
if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper
@@ -59,6 +60,22 @@ class QuarkConfig(QuantizationConfig):
self.kv_cache_group = kv_cache_group
self.kv_cache_config = kv_cache_config
self.pack_method = pack_method
self.dynamic_mxfp4_quant = False
def maybe_update_config(self, model_name: str, revision: str | None = None):
self.hf_config = get_config(
model=model_name,
trust_remote_code=False, # or get from model_config if available
revision=revision,
config_format="auto",
)
quant_config = getattr(self.hf_config, "quantization_config", None)
if quant_config is not None:
quant_dtype = quant_config["global_quant_config"]["weight"]["dtype"]
model_type = self.hf_config.model_type
if quant_dtype == "fp4" and model_type == "deepseek_v3":
self.dynamic_mxfp4_quant = True
def get_linear_method(self) -> "QuarkLinearMethod":
return QuarkLinearMethod(self)
@@ -108,7 +125,20 @@ class QuarkConfig(QuantizationConfig):
if should_ignore_layer(
prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping
):
return UnquantizedLinearMethod()
if (
"self_attn" not in prefix # only quantize attention projections
or not getattr(self, "dynamic_mxfp4_quant", False)
or not isinstance(layer, LinearBase) # Ignore other methods
):
return UnquantizedLinearMethod()
scheme = self.get_scheme(
layer=layer,
layer_name=prefix,
dynamic_mxfp4_quant=True,
)
layer.scheme = scheme
return QuarkLinearMethod(self)
if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix)
layer.scheme = scheme
@@ -450,7 +480,9 @@ class QuarkConfig(QuantizationConfig):
)
return global_quant_config
def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme":
def _get_scheme_from_config(
self, config: dict[str, Any], dynamic_mxfp4_quant: bool = False
) -> "QuarkScheme":
if config.get("output_tensors") or config.get("bias"):
raise NotImplementedError(
"Currently, Quark models with output_tensors "
@@ -473,7 +505,9 @@ class QuarkConfig(QuantizationConfig):
input_symmetric=input_config.get("symmetric"),
)
elif self._is_w_ocp_mx_a_x(weight_config, input_config):
return QuarkOCP_MX(weight_config, input_config)
return QuarkOCP_MX(
weight_config, input_config, dynamic_mxfp4_quant=dynamic_mxfp4_quant
)
raise NotImplementedError(
"No quark compatible scheme was found. "
@@ -481,11 +515,15 @@ class QuarkConfig(QuantizationConfig):
f"Input config: {input_config}"
)
def get_scheme(self, layer: torch.nn.Module, layer_name: str) -> "QuarkScheme":
def get_scheme(
self, layer: torch.nn.Module, layer_name: str, dynamic_mxfp4_quant: bool = False
) -> "QuarkScheme":
layer_quant_config = self._find_matched_config(layer_name, layer)
# Find the quant_scheme
scheme = self._get_scheme_from_config(layer_quant_config)
scheme = self._get_scheme_from_config(
layer_quant_config, dynamic_mxfp4_quant=dynamic_mxfp4_quant
)
# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)
self._check_scheme_supported(scheme.get_min_capability())

View File

@@ -5,8 +5,8 @@ from typing import Any
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
@@ -32,6 +32,7 @@ from vllm.model_executor.layers.quantization.mxfp4 import (
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_fp8_moe_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import _swizzle_mxfp4
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
OCP_MX_BLOCK_SIZE,
OCP_MX_Scheme,
@@ -49,7 +50,11 @@ from vllm.utils.math_utils import round_up
logger = init_logger(__name__)
__all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod", "QuarkOCP_MX_MoEMethod"]
__all__ = [
"QuarkMoEMethod",
"QuarkOCP_MX_MoEMethod",
"QuarkOCP_MX_MoEMethod_OSS",
]
class QuarkMoEMethod(FusedMoEMethodBase):
@@ -71,14 +76,30 @@ class QuarkMoEMethod(FusedMoEMethodBase):
"output_tensors and bias "
"quantized are not supported"
)
weight_config = layer_quant_config.get("weight")
input_config = layer_quant_config.get("input_tensors")
if quant_config._is_fp8_w4a8(weight_config, input_config):
return QuarkW4A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
elif quant_config._is_fp8_w8a8(weight_config, input_config):
return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
elif quant_config._is_w_ocp_mx_a_x(weight_config, input_config):
return QuarkOCP_MX_MoEMethod(weight_config, input_config, module.moe_config)
emulate = not current_platform.supports_mx() or not (
rocm_aiter_ops.is_fused_moe_enabled()
)
if (
input_config.get("dtype") == "fp8_e4m3"
and not input_config.get("is_dynamic")
and not emulate
):
return QuarkOCP_MX_MoEMethod_OSS(
weight_config, input_config, module.moe_config
)
else:
return QuarkOCP_MX_MoEMethod(
weight_config, input_config, module.moe_config
)
else:
raise RuntimeError("Unsupported FusedMoe scheme")
@@ -706,13 +727,11 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
get_current_vllm_config().model_config.hf_config, "model_type", None
)
self._emulate = (
self.emulate = (
not current_platform.supports_mx()
or not self.ocp_mx_scheme.startswith("w_mxfp4")
) and (self.mxfp4_backend is None or not self.use_rocm_aiter_moe)
self.emulate = True if self.model_type == "gpt_oss" else self._emulate
if self.emulate:
logger.warning_once(
f"The current mode (supports_mx={current_platform.supports_mx()}, "
@@ -753,6 +772,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
)
params_dtype = torch.uint8
self.intermediate_size_per_partition = intermediate_size_per_partition
if self.model_type == "gpt_oss":
if current_platform.is_rocm():
intermediate_size_per_partition_after_pad = round_up(
@@ -765,6 +785,10 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
else:
intermediate_size_per_partition_after_pad = intermediate_size_per_partition
self.unpadded_hidden_size = extra_weight_attrs.get(
"unpadded_hidden_size", hidden_size
)
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
@@ -991,30 +1015,20 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if not self.emulate:
if (
self.model_type == "gpt_oss"
and self.mxfp4_backend == Mxfp4Backend.TRITON
):
raise NotImplementedError(
"Triton kernel implemented fused MoE for GPT_OSS model "
"in Quark(MoE) format is not integrated or provided yet."
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts,
)
else:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts,
)
return rocm_aiter_fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
quant_config=self.moe_quant_config,
expert_map=layer.expert_map,
)
return rocm_aiter_fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
quant_config=self.moe_quant_config,
expert_map=layer.expert_map,
)
else:
from vllm.model_executor.layers.fused_moe import fused_experts
@@ -1031,3 +1045,133 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
)
class QuarkOCP_MX_MoEMethod_OSS(QuarkOCP_MX_MoEMethod):
def __init__(
self,
weight_config: dict[str, Any],
input_config: dict[str, Any],
moe: FusedMoEConfig,
):
super().__init__(weight_config, input_config, moe)
def process_weights_after_loading(self, layer):
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
w13_bias = layer.w13_bias.to(torch.float32)
w2_bias = layer.w2_bias.to(torch.float32)
layer.w13_bias = torch.nn.Parameter(w13_bias, requires_grad=False)
layer.w2_bias = torch.nn.Parameter(w2_bias, requires_grad=False)
# FIXME warp need to be adjusted based on batch size
# only apply to batched mode
if self.moe.use_ep:
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
else:
num_warps = 8
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
layer.w13_weight, layer.w13_weight_scale, num_warps
)
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
layer.w2_weight, layer.w2_weight_scale, num_warps
)
self.w13_weight_triton_tensor = w13_weight
self.w2_weight_triton_tensor = w2_weight
# need to delete the original weights to save memory on single GPU
del layer.w13_weight
del layer.w2_weight
layer.w13_weight = None
layer.w2_weight = None
torch.cuda.empty_cache()
if self.static_input_scales:
if layer.w13_input_scale is None or layer.w2_input_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
layer.w2_input_scale
):
logger.warning_once(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer."
)
layer.w13_input_scale = torch.nn.Parameter(
layer.w13_input_scale.max().to(torch.float32), requires_grad=False
)
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max().to(torch.float32), requires_grad=False
)
from triton_kernels.numerics import InFlexData
lhs_data13 = InFlexData(scale=layer.w13_input_scale)
lhs_data2 = InFlexData(scale=layer.w2_input_scale)
self.w13_precision_config = PrecisionConfig(
weight_scale=w13_scale,
flex_ctx=FlexCtx(rhs_data=w13_flex, lhs_data=lhs_data13),
)
self.w2_precision_config = PrecisionConfig(
weight_scale=w2_scale,
flex_ctx=FlexCtx(rhs_data=w2_flex, lhs_data=lhs_data2),
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
return mxfp4_w4a8_moe_quant_config(
w1_scale=self.w13_precision_config,
w2_scale=self.w2_precision_config,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
block_shape=None,
)
@property
def is_monolithic(self) -> bool:
return True
def apply_monolithic(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
expert_map: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if layer.enable_eplb:
raise NotImplementedError(
"EPLB not supported for `QuarkW4MXFp4MoEMethod_OSS` yet."
)
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501
triton_kernel_moe_forward,
)
return triton_kernel_moe_forward(
hidden_states=x,
w1=self.w13_weight_triton_tensor,
w2=self.w2_weight_triton_tensor,
gating_output=router_logits,
topk=layer.top_k,
renormalize=layer.renormalize,
global_num_experts=layer.global_num_experts,
expert_map=expert_map,
quant_config=self.moe_quant_config,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
unpadded_N_w1=self.intermediate_size_per_partition * 2,
unpadded_K_w1=self.unpadded_hidden_size,
unpadded_N_w2=self.unpadded_hidden_size,
unpadded_K_w2=self.intermediate_size_per_partition,
)

View File

@@ -24,7 +24,12 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
OCP_MX_BLOCK_SIZE,
OCP_MX_Scheme,
)
from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter
from vllm.model_executor.parameter import (
GroupQuantScaleParameter,
ModelWeightParameter,
PackedvLLMParameter,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from .quark_scheme import QuarkScheme
@@ -169,13 +174,16 @@ except (ImportError, AttributeError, RuntimeError):
class QuarkOCP_MX(QuarkScheme):
def __init__(
self, weight_quant_spec: dict[str, Any], input_quant_spec: dict[str, Any]
self,
weight_quant_spec: dict[str, Any],
input_quant_spec: dict[str, Any],
dynamic_mxfp4_quant: bool = False,
):
self.out_dtype = torch.get_default_dtype()
self.qscheme = "per_group"
self.weight_quant_spec = weight_quant_spec
self.input_quant_spec = input_quant_spec
self.dynamic_mxfp4_quant = dynamic_mxfp4_quant
self.weight_dtype = weight_quant_spec["dtype"].replace("fp", "mxfp")
self.input_dtype = input_quant_spec["dtype"].replace("fp", "mxfp")
@@ -269,7 +277,13 @@ class QuarkOCP_MX(QuarkScheme):
layer.weight_scale.data, requires_grad=False
)
else:
if self.rocm_use_aiter_fp4_asm_gemm:
if self.dynamic_mxfp4_quant:
w_q, w_s = dynamic_mxfp4_quant(layer.weight)
layer.weight_scale = torch.nn.Parameter(
w_s.T.contiguous(), requires_grad=False
)
layer.weight = torch.nn.Parameter(w_q, requires_grad=False)
elif self.rocm_use_aiter_fp4_asm_gemm:
# shuffle weight scale
weight_scale_shuffle = layer.weight_scale.data
sm, sn = weight_scale_shuffle.shape
@@ -302,36 +316,51 @@ class QuarkOCP_MX(QuarkScheme):
weight_loader: Callable,
**kwargs,
):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
if self.dynamic_mxfp4_quant:
weight = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
# WEIGHT
weight = PackedvLLMParameter(
data=torch.empty(
output_size_per_partition,
self.get_packed_dim(input_size_per_partition, self.weight_dtype),
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
packed_dim=1,
packed_factor=self.packed_factor,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, kwargs)
else:
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
# WEIGHT SCALE
weight_scale = GroupQuantScaleParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // OCP_MX_BLOCK_SIZE,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
# WEIGHT
weight = PackedvLLMParameter(
data=torch.empty(
output_size_per_partition,
self.get_packed_dim(input_size_per_partition, self.weight_dtype),
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
packed_dim=1,
packed_factor=self.packed_factor,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
weight_scale = GroupQuantScaleParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // OCP_MX_BLOCK_SIZE,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
def apply_weights(
self,

View File

@@ -6,28 +6,18 @@ from typing import TYPE_CHECKING
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
activation_to_flashinfer_int,
align_fp4_moe_weights_for_fi,
)
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
swizzle_blockscale,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kNvfp4Dynamic,
kNvfp4Static,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import (
has_flashinfer_cutlass_fused_moe,
)
if TYPE_CHECKING:
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
@@ -42,92 +32,15 @@ __all__ = [
"reorder_w1w3_to_w3w1",
]
#
# Methods used by the oracle for kernel selection.
#
def _supports_current_device() -> bool:
"""Supports only Blackwell-family GPUs."""
p = current_platform
return p.is_cuda() and p.is_device_capability_family(100)
def _supports_no_act_and_mul() -> bool:
"""Supports non-gated MoE."""
return True
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Supports Nvfp4 quantization."""
SUPPORTED_W_A = [
(kNvfp4Static, kNvfp4Dynamic),
]
return (weight_key, activation_key) in SUPPORTED_W_A
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
def _supports_routing_method(
routing_method: RoutingMethodType,
) -> bool:
"""Monolithic kernels need to express router support."""
# NOTE(rob): potentially allow others here. This is a conservative list.
return routing_method in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
RoutingMethodType.Llama4,
]
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
"""
TRTLLM is a monolithic kernel that requires dispatch_router_logits() for
the naive dispatch/combine path. DeepEP HT only implements dispatch() for
the modular kernel path, so TRTLLM is incompatible with DeepEP HT.
"""
return not moe_parallel_config.use_deepep_ht_kernels
def is_supported_config_trtllm(
moe_config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat,
) -> tuple[bool, str | None]:
"""
This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config
"""
def _make_reason(reason: str) -> str:
return f"kernel does not support {reason}"
if not _supports_current_device():
return False, _make_reason(f"current device {current_platform.device_name}")
elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()):
return False, _make_reason("no act_and_mul MLP layer")
elif not _supports_activation(moe_config.activation):
return False, _make_reason(f"{moe_config.activation} activation")
elif not _supports_quant_scheme(weight_key, activation_key):
return False, _make_reason(f"quantization scheme {weight_key}x{activation_key}")
elif not _supports_parallel_config(moe_config.moe_parallel_config):
return False, _make_reason(f"parallel config {moe_config.moe_parallel_config}")
elif not _supports_routing_method(moe_config.routing_method):
return False, _make_reason(f"routing method {moe_config.routing_method}")
elif activation_format != mk.FusedMoEActivationFormat.Standard:
return False, _make_reason(f"activation format {activation_format}")
elif moe_config.hidden_dim % 512 != 0:
return False, _make_reason(
f"hidden_dim must be divisible by 512, found {moe_config.hidden_dim}"
)
return True, None
def is_flashinfer_fp4_cutlass_moe_available() -> bool:
"""Return `True` when FlashInfer CUTLASS NV-FP4 kernels can be used."""
return (
envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutlass_fused_moe()
and current_platform.is_cuda()
and current_platform.has_device_capability(100)
)
def reorder_w1w3_to_w3w1(
@@ -276,190 +189,6 @@ def prepare_static_weights_for_trtllm_fp4_moe(
)
def flashinfer_trtllm_fp4_moe(
layer: torch.nn.Module,
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
router_logits: torch.Tensor,
top_k: int,
activation: MoEActivation,
global_num_experts: int,
num_expert_group: int | None,
topk_group: int | None,
custom_routing_function: object | None,
e_score_correction_bias: torch.Tensor | None,
) -> torch.Tensor:
"""
Apply FlashInfer TensorRT-LLM FP4 MoE kernel.
Args:
layer: The MoE layer with weights and scales
x: Input tensor
router_logits: Router logits for expert selection
top_k: Number of experts to select per token
activation: Activation function to use
global_num_experts: Total number of experts across all ranks
num_expert_group: Number of expert groups (for grouped routing)
topk_group: Top-k within each group
custom_routing_function: Custom routing function (e.g., Llama4)
e_score_correction_bias: Optional routing bias correction
Returns:
Output tensor from the MoE layer
"""
import flashinfer
from vllm.model_executor.models.llama4 import Llama4MoE
SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
assert activation in SUPPORTED_ACTIVATIONS, (
f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer "
f"TRTLLM FP4 MoE, {activation} found instead."
)
# Quantize input to FP4
if isinstance(x, tuple):
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
else:
# hidden_states is the already quantized
(hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant(
x, layer.a1_gscale, is_sf_swizzled_layout=False
)
# Determine routing method type
use_llama4_routing = custom_routing_function is Llama4MoE.custom_routing_function
routing_method_type = layer.routing_method_type
if use_llama4_routing:
routing_method_type = flashinfer.RoutingMethodType.Llama4
# Cast to Fp32 (required by kernel).
router_logits = (
router_logits.to(torch.float32)
if routing_method_type == RoutingMethodType.DeepSeekV3
else router_logits
)
# Determine activation type
activation_type = activation_to_flashinfer_int(layer.activation)
# Call TRT-LLM FP4 block-scale MoE kernel
out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
routing_logits=router_logits,
routing_bias=e_score_correction_bias,
hidden_states=hidden_states_fp4,
hidden_states_scale=hidden_states_scale_linear_fp4.view(
torch.float8_e4m3fn
).reshape(*hidden_states_fp4.shape[:-1], -1),
gemm1_weights=layer.w13_weight.data,
gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=layer.w2_weight.data,
gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn),
gemm2_bias=None,
output1_scale_scalar=layer.g1_scale_c.data,
output1_scale_gate_scalar=layer.g1_alphas.data,
output2_scale_scalar=layer.g2_alphas.data,
num_experts=global_num_experts,
top_k=top_k,
n_group=num_expert_group if num_expert_group is not None else 0,
topk_group=topk_group if topk_group is not None else 0,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
routed_scaling_factor=None,
routing_method_type=routing_method_type,
do_finalize=True,
activation_type=activation_type,
)[0]
return out
def flashinfer_trtllm_fp4_routed_moe(
layer: torch.nn.Module,
x: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
top_k: int,
activation: MoEActivation,
global_num_experts: int,
) -> torch.Tensor:
"""
Apply FlashInfer TensorRT-LLM FP4 MoE kernel. Uses packed
input top k expert indices and scores rather than computing
top k expert indices from scores.
Args:
layer: The MoE layer with weights and scales
x: Input tensor
topk_ids: Ids of selected experts
top_k: Number of experts to select per token
activation: Activation function to use
global_num_experts: Total number of experts across all ranks
Returns:
Output tensor from the MoE layer
"""
import flashinfer
# https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2535
assert activation == MoEActivation.SILU, (
"Only SiLU activation is supported for FlashInfer TRTLLM FP4 Routed MoE. "
f"{activation} found instead."
)
# Pack top k ids and expert weights into a single int32 tensor, as
# required by TRT-LLM
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
torch.bfloat16
).view(torch.int16)
if isinstance(x, tuple):
# Hidden_states is the already quantized
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
else:
# Quantize input to FP4
(hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant(
x, layer.a1_gscale, is_sf_swizzled_layout=False
)
# Call TRT-LLM FP4 block-scale MoE kernel
out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(
topk_ids=packed_tensor,
routing_bias=None,
hidden_states=hidden_states_fp4,
hidden_states_scale=hidden_states_scale_linear_fp4.view(
torch.float8_e4m3fn
).reshape(*hidden_states_fp4.shape[:-1], -1),
gemm1_weights=layer.w13_weight.data,
gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=layer.w2_weight.data,
gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn),
gemm2_bias=None,
output1_scale_scalar=layer.g1_scale_c.data,
output1_scale_gate_scalar=layer.g1_alphas.data,
output2_scale_scalar=layer.g2_alphas.data,
num_experts=global_num_experts,
top_k=top_k,
n_group=0,
topk_group=0,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
routed_scaling_factor=None,
routing_method_type=1,
do_finalize=True,
)[0]
return out
def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
backend: "NvFp4MoeBackend",
layer: "FusedMoE",
@@ -526,6 +255,7 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
)
)
layer.intermediate_size_per_partition = padded_intermediate
layer.moe_config.intermediate_size_per_partition = padded_intermediate
w13, w13_scale, w2, w2_scale = prepare_static_weights_for_trtllm_fp4_moe(
w13,

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
from typing import TYPE_CHECKING
import torch
@@ -10,6 +11,9 @@ from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up
if TYPE_CHECKING:
from flashinfer.fused_moe.core import ActivationType
logger = init_logger(__name__)
@@ -20,6 +24,10 @@ class FlashinferMoeBackend(Enum):
def activation_to_flashinfer_int(activation: MoEActivation) -> int:
return activation_to_flashinfer_type(activation).value
def activation_to_flashinfer_type(activation: MoEActivation) -> "ActivationType":
from flashinfer.fused_moe.core import ActivationType
# silu and gelu are mapped to their gated versions SwiGLU and GeGLU respectively
@@ -30,7 +38,7 @@ def activation_to_flashinfer_int(activation: MoEActivation) -> int:
MoEActivation.GELU: ActivationType.Geglu,
MoEActivation.RELU2_NO_MUL: ActivationType.Relu2,
}
return ACTIVATION_TO_FI_ACTIVATION[activation].value
return ACTIVATION_TO_FI_ACTIVATION[activation]
def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
@@ -87,104 +95,6 @@ def rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
)
def register_scales_for_trtllm_fp8_per_tensor_moe(
layer: torch.nn.Module,
w13_scale: torch.Tensor,
w13_input_scale: torch.Tensor,
w2_scale: torch.Tensor,
w2_input_scale: torch.Tensor,
) -> None:
"""Register necessary scales for FlashInfer TRTLLM FP8 MoE kernel"""
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
w13_scale=w13_scale,
w13_input_scale=w13_input_scale,
w2_scale=w2_scale,
w2_input_scale=w2_input_scale,
)
layer.w2_input_scale_inv = 1.0 / w2_input_scale
layer.output1_scales_gate_scalar = g1_alphas
if layer.activation.is_gated:
layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv
else:
layer.output1_scales_scalar = (
torch.ones_like(g1_alphas) * layer.w2_input_scale_inv
)
layer.output2_scales_scalar = g2_alphas
def apply_fi_trtllm_fp8_per_tensor_moe(
layer: torch.nn.Module,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
routing_bias: torch.Tensor | None,
top_k: int,
num_expert_group: int | None,
topk_group: int | None,
global_num_experts: int,
apply_router_weight_on_input: bool,
) -> torch.Tensor:
from flashinfer.fused_moe import RoutingMethodType
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
from vllm.model_executor.models.llama4 import Llama4MoE
# Added to the layer by: register_scales_for_trtllm_fp8_per_tensor_moe
assert (
hasattr(layer, "output1_scales_scalar")
and hasattr(layer, "output1_scales_gate_scalar")
and hasattr(layer, "output2_scales_scalar")
)
if layer.routing_method_type == RoutingMethodType.Llama4:
assert (
not layer.renormalize
and layer.custom_routing_function == Llama4MoE.custom_routing_function
), (
"FusedMoE flashinfer kernels with Llama4 routing method are only "
"supported for Llama4"
)
else:
assert layer.custom_routing_function is None, (
"Custom routing function is only supported for Llama4"
)
activation_type = activation_to_flashinfer_int(layer.activation)
return torch.ops.vllm.fi_trtllm_fp8_per_tensor_moe(
routing_logits=router_logits,
routing_bias=routing_bias,
hidden_states=hidden_states,
input_scale=layer.w13_input_scale,
gemm1_weights=layer.w13_weight,
gemm2_weights=layer.w2_weight,
output1_scales_scalar=layer.output1_scales_scalar,
output1_scales_gate_scalar=layer.output1_scales_gate_scalar,
output2_scales_scalar=layer.output2_scales_scalar,
num_experts=global_num_experts,
top_k=top_k,
num_expert_group=num_expert_group,
topk_group=topk_group,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
use_routing_scales_on_input=apply_router_weight_on_input,
routing_method_type=layer.routing_method_type,
activation_type=activation_type,
)
def make_fp8_moe_alpha_scales_for_fi(
w13_scale: torch.Tensor,
w13_input_scale: torch.Tensor,
w2_scale: torch.Tensor,
w2_input_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
g1_alphas = (w13_scale * w13_input_scale).squeeze()
g2_alphas = (w2_scale * w2_input_scale).squeeze()
return g1_alphas, g2_alphas
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
backend_map = {
"throughput": FlashinferMoeBackend.CUTLASS,
@@ -432,6 +342,7 @@ def prepare_fp8_moe_layer_for_fi(
min_alignment,
)
layer.intermediate_size_per_partition = new_intermediate
layer.moe_config.intermediate_size_per_partition = new_intermediate
# FI kernels require W31 layout rather than W13.
if layer.moe_config.is_act_and_mul:
@@ -440,20 +351,12 @@ def prepare_fp8_moe_layer_for_fi(
w13_scale = swap_w13_to_w31(w13_scale)
# FI TRT-LLM FP8 per-tensor MoE kernel requires weight shuffle
# and registration of alpha scales. Note that we do not register
# as nn.Parameters since they are not needed for weight-reloading.
# and registration of alpha scales.
if is_trtllm and not block_quant:
assert w13_input_scale is not None
assert w2_input_scale is not None
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(w13, w2, is_gated)
register_scales_for_trtllm_fp8_per_tensor_moe(
layer,
w13_scale=w13_scale,
w13_input_scale=w13_input_scale,
w2_scale=w2_scale,
w2_input_scale=w2_input_scale,
)
# Clamp block scales to avoid NaN from the FlashInfer CUTLASS kernel.
# Some FP8 models have near-zero block scales (~1e-23) for dead/unused

View File

@@ -53,7 +53,10 @@ logger = init_logger(__name__)
def is_fp8(x: torch.dtype | torch.Tensor) -> bool:
if isinstance(x, torch.Tensor):
x = x.dtype
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
try:
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
except:
return False
# We need to pass in the is_hopper flag as argument because the function

View File

@@ -0,0 +1,373 @@
import torch
import numpy as np
from gguf.constants import GGMLQuantizationType
def get_awq_format(w, group_size=128, w_bit=4):
org_w_shape = w.shape
ori_w_dtype = torch.get_default_dtype()
assert w_bit == 4
assert w.shape[1] % group_size == 0
in_features = org_w_shape[1]
w = w.reshape(-1, group_size)
assert torch.isnan(w).sum() == 0
max_val = w.amax(dim=1, keepdim=True)
min_val = w.amin(dim=1, keepdim=True)
max_int = 2**w_bit - 1
min_int = 0
scales = (max_val - min_val).clamp(min=1e-5) / max_int
zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
w = (
torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros
) * scales
zeros = zeros.view(org_w_shape[0], -1)
scales = scales.view(org_w_shape[0], -1)
w = w.reshape(org_w_shape)
assert torch.isnan(scales).sum() == 0
assert torch.isnan(w).sum() == 0
scales = scales.t().contiguous() # input // group, o
zeros = zeros.t().contiguous() # input // group, o
# from auto awq
scale_zeros = zeros * scales
scales = scales.clone().to(ori_w_dtype)
pack_num = 32 // w_bit
intweight = []
for idx in range(in_features):
intweight.append(
torch.round(
(w[:, idx] + scale_zeros[idx // group_size])
/ scales[idx // group_size]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.to(dtype=torch.int32)
qweight = torch.zeros(
(intweight.shape[0], intweight.shape[1] // 32 * w_bit),
dtype=torch.int32,
device=intweight.device,
)
for col in range(intweight.shape[1] // pack_num):
order_map = [0, 2, w_bit, 6, 1, 3, 5, 7]
for i in range(pack_num):
qweight_col = intweight[:, col * pack_num + order_map[i]]
qweight[:, col] |= qweight_col << (i * w_bit)
zeros = zeros.to(dtype=torch.int32, device=qweight.device)
qzeros = torch.zeros(
(zeros.shape[0], zeros.shape[1] // 32 * w_bit),
dtype=torch.int32,
device=zeros.device,
)
for col in range(zeros.shape[1] // pack_num):
order_map = [0, 2, w_bit, 6, 1, 3, 5, 7]
for i in range(pack_num):
qzero_col = zeros[:, col * pack_num + order_map[i]]
qzeros[:, col] |= qzero_col << (i * w_bit)
return qweight, qzeros, scales
GGML_BLOCK_SIZES = {
"F32": 4,
"F16": 2,
"Q4_0": 2 + 16,
"Q5_0": 2 + 4 + 16,
"Q8_0": 2 + 32,
"Q2_K": 256 // 16 + 256 // 4 + 2 + 2,
"Q3_K": 256 // 8 + 256 // 4 + 12 + 2,
"Q4_K": 2 + 2 + 12 + 256 // 2,
"Q5_K": 2 + 2 + 12 + 256 // 8 + 256 // 2,
"Q6_K": 256 // 2 + 256 // 4 + 256 // 16 + 2,
"IQ4_XS": 2 + 2 + 256 // 2 + 256 // 64,
}
def dequantize_f32(data):
return np.frombuffer(data, dtype=np.float32)
def dequantize_f16(data):
return np.frombuffer(data, dtype=np.float16)
def dequantize_q4_0(data):
num_blocks = len(data) // GGML_BLOCK_SIZES["Q4_0"]
scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 8)[:, :1].astype(np.float32)
qs = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 16)[:, 2:]
return np.concatenate([
scales * ((qs & 0xf).astype(np.int8) - 8),
scales * ((qs >> 4).astype(np.int8) - 8),
], axis=1)
def dequantize_q5_0(data):
num_blocks = len(data) // GGML_BLOCK_SIZES["Q5_0"]
scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 2 + 8)[:, :1].astype(np.float32)
qh = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 4 + 16)[:, 2:2 + 4]
qs = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 4 + 16)[:, 2 + 4:]
bits = np.unpackbits(qh, axis=-1, bitorder="little")
x0 = ((qs & 0xf).astype(np.int8) | (bits[:, :16] << 4)) - 16
x1 = ((qs >> 4).astype(np.int8) | (bits[:, 16:] << 4)) - 16
return np.concatenate([
scales * x0,
scales * x1,
], axis=1)
def dequantize_q8_0(data):
num_blocks = len(data) // GGML_BLOCK_SIZES["Q8_0"]
scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 16)[:, :1].astype(np.float32)
qs = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, 2 + 32)[:, 2:]
return scales * qs
def dequantize_q2_k(data):
block_size = GGML_BLOCK_SIZES["Q2_K"]
num_blocks = len(data) // block_size
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
dmin = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32)
d = data_f16[:, -2].reshape(num_blocks, 1, 1).astype(np.float32)
scales = data_u8[:, :16].reshape(num_blocks, 16, 1)
qs = data_u8[:, 16:80].reshape(num_blocks, 64)
tmp = np.stack([
qs[:, 00:16] >> 0,
qs[:, 16:32] >> 0,
qs[:, 00:16] >> 2,
qs[:, 16:32] >> 2,
qs[:, 00:16] >> 4,
qs[:, 16:32] >> 4,
qs[:, 00:16] >> 6,
qs[:, 16:32] >> 6,
qs[:, 32:48] >> 0,
qs[:, 48:64] >> 0,
qs[:, 32:48] >> 2,
qs[:, 48:64] >> 2,
qs[:, 32:48] >> 4,
qs[:, 48:64] >> 4,
qs[:, 32:48] >> 6,
qs[:, 48:64] >> 6,
], axis=1)
return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4)
def dequantize_q3_k(data):
block_size = GGML_BLOCK_SIZES["Q3_K"]
num_blocks = len(data) // block_size
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
d = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32)
bits = np.unpackbits(data_u8[:, :32].reshape(num_blocks, 32, 1), axis=-1, bitorder="little")
bits = 4 ^ (bits << 2)
qs = data_u8[:, 32:32 + 64].astype(np.int16)
a, b, c = data_u8[:, 96: 96 + 12].reshape(num_blocks, 3, 4).transpose(1, 0, 2)
scales = np.zeros((num_blocks, 4, 4), dtype=np.uint8)
scales[:, 0] = (a & 15) | ((c & 3) << 4)
scales[:, 1] = (b & 15) | (((c >> 2) & 3) << 4)
scales[:, 2] = (a >> 4) | (((c >> 4) & 3) << 4)
scales[:, 3] = (b >> 4) | ((c >> 6) << 4)
scales = scales.reshape(num_blocks, 16, 1).astype(np.int16)
return d * (scales - 32) * np.stack([
(((qs[:, 00:16] >> 0) & 3) - bits[:, :16, 0]),
(((qs[:, 16:32] >> 0) & 3) - bits[:, 16:, 0]),
(((qs[:, 00:16] >> 2) & 3) - bits[:, :16, 1]),
(((qs[:, 16:32] >> 2) & 3) - bits[:, 16:, 1]),
(((qs[:, 00:16] >> 4) & 3) - bits[:, :16, 2]),
(((qs[:, 16:32] >> 4) & 3) - bits[:, 16:, 2]),
(((qs[:, 00:16] >> 6) & 3) - bits[:, :16, 3]),
(((qs[:, 16:32] >> 6) & 3) - bits[:, 16:, 3]),
(((qs[:, 32:48] >> 0) & 3) - bits[:, :16, 4]),
(((qs[:, 48:64] >> 0) & 3) - bits[:, 16:, 4]),
(((qs[:, 32:48] >> 2) & 3) - bits[:, :16, 5]),
(((qs[:, 48:64] >> 2) & 3) - bits[:, 16:, 5]),
(((qs[:, 32:48] >> 4) & 3) - bits[:, :16, 6]),
(((qs[:, 48:64] >> 4) & 3) - bits[:, 16:, 6]),
(((qs[:, 32:48] >> 6) & 3) - bits[:, :16, 7]),
(((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7])
], axis=1)
def dequantize_q4_k(data, device=None):
block_size = GGML_BLOCK_SIZES["Q4_K"]
num_blocks = len(data) // block_size
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
# Casting to float32 because float16 is very slow on CPU
scale_factors = data_f16[:, 0].reshape(num_blocks, 1, 1).astype(np.float32)
scale_offsets = data_f16[:, 1].reshape(num_blocks, 1, 1).astype(np.float32)
qs1 = data_u8[:, 4:16].reshape(num_blocks, 12, 1)
qs2 = data_u8[:, 16:].reshape(num_blocks, 4, 32)
# Dequantize scales and offsets (6 bits and 4 + 2 bits)
factors = scale_factors * np.concatenate([qs1[:, 0:4] & 0b111111, (qs1[:, 8:] & 15) | ((qs1[:, 0:4] >> 6) << 4)], axis=1)
offsets = scale_offsets * np.concatenate([qs1[:, 4:8] & 0b111111, (qs1[:, 8:] >> 4) | ((qs1[:, 4:8] >> 6) << 4)], axis=1)
# Interleave low and high quantized bits
qs2 = np.stack([qs2 & 0xf, qs2 >> 4], axis=2).reshape(num_blocks, 8, 32)
# Dequantize final weights using scales and offsets
weight = factors * qs2 - offsets
if device is None:
return weight
return torch.from_numpy(weight).to(device=device)
def dequantize_q5_k(data):
block_size = GGML_BLOCK_SIZES["Q5_K"]
num_blocks = len(data) // block_size
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
d = data_f16[:, 0].reshape(num_blocks, 1).astype(np.float32)
dmin = data_f16[:, 1].reshape(num_blocks, 1).astype(np.float32)
scales = data_u8[:, 4:16].reshape(num_blocks, 12, 1)
qh = data_u8[:, 16: 16 + 32].reshape(num_blocks, 32, 1)
qs = data_u8[:, 48: 48 + 128].reshape(num_blocks, 4, 32)
bits = np.unpackbits(qh, axis=-1, bitorder="little")
qs_hi_4 = qs >> 4
qs_lo_4 = qs & 15
scales_lo_6 = scales[:, :8] & 63
scales_hi_6 = scales[:, :8] >> 6
scales_lo_4 = scales[:, 8:] & 15
scales_hi_4 = scales[:, 8:] >> 4
m1 = dmin * scales_lo_6[:, 4]
m2 = dmin * scales_lo_6[:, 5]
m3 = dmin * scales_lo_6[:, 6]
m4 = dmin * scales_lo_6[:, 7]
m5 = dmin * (scales_hi_4[:, 0] | (scales_hi_6[:, 4] << 4))
m6 = dmin * (scales_hi_4[:, 1] | (scales_hi_6[:, 5] << 4))
m7 = dmin * (scales_hi_4[:, 2] | (scales_hi_6[:, 6] << 4))
m8 = dmin * (scales_hi_4[:, 3] | (scales_hi_6[:, 7] << 4))
d1 = d * scales_lo_6[:, 0]
d2 = d * scales_lo_6[:, 1]
d3 = d * scales_lo_6[:, 2]
d4 = d * scales_lo_6[:, 3]
d5 = d * (scales_lo_4[:, 0] | (scales_hi_6[:, 0] << 4))
d6 = d * (scales_lo_4[:, 1] | (scales_hi_6[:, 1] << 4))
d7 = d * (scales_lo_4[:, 2] | (scales_hi_6[:, 2] << 4))
d8 = d * (scales_lo_4[:, 3] | (scales_hi_6[:, 3] << 4))
return np.concatenate([
d1 * (qs_lo_4[:, 0] + (bits[:, :, 0] << 4)) - m1,
d2 * (qs_hi_4[:, 0] + (bits[:, :, 1] << 4)) - m2,
d3 * (qs_lo_4[:, 1] + (bits[:, :, 2] << 4)) - m3,
d4 * (qs_hi_4[:, 1] + (bits[:, :, 3] << 4)) - m4,
d5 * (qs_lo_4[:, 2] + (bits[:, :, 4] << 4)) - m5,
d6 * (qs_hi_4[:, 2] + (bits[:, :, 5] << 4)) - m6,
d7 * (qs_lo_4[:, 3] + (bits[:, :, 6] << 4)) - m7,
d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8,
], axis=1)
def dequantize_q6_k(data, device = None):
block_size = GGML_BLOCK_SIZES["Q6_K"]
num_blocks = len(data) // block_size
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
data_i8 = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, block_size)
scales = data_f16[:, -1].reshape(num_blocks, 1).astype(np.float32)
# TODO use uint8 and cast later?
ql = data_u8[:, :128].astype(np.int16)
qh = data_u8[:, 128:192].astype(np.int16)
sc = data_i8[:, 192:208, np.newaxis].astype(np.float32)
# Unpack bits, subtraction requires signed data type
q1 = (ql[:, :32 ] & 0xF) | (((qh[:, :32] >> 0) & 3) << 4) - 32
q2 = (ql[:, 32:64 ] & 0xF) | (((qh[:, :32] >> 2) & 3) << 4) - 32
q3 = (ql[:, :32 ] >> 4) | (((qh[:, :32] >> 4) & 3) << 4) - 32
q4 = (ql[:, 32:64 ] >> 4) | (((qh[:, :32] >> 6) & 3) << 4) - 32
q5 = (ql[:, 64:96 ] & 0xF) | (((qh[:, 32:] >> 0) & 3) << 4) - 32
q6 = (ql[:, 96:128] & 0xF) | (((qh[:, 32:] >> 2) & 3) << 4) - 32
q7 = (ql[:, 64:96 ] >> 4) | (((qh[:, 32:] >> 4) & 3) << 4) - 32
q8 = (ql[:, 96:128] >> 4) | (((qh[:, 32:] >> 6) & 3) << 4) - 32
# Dequantize
weight = scales * np.concatenate([
sc[:, 0] * q1[:, :16],
sc[:, 1] * q1[:, 16:],
sc[:, 2] * q2[:, :16],
sc[:, 3] * q2[:, 16:],
sc[:, 4] * q3[:, :16],
sc[:, 5] * q3[:, 16:],
sc[:, 6] * q4[:, :16],
sc[:, 7] * q4[:, 16:],
sc[:, 8] * q5[:, :16],
sc[:, 9] * q5[:, 16:],
sc[:, 10] * q6[:, :16],
sc[:, 11] * q6[:, 16:],
sc[:, 12] * q7[:, :16],
sc[:, 13] * q7[:, 16:],
sc[:, 14] * q8[:, :16],
sc[:, 15] * q8[:, 16:],
], axis=1)
if device is None:
return weight
return torch.from_numpy(weight).to(device=device)
QK_K = 256
kvalues_iq4nl = np.array([-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], dtype=np.int8)
def dequantize_iq4_xs(data):
block_size = GGML_BLOCK_SIZES["IQ4_XS"]
num_blocks = len(data) // block_size
d = np.frombuffer(data, dtype=np.float16)[0::block_size//2].astype(np.float32).reshape(num_blocks, 1)
scales_h = np.frombuffer(data, dtype=np.uint16)[1::block_size//2].reshape(num_blocks, 1)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)[:, 4:]
scales_l = data_u8[:, :4].reshape(num_blocks, 4)
qs = data_u8[:, 4:].reshape(num_blocks, block_size - 8)
ls = np.zeros((num_blocks, QK_K // 32), dtype=np.int8)
for ib in range(QK_K // 32):
ls[:, ib] = ((scales_l[:, ib // 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h[:, 0] >> 2 * ib) & 3) << 4)
dl = (d * (ls - 32)).reshape(num_blocks, -1, 1)
qs_lo_4 = qs[:, :QK_K // 2].reshape(num_blocks, -1, 16) & 0xf
qs_hi_4 = qs[:, :QK_K // 2].reshape(num_blocks, -1, 16) >> 4
y = np.zeros((num_blocks, QK_K), dtype=np.float32)
for ib in range(QK_K // 32):
y[:, ib*32:(ib*32)+16] = dl[:, ib] * kvalues_iq4nl[qs_lo_4[:, ib]]
y[:, (ib*32)+16:(ib*32)+32] = dl[:, ib] * kvalues_iq4nl[qs_hi_4[:, ib]]
return y.flatten()
GGML_DEQUANTIZE = {
int(GGMLQuantizationType.F32): dequantize_f32,
int(GGMLQuantizationType.F16): dequantize_f16,
int(GGMLQuantizationType.Q4_0): dequantize_q4_0,
int(GGMLQuantizationType.Q5_0): dequantize_q5_0,
int(GGMLQuantizationType.Q8_0): dequantize_q8_0,
int(GGMLQuantizationType.Q2_K): dequantize_q2_k,
int(GGMLQuantizationType.Q3_K): dequantize_q3_k,
int(GGMLQuantizationType.Q4_K): dequantize_q4_k,
int(GGMLQuantizationType.Q5_K): dequantize_q5_k,
int(GGMLQuantizationType.Q6_K): dequantize_q6_k,
int(GGMLQuantizationType.IQ4_XS): dequantize_iq4_xs,
}
def dequant_gguf(data, type, shape):
values = GGML_DEQUANTIZE[type](data)
values = torch.from_numpy(values).view(shape)
return values

View File

@@ -255,18 +255,6 @@ def marlin_moe_intermediate_size(w1_packed: torch.Tensor, w2_packed: torch.Tenso
return w2_packed.size(1) * marlin_tile_size
def marlin_make_workspace(
output_size_per_partition: int, device: torch.device
) -> torch.Tensor:
max_workspace_size = (
output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N
) * GPTQ_MARLIN_MAX_PARALLEL
return torch.zeros(
max_workspace_size, dtype=torch.int, device=device, requires_grad=False
)
def marlin_make_workspace_new(
device: torch.device, max_blocks_per_sm: int = 1
) -> torch.Tensor:
@@ -297,12 +285,6 @@ def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
)
def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
return torch.nn.Parameter(
torch.empty(0, dtype=torch.int, device=device), requires_grad=False
)
def marlin_sort_g_idx(g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
return g_idx[g_idx_sort_indices], g_idx_sort_indices

View File

@@ -175,7 +175,7 @@ try:
op_func=_dequant_mxfp4,
fake_impl=_dequant_mxfp4_fake,
)
dequant_mxfp4 = torch.ops.vllm.dequant_mxfp4
dequant_mxfp4 = None
except AttributeError as error:
raise error
@@ -185,6 +185,6 @@ try:
op_func=_quant_dequant_mxfp4,
fake_impl=_quant_dequant_mxfp4_fake,
)
quant_dequant_mxfp4 = torch.ops.vllm.quant_dequant_mxfp4
quant_dequant_mxfp4 = None
except AttributeError as error:
raise error

View File

@@ -271,12 +271,12 @@ def scaled_quantize(
If None, uses input dtype. Use torch.float32 for higher precision.
"""
group_shape = _normalize_quant_group_shape(x, group_shape)
assert quant_dtype.is_floating_point, (
"currently `scaled_quantize` only supports floating point dtypes "
"but could be extended to support other dtypes"
)
# assert quant_dtype.is_floating_point, (
# "currently `scaled_quantize` only supports floating point dtypes "
# "but could be extended to support other dtypes"
# )
finfo = torch.finfo(quant_dtype)
finfo = torch.finfo(quant_dtype) if quant_dtype.is_floating_point else torch.iinfo(quant_dtype)
# Convert to compute dtype if specified
x_compute = x if compute_dtype is None else x.to(compute_dtype)

View File

@@ -0,0 +1,114 @@
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.model_executor.utils import set_weight_attrs
class W8a16Config(QuantizationConfig):
"""Config class for W8a16.
"""
def __init__(
self,
) -> None:
pass
def __repr__(self) -> str:
return ("W8a16Config")
def get_name(self) -> str:
return "w8a16"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]
def get_min_capability(self) -> int:
return 75
@staticmethod
def get_config_filenames():
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "W8a16Config":
return cls()
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["W8a16LinearMethod"]:
if isinstance(layer, LinearBase):
return W8a16LinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class W8a16LinearMethod(LinearMethodBase):
"""Linear method for w8a16.
"""
def __init__(self, quant_config: W8a16Config):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
output_size_per_partition = sum(output_partition_sizes)
weight = Parameter(
torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=torch.int8,
),
requires_grad=False,
)
set_weight_attrs(
weight, {
"input_dim": 1,
"output_dim": 0,
})
scales = Parameter(
torch.empty(
1,
output_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(scales, {
"input_dim": None,
"output_dim": 1,
})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = layer.weight
scales = layer.scales
out_shape = (x.shape[:-1] + (qweight.shape[-2],))
reshaped_x = x.reshape(-1, x.shape[-1])
out = ops.linear_w8a16(reshaped_x, qweight, scales, format="TN")
if bias is not None:
out = out + bias
return out.reshape(out_shape)

Some files were not shown because too many files have changed in this diff Show More