Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -8,6 +8,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import envs
|
||||
from vllm.distributed import (
|
||||
divide,
|
||||
get_tensor_model_parallel_rank,
|
||||
@@ -130,13 +131,12 @@ class SiluAndMul(CustomOp):
|
||||
|
||||
def __init__(self, *, compile_native: bool = True):
|
||||
super().__init__(compile_native=compile_native)
|
||||
if current_platform.is_cuda_alike():
|
||||
if current_platform.is_cuda_alike() or current_platform.is_xpu():
|
||||
from vllm import _custom_ops as ops
|
||||
self.op = ops.silu_and_mul
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
|
||||
self.op = ipex_ops.silu_and_mul
|
||||
if envs.VLLM_USE_SILU_QUANT_FUSION:
|
||||
self.op = ops.silu_and_mul_quant
|
||||
else:
|
||||
self.op = ops.silu_and_mul
|
||||
elif current_platform.is_cpu():
|
||||
self._forward_method = self.forward_native
|
||||
|
||||
@@ -146,11 +146,15 @@ class SiluAndMul(CustomOp):
|
||||
d = x.shape[-1] // 2
|
||||
return F.silu(x[..., :d]) * x[..., d:]
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward_cuda(self, x: torch.Tensor, out_dim: int = 0) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = x.shape[:-1] + (d,)
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
self.op(out, x)
|
||||
if envs.VLLM_USE_SILU_QUANT_FUSION:
|
||||
quant_out, out_scales = self.op(x, out_dim)
|
||||
out = (quant_out, out_scales, x.dtype)
|
||||
else:
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
self.op(out, x)
|
||||
return out
|
||||
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -174,7 +178,6 @@ class MulAndSilu(CustomOp):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if current_platform.is_cuda_alike() or current_platform.is_xpu():
|
||||
# self.op = torch.ops._C.mul_and_silu
|
||||
from vllm import _custom_ops as ops
|
||||
self.op = ops.mul_and_silu
|
||||
elif current_platform.is_cpu():
|
||||
@@ -397,7 +400,6 @@ class NewGELU(CustomOp):
|
||||
or current_platform.is_cpu()
|
||||
or current_platform.is_xpu()
|
||||
):
|
||||
# self.op = torch.ops._C.gelu_new
|
||||
from vllm import _custom_ops as ops
|
||||
self.op = ops.gelu_new
|
||||
|
||||
@@ -427,7 +429,8 @@ class FastGELU(CustomOp):
|
||||
or current_platform.is_cpu()
|
||||
or current_platform.is_xpu()
|
||||
):
|
||||
self.op = torch.ops._C.gelu_fast
|
||||
from vllm import _custom_ops as ops
|
||||
self.op = ops.gelu_fast
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
@@ -455,7 +458,6 @@ class QuickGELU(CustomOp):
|
||||
or current_platform.is_cpu()
|
||||
or current_platform.is_xpu()
|
||||
):
|
||||
# self.op = torch.ops._C.gelu_quick
|
||||
from vllm import _custom_ops as ops
|
||||
self.op = ops.gelu_quick
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from vllm.config.vllm import VllmConfig
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.kv_transfer_utils import (
|
||||
maybe_transfer_kv_layer,
|
||||
maybe_transfer_kv_layer
|
||||
)
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
|
||||
@@ -40,6 +40,9 @@ from vllm.v1.kv_cache_interface import (
|
||||
KVCacheSpec,
|
||||
SlidingWindowSpec,
|
||||
)
|
||||
from .extra_cache import StaticQuantManager
|
||||
from ixformer.core import config
|
||||
_USE_TORCH_OPS = config.IXFORMER_USE_TORCH_OPS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.attention import MLAAttention
|
||||
@@ -202,6 +205,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
kv_sharing_target_layer_name: str | None = None,
|
||||
attn_backend: type[AttentionBackend] | None = None,
|
||||
head_size_v: int | None = None,
|
||||
extra_cache_para: dict = None,
|
||||
**extra_impl_args,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -258,6 +262,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.hidden_size = head_size * num_heads
|
||||
self.head_size_v = self.head_size if head_size_v is None else head_size_v
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
@@ -326,6 +331,15 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
kv_sharing_target_layer_name,
|
||||
**extra_impl_args,
|
||||
)
|
||||
if extra_cache_para is not None:
|
||||
self.quant_manager = StaticQuantManager(
|
||||
layer_id=extra_cache_para.get("layer_id", None),
|
||||
shape=(self.num_kv_heads, self.head_size_v),
|
||||
dtype=torch.float32,
|
||||
total_layer_num=extra_cache_para.get("total_layer_num", None)
|
||||
)
|
||||
else:
|
||||
self.quant_manager = None
|
||||
self.backend = AttentionBackendEnum[self.attn_backend.get_name()]
|
||||
self.dtype = dtype
|
||||
|
||||
@@ -333,7 +347,10 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
# torch.compile works by registering the attention as one giant
|
||||
# opaque custom op. For other platforms, we directly call them
|
||||
# and let torch.compile handle them.
|
||||
self.use_direct_call = not current_platform.opaque_attention_op()
|
||||
if _USE_TORCH_OPS:
|
||||
self.use_direct_call = False
|
||||
else:
|
||||
self.use_direct_call = True
|
||||
|
||||
self.use_output = self.attn_backend.accept_output_buffer
|
||||
compilation_config = vllm_config.compilation_config
|
||||
@@ -349,14 +366,26 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
compilation_config.static_forward_context,
|
||||
)
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
# use a placeholder kv cache tensor during init, which will be replaced
|
||||
# by bind_kv_cache
|
||||
# this variable will not be accessed if use_direct_call is True
|
||||
self.kv_cache = [
|
||||
torch.tensor([])
|
||||
for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
self.is_i8qi8ki8v = envs.VLLM_ATTN_OPT_LEVEL == 1
|
||||
self.is_i8qi8kf16v = envs.VLLM_ATTN_OPT_LEVEL == 2
|
||||
if self.is_i8qi8kf16v:
|
||||
self.kv_cache_scale = [
|
||||
torch.tensor([]) for _ in range(get_current_vllm_config(
|
||||
).parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
elif self.is_i8qi8ki8v:
|
||||
self.kv_cache_scale = [
|
||||
[torch.tensor([]), torch.tensor([])] for _ in range(get_current_vllm_config(
|
||||
).parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
# use a placeholder kv cache tensor during init, which will be replaced
|
||||
# by bind_kv_cache
|
||||
# this variable will not be accessed if use_direct_call is True
|
||||
|
||||
# Initialize KV cache quantization attributes
|
||||
_init_kv_cache_quant(self, quant_config, prefix)
|
||||
@@ -396,6 +425,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
context using
|
||||
`vllm.forward_context.get_forward_context().attn_metadata`.
|
||||
"""
|
||||
optional_args = {}
|
||||
if self.calculate_kv_scales:
|
||||
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
|
||||
output_dtype = query.dtype
|
||||
@@ -412,15 +442,8 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
query, _ = self.query_quant(query, self._q_scale)
|
||||
|
||||
if self.use_output:
|
||||
if output_shape is None:
|
||||
# Handle both 2D [num_tokens, hidden] and
|
||||
# 3D [num_tokens, heads, head_dim] query
|
||||
num_tokens = query.shape[0]
|
||||
output_shape = torch.Size(
|
||||
(num_tokens, self.num_heads * self.head_size_v)
|
||||
)
|
||||
output_shape = output_shape if output_shape is not None else query.shape
|
||||
output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
|
||||
hidden_size = output_shape[-1]
|
||||
# Reshape the query, key, and value tensors.
|
||||
# NOTE(woosuk): We do this outside the custom op to minimize the
|
||||
# CPU overheads from the non-CUDA-graph regions.
|
||||
@@ -430,46 +453,50 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
if value is not None:
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size_v)
|
||||
kv_cache_dummy_dep = None
|
||||
if self.use_direct_call:
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
if (
|
||||
not self.attn_backend.forward_includes_kv_cache_update
|
||||
and self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
kv_cache_dummy_dep = unified_kv_cache_update(
|
||||
key, value, self.layer_name
|
||||
def direct_forward(layer_name: str, output: torch.Tensor):
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
if self.is_i8qi8ki8v or self.is_i8qi8kf16v:
|
||||
optional_args["kv_cache_scale"] = self.kv_cache_scale[forward_context.virtual_engine]
|
||||
output = self.impl.forward(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self_kv_cache,
|
||||
attn_metadata,
|
||||
output=output,
|
||||
**optional_args
|
||||
)
|
||||
unified_attention_with_output(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
self.layer_name,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
return output
|
||||
return maybe_transfer_kv_layer(direct_forward)(self.layer_name, output)
|
||||
else:
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
if (
|
||||
not self.attn_backend.forward_includes_kv_cache_update
|
||||
and self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
|
||||
key, value, self.layer_name
|
||||
)
|
||||
if self.is_i8qi8ki8v:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
kv_cache_scale = self.kv_cache_scale[forward_context.virtual_engine][0]
|
||||
v_cache_scale = self.kv_cache_scale[forward_context.virtual_engine][1]
|
||||
elif self.is_i8qi8kf16v:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
kv_cache_scale = self.kv_cache_scale[forward_context.virtual_engine]
|
||||
v_cache_scale = None
|
||||
else:
|
||||
kv_cache_scale = None
|
||||
v_cache_scale = None
|
||||
torch.ops.vllm.unified_attention_with_output(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
self.layer_name,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
kv_cache_scale,
|
||||
v_cache_scale
|
||||
)
|
||||
return output.view(-1, hidden_size)
|
||||
return output.view(-1, self.hidden_size)
|
||||
else:
|
||||
assert self.attn_backend.forward_includes_kv_cache_update, (
|
||||
"Split KV cache update not supported when output tensor not provided."
|
||||
@@ -521,6 +548,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
# Should not be called for enc-dec or encoder-only attention.
|
||||
assert self.attn_type == AttentionType.DECODER
|
||||
# TODO : kernel unsupport kvcache for sliding_window, use FullAttentionSpec replace
|
||||
if self.sliding_window is not None:
|
||||
assert not vllm_config.model_config.use_mla, (
|
||||
"MLA is not supported for slidingwindow"
|
||||
@@ -689,6 +717,8 @@ def unified_attention_with_output(
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
kv_cache_scale: torch.Tensor | None = None,
|
||||
v_cache_scale: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
kv_cache_dummy_dep: torch.Tensor | None = None,
|
||||
@@ -696,9 +726,7 @@ def unified_attention_with_output(
|
||||
# kv_cache_dummy_dep is not used but accepting it creates a data dependency
|
||||
# that ensures torch.compile preserves ordering between KV cache update and
|
||||
# attention forward.
|
||||
del kv_cache_dummy_dep
|
||||
attn_metadata, self, kv_cache, _ = get_attention_context(layer_name)
|
||||
|
||||
self.impl.forward(
|
||||
self,
|
||||
query,
|
||||
@@ -707,6 +735,7 @@ def unified_attention_with_output(
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output=output,
|
||||
kv_cache_scale = [kv_cache_scale, v_cache_scale] if envs.VLLM_ATTN_OPT_LEVEL==1 else kv_cache_scale,
|
||||
output_scale=output_scale,
|
||||
output_block_scale=output_block_scale,
|
||||
)
|
||||
@@ -718,6 +747,8 @@ def unified_attention_with_output_fake(
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
kv_cache_scale: torch.Tensor | None = None,
|
||||
v_cache_scale: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
kv_cache_dummy_dep: torch.Tensor | None = None,
|
||||
|
||||
131
vllm/model_executor/layers/attention/extra_cache.py
Normal file
131
vllm/model_executor/layers/attention/extra_cache.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
|
||||
import torch
|
||||
from filelock import FileLock
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class StaticQuantManager:
|
||||
def __init__(
|
||||
self,
|
||||
layer_id: int,
|
||||
shape: tuple,
|
||||
dtype: torch.dtype,
|
||||
total_layer_num: int,
|
||||
device: str = None,
|
||||
tp_size: int = None,
|
||||
tp_rank: int = None,
|
||||
file_save_path: str = None,
|
||||
save_step: int = 100,
|
||||
info_step: int = 100,
|
||||
):
|
||||
# update parament
|
||||
if tp_size is None:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if tp_rank is None:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
if file_save_path is None:
|
||||
file_save_path = envs.VLLM_ATTN_STATIC_QUANT_SCALE_FILE_PATH
|
||||
if device is None:
|
||||
device = "cuda"
|
||||
|
||||
# check parament
|
||||
if file_save_path in [None, ""]:
|
||||
self.disable = True
|
||||
return
|
||||
|
||||
para_dir = os.path.dirname(file_save_path)
|
||||
assert os.path.exists(para_dir), (
|
||||
f"StaticQuantManager workdir {para_dir} not exist!"
|
||||
)
|
||||
self.disable = os.path.exists(file_save_path)
|
||||
if self.disable:
|
||||
return
|
||||
|
||||
assert layer_id is not None
|
||||
assert total_layer_num is not None
|
||||
|
||||
world_rank = torch.distributed.get_rank()
|
||||
work_dir = os.path.join(para_dir, "StaticQuantManagerWorkdir")
|
||||
self.operator = world_rank == 0 and layer_id == 0
|
||||
if not os.path.exists(work_dir):
|
||||
if self.operator:
|
||||
logger.debug(f"StaticQuantManager Creat {work_dir}!")
|
||||
os.mkdir(work_dir)
|
||||
self.file_save_path = file_save_path
|
||||
self.work_dir = work_dir
|
||||
|
||||
self.tp_size = tp_size
|
||||
self.tp_rank = tp_rank
|
||||
self.world_rank = world_rank
|
||||
self.layer_id = layer_id
|
||||
self.total_layer_num = total_layer_num
|
||||
self.save_step = save_step
|
||||
self.info_step = info_step
|
||||
self.update_count = 0
|
||||
self.save_flag = False
|
||||
self.scales = torch.zeros(shape, dtype=dtype, device=device)
|
||||
logger.debug(
|
||||
f"StaticQuantManager info: world_rank:{self.world_rank} tp_rank:{self.tp_rank} layer_id:{self.layer_id} scale shape:{shape} self.scales:{self.scales.device}"
|
||||
)
|
||||
|
||||
def check_enable(self):
|
||||
return not self.disable
|
||||
|
||||
def update_data(self, data):
|
||||
if self.disable:
|
||||
return
|
||||
|
||||
self.scales = torch.max(data, self.scales)
|
||||
|
||||
# save file
|
||||
self.update_count += 1
|
||||
if self.update_count % self.info_step == 0 and self.operator:
|
||||
logger.info(f"StaticQuantManager run update_data {self.update_count} step")
|
||||
|
||||
if self.update_count % self.save_step == 0:
|
||||
# step1: save to disk
|
||||
save_file_path = os.path.join(
|
||||
self.work_dir, f"{self.layer_id}_{self.tp_rank}.pt"
|
||||
)
|
||||
lock_file_path = os.path.join(
|
||||
self.work_dir, f"{self.layer_id}_{self.tp_rank}.lock"
|
||||
)
|
||||
lock = FileLock(lock_file_path)
|
||||
cpu_data = self.scales.cpu()
|
||||
with lock:
|
||||
torch.save(cpu_data, save_file_path)
|
||||
|
||||
# step2: merge and save
|
||||
if self.save_flag and self.operator:
|
||||
save_dict = {}
|
||||
for idx in range(self.total_layer_num):
|
||||
tp_datas = []
|
||||
for tp_rank in range(self.tp_size):
|
||||
load_file = os.path.join(self.work_dir, f"{idx}_{tp_rank}.pt")
|
||||
lock_file_path = os.path.join(
|
||||
self.work_dir, f"{idx}_{tp_rank}.lock"
|
||||
)
|
||||
lock = FileLock(lock_file_path)
|
||||
with lock:
|
||||
cur_data = torch.load(load_file)
|
||||
tp_datas.append(cur_data)
|
||||
|
||||
layer_data = torch.concat(tp_datas)
|
||||
save_dict[f"layer_{idx}"] = layer_data
|
||||
|
||||
torch.save(save_dict, self.file_save_path)
|
||||
logger.info(
|
||||
f"StaticQuantManager save to {self.file_save_path} with {self.update_count} step"
|
||||
)
|
||||
self.save_flag = True
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,21 +2,94 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||
from vllm.utils.math_utils import round_up
|
||||
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.v1.attention.ops.vit_attn_wrappers import (
|
||||
vit_flash_attn_wrapper,
|
||||
vit_flashinfer_wrapper,
|
||||
vit_torch_sdpa_wrapper,
|
||||
vit_triton_attn_wrapper,
|
||||
)
|
||||
import ixformer.contrib.vllm_flash_attn as ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Batch buckets for cuDNN graph caching.
|
||||
# Graphs use batch size and max sequence length as cache key.
|
||||
# This avoids creating a new graph for each unique set of
|
||||
# batch size and max sequence length at runtime.
|
||||
# From the cuDNN team's performance measurements, there
|
||||
# is no significant kernel performance difference between padding
|
||||
# to a smaller batch size/seq length and padding to larger
|
||||
# ones. The bucketing here is solely used to avoid memory
|
||||
# operation overhead, which won't be needed if we have CUDA
|
||||
# graph support in the future.
|
||||
# TODO: Remove buckets after issue #34763
|
||||
# (cuda graph support) is addressed.
|
||||
FLASHINFER_BATCH_BUCKETS = [8, 16, 32, 64]
|
||||
FLASHINFER_MAX_SEQLEN_BUCKETS = [
|
||||
1 * 1024,
|
||||
2 * 1024,
|
||||
4 * 1024,
|
||||
8 * 1024,
|
||||
16 * 1024,
|
||||
32 * 1024,
|
||||
64 * 1024,
|
||||
128 * 1024,
|
||||
]
|
||||
|
||||
# Workspace buffer for FlashInfer CuDNN backend
|
||||
FLASHINFER_CUDNN_WORKSPACE_SIZE_BYTES = 128 * 1024 * 1024
|
||||
_flashinfer_workspace_buffer: torch.Tensor | None = None
|
||||
|
||||
|
||||
def _get_flashinfer_workspace_buffer() -> torch.Tensor:
|
||||
global _flashinfer_workspace_buffer
|
||||
if _flashinfer_workspace_buffer is None:
|
||||
_flashinfer_workspace_buffer = torch.zeros(
|
||||
FLASHINFER_CUDNN_WORKSPACE_SIZE_BYTES,
|
||||
dtype=torch.uint8,
|
||||
device="cuda",
|
||||
)
|
||||
return _flashinfer_workspace_buffer
|
||||
|
||||
|
||||
def add_padding_to_seqlens(
|
||||
seq: np.ndarray,
|
||||
batch_size: int,
|
||||
padding_value: int,
|
||||
) -> np.ndarray:
|
||||
batch_size_padded = next(
|
||||
(b for b in FLASHINFER_BATCH_BUCKETS if b >= batch_size),
|
||||
round_up(batch_size, FLASHINFER_BATCH_BUCKETS[0]),
|
||||
)
|
||||
if batch_size_padded == batch_size:
|
||||
return seq
|
||||
return np.concatenate(
|
||||
[
|
||||
seq,
|
||||
np.full((batch_size_padded - batch_size,), padding_value, dtype=seq.dtype),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def bucket_flashinfer_max_seqlen(
|
||||
real_max_seqlen: int,
|
||||
) -> int:
|
||||
if real_max_seqlen <= 0:
|
||||
return FLASHINFER_MAX_SEQLEN_BUCKETS[0]
|
||||
return next(
|
||||
(s for s in FLASHINFER_MAX_SEQLEN_BUCKETS if s >= real_max_seqlen),
|
||||
round_up(real_max_seqlen, FLASHINFER_MAX_SEQLEN_BUCKETS[-1]),
|
||||
)
|
||||
|
||||
|
||||
# --8<-- [start:mm_encoder_attn]
|
||||
@CustomOp.register("mm_encoder_attn")
|
||||
@@ -24,6 +97,67 @@ class MMEncoderAttention(CustomOp):
|
||||
"""Multi-headed attention without any cache, used for multimodal encoder."""
|
||||
|
||||
# --8<-- [end:mm_encoder_attn]
|
||||
@classmethod
|
||||
def compute_max_seqlen(
|
||||
cls,
|
||||
attn_backend: AttentionBackendEnum,
|
||||
cu_seqlens: np.ndarray,
|
||||
) -> int:
|
||||
max_seqlen = 0
|
||||
if (
|
||||
attn_backend
|
||||
in (
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
AttentionBackendEnum.FLASHINFER,
|
||||
)
|
||||
and len(cu_seqlens) >= 2
|
||||
):
|
||||
max_seqlen = int((cu_seqlens[1:] - cu_seqlens[:-1]).max())
|
||||
if attn_backend == AttentionBackendEnum.FLASHINFER:
|
||||
max_seqlen = bucket_flashinfer_max_seqlen(max_seqlen)
|
||||
return max_seqlen
|
||||
|
||||
@classmethod
|
||||
def maybe_compute_sequence_lengths(
|
||||
cls,
|
||||
attn_backend: AttentionBackendEnum,
|
||||
cu_seqlens: np.ndarray,
|
||||
) -> np.ndarray | None:
|
||||
if attn_backend != AttentionBackendEnum.FLASHINFER:
|
||||
return None
|
||||
sequence_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
sequence_lengths = add_padding_to_seqlens(
|
||||
sequence_lengths, len(sequence_lengths), 0
|
||||
)
|
||||
return sequence_lengths
|
||||
|
||||
@classmethod
|
||||
def maybe_recompute_cu_seqlens(
|
||||
cls,
|
||||
attn_backend: AttentionBackendEnum,
|
||||
cu_seqlens: np.ndarray,
|
||||
hidden_size: int,
|
||||
tp_size: int,
|
||||
) -> np.ndarray:
|
||||
if attn_backend != AttentionBackendEnum.FLASHINFER:
|
||||
return cu_seqlens
|
||||
|
||||
batch_size = len(cu_seqlens) - 1
|
||||
scale = hidden_size // tp_size
|
||||
cu_seqlens = cu_seqlens * scale
|
||||
|
||||
cu_seqlens_qko = cu_seqlens
|
||||
cu_seqlens_v = cu_seqlens * 3
|
||||
|
||||
cu_seqlens_qko = add_padding_to_seqlens(
|
||||
cu_seqlens_qko, batch_size, cu_seqlens_qko[-1]
|
||||
)
|
||||
cu_seqlens_v = add_padding_to_seqlens(
|
||||
cu_seqlens_v, batch_size, cu_seqlens_v[-1]
|
||||
)
|
||||
return np.concatenate([cu_seqlens_qko, cu_seqlens_v])
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -46,10 +180,9 @@ class MMEncoderAttention(CustomOp):
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = scale
|
||||
self.scale = 1.0 / (head_size**0.5) if scale is None else scale
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
self.layer_name = prefix
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0, (
|
||||
f"num_heads ({self.num_heads}) is not "
|
||||
f"divisible by num_kv_heads ({self.num_kv_heads})"
|
||||
@@ -72,9 +205,14 @@ class MMEncoderAttention(CustomOp):
|
||||
}
|
||||
|
||||
self._fa_version = (
|
||||
get_flash_attn_version() if self.is_flash_attn_backend else None
|
||||
get_flash_attn_version(head_size=head_size)
|
||||
if self.is_flash_attn_backend
|
||||
else None
|
||||
)
|
||||
|
||||
if self.attn_backend == AttentionBackendEnum.FLASHINFER:
|
||||
_get_flashinfer_workspace_buffer()
|
||||
|
||||
logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")
|
||||
|
||||
@classmethod
|
||||
@@ -148,23 +286,27 @@ class MMEncoderAttention(CustomOp):
|
||||
bsz, q_len = query.size()[:2]
|
||||
kv_len = key.size(1)
|
||||
is_reshaped = query.dim() != 4
|
||||
query = query.view(bsz * q_len, self.num_heads, self.head_size)
|
||||
key = key.view(bsz * kv_len, self.num_kv_heads, self.head_size)
|
||||
value = value.view(bsz * kv_len, self.num_kv_heads, self.head_size)
|
||||
|
||||
query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len)
|
||||
|
||||
output = vit_flash_attn_wrapper(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
batch_size=bsz,
|
||||
is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
|
||||
fa_version=self._fa_version,
|
||||
scale=self.scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
cu_q = torch.tensor([0,] + [q_len for _ in range(bsz)], device=query.device, dtype=torch.int32).cumsum(dim=0, dtype=torch.int32)
|
||||
cu_kv = torch.tensor([0,] + [kv_len for _ in range(bsz)], device=query.device, dtype=torch.int32).cumsum(dim=0, dtype=torch.int32)
|
||||
out = ops.flash_attn_varlen_func(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
cu_q,
|
||||
cu_kv,
|
||||
q_len,
|
||||
kv_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
)
|
||||
out = out.view(bsz, q_len, self.num_heads, self.head_size)
|
||||
if is_reshaped:
|
||||
output = output.reshape(bsz, q_len, -1)
|
||||
return output
|
||||
out = out.reshape(bsz, q_len, -1)
|
||||
return out
|
||||
|
||||
def _forward_triton(
|
||||
self,
|
||||
@@ -201,6 +343,27 @@ class MMEncoderAttention(CustomOp):
|
||||
output = output.reshape(bsz, q_len, -1)
|
||||
return output
|
||||
|
||||
def _forward_flashinfer(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None,
|
||||
sequence_lengths: torch.Tensor
|
||||
| None = None, # Only used for FlashInfer CuDNN backend
|
||||
) -> torch.Tensor:
|
||||
return vit_flashinfer_wrapper(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
scale=self.scale,
|
||||
workspace_buffer=_get_flashinfer_workspace_buffer(),
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
sequence_lengths=sequence_lengths,
|
||||
)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
@@ -208,6 +371,8 @@ class MMEncoderAttention(CustomOp):
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
sequence_lengths: torch.Tensor
|
||||
| None = None, # Only used for FlashInfer CuDNN backend
|
||||
) -> torch.Tensor:
|
||||
return self._forward_sdpa(query, key, value, cu_seqlens)
|
||||
|
||||
@@ -218,11 +383,17 @@ class MMEncoderAttention(CustomOp):
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
sequence_lengths: torch.Tensor
|
||||
| None = None, # Only used for FlashInfer CuDNN backend
|
||||
) -> torch.Tensor:
|
||||
if self.is_flash_attn_backend:
|
||||
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
|
||||
elif self.attn_backend == AttentionBackendEnum.TRITON_ATTN:
|
||||
return self._forward_triton(query, key, value, cu_seqlens, max_seqlen)
|
||||
elif self.attn_backend == AttentionBackendEnum.FLASHINFER:
|
||||
return self._forward_flashinfer(
|
||||
query, key, value, cu_seqlens, max_seqlen, sequence_lengths
|
||||
)
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
return self._forward_sdpa(query, key, value, cu_seqlens)
|
||||
else:
|
||||
@@ -238,6 +409,8 @@ class MMEncoderAttention(CustomOp):
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
sequence_lengths: torch.Tensor
|
||||
| None = None, # Only used for FlashInfer CuDNN backend
|
||||
) -> torch.Tensor:
|
||||
return self._forward_sdpa(query, key, value, cu_seqlens)
|
||||
|
||||
@@ -248,6 +421,8 @@ class MMEncoderAttention(CustomOp):
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
sequence_lengths: torch.Tensor
|
||||
| None = None, # Only used for FlashInfer CuDNN backend
|
||||
) -> torch.Tensor:
|
||||
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
|
||||
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
|
||||
|
||||
@@ -7,11 +7,17 @@
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
from .chunk import chunk_gated_delta_rule
|
||||
from .fused_recurrent import fused_recurrent_gated_delta_rule
|
||||
from .fused_recurrent import (
|
||||
fused_recurrent_gated_delta_rule,
|
||||
fused_recurrent_gated_delta_rule_packed_decode,
|
||||
)
|
||||
from .fused_sigmoid_gating import fused_sigmoid_gating_delta_rule_update
|
||||
from .layernorm_guard import RMSNormGated
|
||||
|
||||
__all__ = [
|
||||
"RMSNormGated",
|
||||
"chunk_gated_delta_rule",
|
||||
"fused_recurrent_gated_delta_rule",
|
||||
"fused_recurrent_gated_delta_rule_packed_decode",
|
||||
"fused_sigmoid_gating_delta_rule_update",
|
||||
]
|
||||
|
||||
@@ -30,7 +30,7 @@ def chunk_gated_delta_rule_fwd(
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
):
|
||||
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
|
||||
# obtain WY representation. u is actually the new v.
|
||||
@@ -84,7 +84,7 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
):
|
||||
if use_qk_l2norm_in_kernel:
|
||||
@@ -117,7 +117,7 @@ def chunk_gated_delta_rule(
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
):
|
||||
r"""
|
||||
@@ -141,7 +141,7 @@ def chunk_gated_delta_rule(
|
||||
Default: `None`.
|
||||
output_final_state (Optional[bool]):
|
||||
Whether to output the final state of shape `[N, H, V, K]`. Default: `False`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
cu_seqlens (torch.Tensor):
|
||||
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
||||
consistent with the FlashAttention API.
|
||||
Returns:
|
||||
@@ -171,7 +171,7 @@ def chunk_gated_delta_rule(
|
||||
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
|
||||
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
|
||||
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
|
||||
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
|
||||
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.int32)
|
||||
>>> o_var, ht_var = chunk_gated_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
|
||||
@@ -288,7 +288,7 @@ def chunk_gated_delta_rule_fwd_h(
|
||||
output_final_state: bool = False,
|
||||
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
|
||||
save_new_value: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# This kernel is slightly different from fla to support Q/K with different head numbers.
|
||||
# In fla, Q/K always have the same head number, so Hg is always equal to H.
|
||||
|
||||
@@ -89,7 +89,7 @@ def chunk_fwd_kernel_o(
|
||||
|
||||
b_o = tl.zeros([BT, BV], dtype=tl.float32)
|
||||
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
|
||||
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_q = tl.make_block_ptr(
|
||||
q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
|
||||
@@ -145,7 +145,7 @@ def chunk_fwd_o(
|
||||
h: torch.Tensor,
|
||||
g: torch.Tensor | None = None, # cumsum of log decay
|
||||
scale: float | None = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
chunk_size: int = 64,
|
||||
) -> torch.Tensor:
|
||||
B, T, Hg, K, V = *q.shape, v.shape[-1]
|
||||
|
||||
@@ -102,7 +102,7 @@ def chunk_scaled_dot_kkt_fwd(
|
||||
k: torch.Tensor,
|
||||
g: torch.Tensor | None = None,
|
||||
beta: torch.Tensor | None = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
chunk_size: int = 64,
|
||||
output_dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
@@ -116,7 +116,7 @@ def chunk_scaled_dot_kkt_fwd(
|
||||
The beta tensor of shape `[B, T, H]`.
|
||||
g (torch.Tensor):
|
||||
The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
cu_seqlens (torch.Tensor):
|
||||
The cumulative sequence lengths of the input tensor.
|
||||
Default: None
|
||||
chunk_size (int):
|
||||
|
||||
@@ -106,12 +106,12 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
||||
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
|
||||
else:
|
||||
i_t = 0
|
||||
# Load state index and check for PAD_SLOT_ID (-1)
|
||||
# Load state index and check for invalid entries
|
||||
state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
|
||||
tl.int64
|
||||
)
|
||||
# Skip if state index is invalid (PAD_SLOT_ID = -1)
|
||||
if state_idx < 0:
|
||||
# Skip if state index is invalid (NULL_BLOCK_ID=0)
|
||||
if state_idx <= 0:
|
||||
return
|
||||
p_h0 = h0 + state_idx * stride_init_state_token
|
||||
else:
|
||||
@@ -150,12 +150,12 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
||||
|
||||
# keep the states for multi-query tokens
|
||||
if INPLACE_FINAL_STATE:
|
||||
# Load state index and check for PAD_SLOT_ID (-1)
|
||||
# Load state index and check for invalid entries
|
||||
final_state_idx = tl.load(
|
||||
ssm_state_indices + i_n * stride_indices_seq + i_t
|
||||
).to(tl.int64)
|
||||
# Only store if state index is valid (not PAD_SLOT_ID)
|
||||
if final_state_idx >= 0:
|
||||
# Only store if state index is valid (not NULL_BLOCK_ID=0)
|
||||
if final_state_idx > 0:
|
||||
p_ht = ht + final_state_idx * stride_final_state_token
|
||||
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
|
||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|
||||
@@ -184,7 +184,7 @@ def fused_recurrent_gated_delta_rule_fwd(
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
ssm_state_indices: torch.Tensor | None = None,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
@@ -252,6 +252,232 @@ def fused_recurrent_gated_delta_rule_fwd(
|
||||
return o, final_state
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_recurrent_gated_delta_rule_packed_decode_kernel(
|
||||
mixed_qkv,
|
||||
a,
|
||||
b,
|
||||
A_log,
|
||||
dt_bias,
|
||||
o,
|
||||
h0,
|
||||
ht,
|
||||
ssm_state_indices,
|
||||
scale,
|
||||
stride_mixed_qkv_tok: tl.constexpr,
|
||||
stride_a_tok: tl.constexpr,
|
||||
stride_b_tok: tl.constexpr,
|
||||
stride_init_state_token: tl.constexpr,
|
||||
stride_final_state_token: tl.constexpr,
|
||||
stride_indices_seq: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
HV: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
SOFTPLUS_THRESHOLD: tl.constexpr,
|
||||
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
|
||||
):
|
||||
i_v, i_nh = tl.program_id(0), tl.program_id(1)
|
||||
i_n, i_hv = i_nh // HV, i_nh % HV
|
||||
i_h = i_hv // (HV // H)
|
||||
|
||||
o_k = tl.arange(0, BK)
|
||||
o_v = i_v * BV + tl.arange(0, BV)
|
||||
mask_k = o_k < K
|
||||
mask_v = o_v < V
|
||||
mask_h = mask_v[:, None] & mask_k[None, :]
|
||||
|
||||
state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq).to(tl.int64)
|
||||
p_o = o + (i_n * HV + i_hv) * V + o_v
|
||||
|
||||
# Skip if state index is invalid (NULL_BLOCK_ID=0)
|
||||
if state_idx <= 0:
|
||||
zero = tl.zeros([BV], dtype=tl.float32).to(p_o.dtype.element_ty)
|
||||
tl.store(p_o, zero, mask=mask_v)
|
||||
return
|
||||
|
||||
p_h0 = h0 + state_idx * stride_init_state_token
|
||||
p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
|
||||
b_h = tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
|
||||
|
||||
p_mixed = mixed_qkv + i_n * stride_mixed_qkv_tok
|
||||
q_off = i_h * K + o_k
|
||||
k_off = (H * K) + i_h * K + o_k
|
||||
v_off = (2 * H * K) + i_hv * V + o_v
|
||||
b_q = tl.load(p_mixed + q_off, mask=mask_k, other=0).to(tl.float32)
|
||||
b_k = tl.load(p_mixed + k_off, mask=mask_k, other=0).to(tl.float32)
|
||||
b_v = tl.load(p_mixed + v_off, mask=mask_v, other=0).to(tl.float32)
|
||||
|
||||
if USE_QK_L2NORM_IN_KERNEL:
|
||||
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
|
||||
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
|
||||
b_q = b_q * scale
|
||||
|
||||
a_val = tl.load(a + i_n * stride_a_tok + i_hv).to(tl.float32)
|
||||
b_val = tl.load(b + i_n * stride_b_tok + i_hv).to(tl.float32)
|
||||
A_log_val = tl.load(A_log + i_hv).to(tl.float32)
|
||||
dt_bias_val = tl.load(dt_bias + i_hv).to(tl.float32)
|
||||
x = a_val + dt_bias_val
|
||||
softplus_x = tl.where(x <= SOFTPLUS_THRESHOLD, tl.log(1.0 + tl.exp(x)), x)
|
||||
g_val = -tl.exp(A_log_val) * softplus_x
|
||||
beta_val = tl.sigmoid(b_val).to(b.dtype.element_ty).to(tl.float32)
|
||||
|
||||
b_h *= exp(g_val)
|
||||
b_v -= tl.sum(b_h * b_k[None, :], 1)
|
||||
b_v *= beta_val
|
||||
b_h += b_v[:, None] * b_k[None, :]
|
||||
b_o = tl.sum(b_h * b_q[None, :], 1)
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
|
||||
|
||||
p_ht = ht + state_idx * stride_final_state_token
|
||||
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
|
||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|
||||
|
||||
|
||||
def fused_recurrent_gated_delta_rule_packed_decode(
|
||||
mixed_qkv: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
A_log: torch.Tensor,
|
||||
dt_bias: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
ssm_state_indices: torch.Tensor,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if mixed_qkv.ndim != 2:
|
||||
raise ValueError(
|
||||
f"`mixed_qkv` must be a 2D tensor (got ndim={mixed_qkv.ndim})."
|
||||
)
|
||||
if mixed_qkv.stride(-1) != 1:
|
||||
raise ValueError("`mixed_qkv` must be contiguous in the last dim.")
|
||||
if a.ndim != 2 or b.ndim != 2:
|
||||
raise ValueError(
|
||||
f"`a` and `b` must be 2D tensors (got a.ndim={a.ndim}, b.ndim={b.ndim})."
|
||||
)
|
||||
if a.stride(-1) != 1 or b.stride(-1) != 1:
|
||||
raise ValueError("`a`/`b` must be contiguous in the last dim.")
|
||||
if A_log.ndim != 1 or dt_bias.ndim != 1:
|
||||
raise ValueError("`A_log`/`dt_bias` must be 1D tensors.")
|
||||
if A_log.stride(0) != 1 or dt_bias.stride(0) != 1:
|
||||
raise ValueError("`A_log`/`dt_bias` must be contiguous.")
|
||||
if ssm_state_indices.ndim != 1:
|
||||
raise ValueError(
|
||||
f"`ssm_state_indices` must be 1D for packed decode (got ndim={ssm_state_indices.ndim})."
|
||||
)
|
||||
if not out.is_contiguous():
|
||||
raise ValueError("`out` must be contiguous.")
|
||||
|
||||
dev = mixed_qkv.device
|
||||
if (
|
||||
a.device != dev
|
||||
or b.device != dev
|
||||
or A_log.device != dev
|
||||
or dt_bias.device != dev
|
||||
or initial_state.device != dev
|
||||
or out.device != dev
|
||||
or ssm_state_indices.device != dev
|
||||
):
|
||||
raise ValueError("All inputs must be on the same device.")
|
||||
|
||||
B = mixed_qkv.shape[0]
|
||||
if a.shape[0] != B or b.shape[0] != B:
|
||||
raise ValueError(
|
||||
"Mismatched batch sizes: "
|
||||
f"mixed_qkv.shape[0]={B}, a.shape[0]={a.shape[0]}, b.shape[0]={b.shape[0]}."
|
||||
)
|
||||
if ssm_state_indices.shape[0] != B:
|
||||
raise ValueError(
|
||||
f"`ssm_state_indices` must have shape [B] (got {tuple(ssm_state_indices.shape)}; expected ({B},))."
|
||||
)
|
||||
|
||||
if initial_state.ndim != 4:
|
||||
raise ValueError(
|
||||
f"`initial_state` must be a 4D tensor (got ndim={initial_state.ndim})."
|
||||
)
|
||||
if initial_state.stride(-1) != 1:
|
||||
raise ValueError("`initial_state` must be contiguous in the last dim.")
|
||||
HV, V, K = initial_state.shape[-3:]
|
||||
if a.shape[1] != HV or b.shape[1] != HV:
|
||||
raise ValueError(
|
||||
f"`a`/`b` must have shape [B, HV] with HV={HV} (got a.shape={tuple(a.shape)}, b.shape={tuple(b.shape)})."
|
||||
)
|
||||
if A_log.numel() != HV or dt_bias.numel() != HV:
|
||||
raise ValueError(
|
||||
f"`A_log` and `dt_bias` must have {HV} elements (got A_log.numel()={A_log.numel()}, dt_bias.numel()={dt_bias.numel()})."
|
||||
)
|
||||
if out.shape != (B, 1, HV, V):
|
||||
raise ValueError(
|
||||
f"`out` must have shape {(B, 1, HV, V)} (got out.shape={tuple(out.shape)})."
|
||||
)
|
||||
|
||||
qkv_dim = mixed_qkv.shape[1]
|
||||
qk_dim = qkv_dim - HV * V
|
||||
if qk_dim <= 0 or qk_dim % 2 != 0:
|
||||
raise ValueError(
|
||||
f"Invalid packed `mixed_qkv` last dim={qkv_dim} for HV={HV}, V={V}."
|
||||
)
|
||||
q_dim = qk_dim // 2
|
||||
if q_dim % K != 0:
|
||||
raise ValueError(f"Invalid packed Q size {q_dim}: must be divisible by K={K}.")
|
||||
H = q_dim // K
|
||||
if H <= 0 or HV % H != 0:
|
||||
raise ValueError(
|
||||
f"Invalid head config inferred from mixed_qkv: H={H}, HV={HV}."
|
||||
)
|
||||
|
||||
BK = triton.next_power_of_2(K)
|
||||
if triton.cdiv(K, BK) != 1:
|
||||
raise ValueError(
|
||||
f"Packed decode kernel only supports NK=1 (got K={K}, BK={BK})."
|
||||
)
|
||||
BV = min(triton.next_power_of_2(V), 32)
|
||||
num_stages = 3
|
||||
num_warps = 1
|
||||
|
||||
stride_mixed_qkv_tok = mixed_qkv.stride(0)
|
||||
stride_a_tok = a.stride(0)
|
||||
stride_b_tok = b.stride(0)
|
||||
stride_init_state_token = initial_state.stride(0)
|
||||
stride_final_state_token = initial_state.stride(0)
|
||||
stride_indices_seq = ssm_state_indices.stride(0)
|
||||
|
||||
NV = triton.cdiv(V, BV)
|
||||
grid = (NV, B * HV)
|
||||
fused_recurrent_gated_delta_rule_packed_decode_kernel[grid](
|
||||
mixed_qkv=mixed_qkv,
|
||||
a=a,
|
||||
b=b,
|
||||
A_log=A_log,
|
||||
dt_bias=dt_bias,
|
||||
o=out,
|
||||
h0=initial_state,
|
||||
ht=initial_state,
|
||||
ssm_state_indices=ssm_state_indices,
|
||||
scale=scale,
|
||||
stride_mixed_qkv_tok=stride_mixed_qkv_tok,
|
||||
stride_a_tok=stride_a_tok,
|
||||
stride_b_tok=stride_b_tok,
|
||||
stride_init_state_token=stride_init_state_token,
|
||||
stride_final_state_token=stride_final_state_token,
|
||||
stride_indices_seq=stride_indices_seq,
|
||||
H=H,
|
||||
HV=HV,
|
||||
K=K,
|
||||
V=V,
|
||||
BK=BK,
|
||||
BV=BV,
|
||||
SOFTPLUS_THRESHOLD=20.0,
|
||||
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
return out, initial_state
|
||||
|
||||
|
||||
class FusedRecurrentFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
@@ -264,7 +490,7 @@ class FusedRecurrentFunction(torch.autograd.Function):
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
ssm_state_indices: torch.Tensor | None = None,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
@@ -296,7 +522,7 @@ def fused_recurrent_gated_delta_rule(
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
ssm_state_indices: torch.Tensor | None = None,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
@@ -324,7 +550,7 @@ def fused_recurrent_gated_delta_rule(
|
||||
inplace_final_state: bool:
|
||||
Whether to store the final state in-place to save memory.
|
||||
Default: `True`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
cu_seqlens (torch.Tensor):
|
||||
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
||||
consistent with the FlashAttention API.
|
||||
ssm_state_indices (Optional[torch.Tensor]):
|
||||
@@ -358,7 +584,7 @@ def fused_recurrent_gated_delta_rule(
|
||||
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
|
||||
>>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
|
||||
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
|
||||
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
|
||||
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.int32)
|
||||
>>> o_var, ht_var = fused_gated_recurrent_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
|
||||
279
vllm/model_executor/layers/fla/ops/fused_sigmoid_gating.py
Normal file
279
vllm/model_executor/layers/fla/ops/fused_sigmoid_gating.py
Normal file
@@ -0,0 +1,279 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{
|
||||
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
"IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None,
|
||||
"IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None,
|
||||
}
|
||||
)
|
||||
@triton.jit(do_not_specialize=["N", "T"])
|
||||
def fused_sigmoid_gating_delta_rule_update_kernel(
|
||||
A_log,
|
||||
a,
|
||||
b,
|
||||
dt_bias,
|
||||
beta,
|
||||
threshold,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
h0,
|
||||
ht,
|
||||
cu_seqlens,
|
||||
ssm_state_indices,
|
||||
num_accepted_tokens,
|
||||
scale,
|
||||
N: tl.int64, # num of sequences
|
||||
T: tl.int64, # num of tokens
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
HV: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
stride_init_state_token: tl.constexpr,
|
||||
stride_final_state_token: tl.constexpr,
|
||||
stride_indices_seq: tl.constexpr,
|
||||
stride_indices_tok: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
||||
INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace
|
||||
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
IS_CONTINUOUS_BATCHING: tl.constexpr,
|
||||
IS_SPEC_DECODING: tl.constexpr,
|
||||
IS_KDA: tl.constexpr,
|
||||
):
|
||||
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_n, i_hv = i_nh // HV, i_nh % HV
|
||||
i_h = i_hv // (HV // H)
|
||||
if IS_VARLEN:
|
||||
bos, eos = (
|
||||
tl.load(cu_seqlens + i_n).to(tl.int64),
|
||||
tl.load(cu_seqlens + i_n + 1).to(tl.int64),
|
||||
)
|
||||
all = T
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_n * T, i_n * T + T
|
||||
all = B * T
|
||||
|
||||
if T == 0:
|
||||
# no tokens to process for this sequence
|
||||
return
|
||||
|
||||
o_k = i_k * BK + tl.arange(0, BK)
|
||||
o_v = i_v * BV + tl.arange(0, BV)
|
||||
|
||||
p_q = q + (bos * H + i_h) * K + o_k
|
||||
p_k = k + (bos * H + i_h) * K + o_k
|
||||
p_v = v + (bos * HV + i_hv) * V + o_v
|
||||
|
||||
p_A_log = A_log + i_hv
|
||||
if not IS_KDA:
|
||||
p_a = a + bos * HV + i_hv
|
||||
p_dt_bias = dt_bias + i_hv
|
||||
else:
|
||||
p_a = a + (bos * HV + i_hv) * K + o_k
|
||||
p_dt_bias = dt_bias + i_hv * K + o_k
|
||||
|
||||
p_b = b + bos * HV + i_hv
|
||||
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
|
||||
|
||||
mask_k = o_k < K
|
||||
mask_v = o_v < V
|
||||
mask_h = mask_v[:, None] & mask_k[None, :]
|
||||
|
||||
b_h = tl.zeros([BV, BK], dtype=tl.float32)
|
||||
if USE_INITIAL_STATE:
|
||||
if IS_CONTINUOUS_BATCHING:
|
||||
if IS_SPEC_DECODING:
|
||||
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
|
||||
else:
|
||||
i_t = 0
|
||||
# Load state index and check for invalid entries
|
||||
state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
|
||||
tl.int64
|
||||
)
|
||||
# Skip if state index is invalid (NULL_BLOCK_ID=0)
|
||||
if state_idx <= 0:
|
||||
return
|
||||
p_h0 = h0 + state_idx * stride_init_state_token
|
||||
else:
|
||||
p_h0 = h0 + bos * HV * V * K
|
||||
p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
|
||||
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
|
||||
|
||||
for i_t in range(0, T):
|
||||
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
|
||||
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
|
||||
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
|
||||
b_b = tl.load(p_b).to(tl.float32)
|
||||
|
||||
# If the model is loaded in fp16, without the .float() here, A might be -inf
|
||||
x = tl.load(p_a).to(tl.float32) + tl.load(p_dt_bias).to(tl.float32)
|
||||
softplus_x = tl.where(
|
||||
beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x
|
||||
)
|
||||
b_g = -tl.exp(tl.load(p_A_log).to(tl.float32)) * softplus_x
|
||||
|
||||
# compute beta_output = sigmoid(b)
|
||||
b_beta = tl.sigmoid(b_b.to(tl.float32))
|
||||
|
||||
if USE_QK_L2NORM_IN_KERNEL:
|
||||
b_q = b_q * (tl.rsqrt(tl.sum(b_q * b_q) + 1e-6))
|
||||
b_k = b_k * (tl.rsqrt(tl.sum(b_k * b_k) + 1e-6))
|
||||
b_q = b_q * scale
|
||||
# [BV, BK]
|
||||
if not IS_KDA:
|
||||
b_h *= tl.exp(b_g)
|
||||
else:
|
||||
b_h *= tl.exp(b_g[None, :])
|
||||
# [BV]
|
||||
b_v -= tl.sum(b_h * b_k[None, :], 1)
|
||||
b_v *= b_beta
|
||||
# [BV, BK]
|
||||
b_h += b_v[:, None] * b_k[None, :]
|
||||
# [BV]
|
||||
b_o = tl.sum(b_h * b_q[None, :], 1)
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
|
||||
|
||||
# keep the states for multi-query tokens
|
||||
if INPLACE_FINAL_STATE:
|
||||
# Load state index and check for invalid entries
|
||||
final_state_idx = tl.load(
|
||||
ssm_state_indices + i_n * stride_indices_seq + i_t
|
||||
).to(tl.int64)
|
||||
# Only store if state index is valid (not NULL_BLOCK_ID=0)
|
||||
if final_state_idx > 0:
|
||||
p_ht = ht + final_state_idx * stride_final_state_token
|
||||
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
|
||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|
||||
else:
|
||||
p_ht = ht + (bos + i_t) * stride_final_state_token
|
||||
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
|
||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|
||||
|
||||
# Update pointers for next timestep
|
||||
p_q += H * K
|
||||
p_k += H * K
|
||||
p_o += HV * V
|
||||
p_v += HV * V
|
||||
p_b += HV
|
||||
p_a += HV
|
||||
|
||||
|
||||
def fused_sigmoid_gating_delta_rule_update(
|
||||
A_log: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
dt_bias: torch.Tensor,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
beta: float = 1.0,
|
||||
threshold: float = 20.0,
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
ssm_state_indices: torch.Tensor | None = None,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
is_kda: bool = False,
|
||||
):
|
||||
"""
|
||||
Fused triton implementation of sigmoid gating delta rule update.
|
||||
This function uses a single fused kernel that combines both sigmoid gating
|
||||
computation and the recurrent delta rule update for better performance.
|
||||
"""
|
||||
B, T, H, K, V = *k.shape, v.shape[-1]
|
||||
HV = v.shape[2]
|
||||
N = B if cu_seqlens is None else len(cu_seqlens) - 1
|
||||
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32)
|
||||
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
assert NK == 1, "NK > 1 is not supported yet"
|
||||
num_stages = 3
|
||||
num_warps = 4
|
||||
|
||||
if cu_seqlens is not None and q.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"The batch size is expected to be 1 rather than {q.shape[0]}"
|
||||
f" when using `cu_seqlens`. Please flatten variable-length"
|
||||
f" inputs before processing."
|
||||
)
|
||||
if scale is None:
|
||||
scale = k.shape[-1] ** -0.5
|
||||
else:
|
||||
assert scale > 0, "scale must be positive"
|
||||
|
||||
o = q.new_empty(NK, *v.shape)
|
||||
if inplace_final_state:
|
||||
final_state = initial_state
|
||||
else:
|
||||
final_state = q.new_empty(T, HV, V, K, dtype=initial_state.dtype)
|
||||
|
||||
stride_init_state_token = initial_state.stride(0)
|
||||
stride_final_state_token = final_state.stride(0)
|
||||
|
||||
if ssm_state_indices is None:
|
||||
stride_indices_seq, stride_indices_tok = 1, 1
|
||||
elif ssm_state_indices.ndim == 1:
|
||||
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1
|
||||
else:
|
||||
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride()
|
||||
|
||||
grid = (NK, NV, N * HV)
|
||||
fused_sigmoid_gating_delta_rule_update_kernel[grid](
|
||||
A_log=A_log,
|
||||
a=a.contiguous(),
|
||||
b=b.contiguous(),
|
||||
dt_bias=dt_bias,
|
||||
beta=beta,
|
||||
threshold=threshold,
|
||||
q=q.contiguous(),
|
||||
k=k.contiguous(),
|
||||
v=v.contiguous(),
|
||||
o=o,
|
||||
h0=initial_state,
|
||||
ht=final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
ssm_state_indices=ssm_state_indices,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
scale=scale,
|
||||
N=N,
|
||||
T=T,
|
||||
B=B,
|
||||
H=H,
|
||||
HV=HV,
|
||||
K=K,
|
||||
V=V,
|
||||
BK=BK,
|
||||
BV=BV,
|
||||
stride_init_state_token=stride_init_state_token,
|
||||
stride_final_state_token=stride_final_state_token,
|
||||
stride_indices_seq=stride_indices_seq,
|
||||
stride_indices_tok=stride_indices_tok,
|
||||
INPLACE_FINAL_STATE=inplace_final_state,
|
||||
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
|
||||
IS_KDA=is_kda,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
o = o.squeeze(0)
|
||||
return o, final_state
|
||||
@@ -15,14 +15,12 @@ from .utils import tensor_cache
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
|
||||
def prepare_lens(cu_seqlens: torch.Tensor) -> torch.Tensor:
|
||||
return cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_chunk_indices(
|
||||
cu_seqlens: torch.LongTensor, chunk_size: int
|
||||
) -> torch.LongTensor:
|
||||
def prepare_chunk_indices(cu_seqlens: torch.Tensor, chunk_size: int) -> torch.Tensor:
|
||||
indices = torch.cat(
|
||||
[
|
||||
torch.arange(n)
|
||||
@@ -33,9 +31,7 @@ def prepare_chunk_indices(
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_chunk_offsets(
|
||||
cu_seqlens: torch.LongTensor, chunk_size: int
|
||||
) -> torch.LongTensor:
|
||||
def prepare_chunk_offsets(cu_seqlens: torch.Tensor, chunk_size: int) -> torch.Tensor:
|
||||
return torch.cat(
|
||||
[cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]
|
||||
).cumsum(-1)
|
||||
|
||||
@@ -37,7 +37,7 @@ def fused_recurrent_kda_fwd(
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
ssm_state_indices: torch.Tensor | None = None,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
@@ -115,7 +115,7 @@ def fused_recurrent_kda(
|
||||
initial_state: torch.Tensor = None,
|
||||
inplace_final_state: bool = True,
|
||||
use_qk_l2norm_in_kernel: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
ssm_state_indices: torch.LongTensor | None = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -692,7 +692,7 @@ def chunk_kda_scaled_dot_kkt_fwd(
|
||||
gk: torch.Tensor | None = None,
|
||||
beta: torch.Tensor | None = None,
|
||||
scale: float | None = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
chunk_size: int = 64,
|
||||
output_dtype: torch.dtype = torch.float32,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -706,7 +706,7 @@ def chunk_kda_scaled_dot_kkt_fwd(
|
||||
The beta tensor of shape `[B, T, H]`.
|
||||
gk (torch.Tensor):
|
||||
The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
cu_seqlens (torch.Tensor):
|
||||
The cumulative sequence lengths of the input tensor.
|
||||
Default: None
|
||||
chunk_size (int):
|
||||
@@ -936,7 +936,7 @@ def recompute_w_u_fwd(
|
||||
A: torch.Tensor,
|
||||
q: torch.Tensor | None = None,
|
||||
gk: torch.Tensor | None = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, H, K, V = *k.shape, v.shape[-1]
|
||||
BT = A.shape[-1]
|
||||
@@ -1104,7 +1104,7 @@ def chunk_gla_fwd_o_gk(
|
||||
h: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
scale: float,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
chunk_size: int = 64,
|
||||
):
|
||||
B, T, H, K, V = *q.shape, v.shape[-1]
|
||||
@@ -1148,7 +1148,7 @@ def chunk_kda_fwd(
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
):
|
||||
chunk_size = 64
|
||||
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens)
|
||||
@@ -1208,7 +1208,7 @@ def chunk_kda(
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
if scale is None:
|
||||
|
||||
@@ -84,6 +84,7 @@ def layer_norm_fwd_kernel(
|
||||
HAS_Z: tl.constexpr,
|
||||
NORM_BEFORE_GATE: tl.constexpr,
|
||||
IS_RMS_NORM: tl.constexpr,
|
||||
ACTIVATION: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the starting row of X and Y it should compute.
|
||||
row_start = tl.program_id(0) * ROWS_PER_BLOCK
|
||||
@@ -112,7 +113,10 @@ def layer_norm_fwd_kernel(
|
||||
if HAS_Z and not NORM_BEFORE_GATE:
|
||||
Z_base = Z + rows[:, None] * stride_z_row + col_offsets
|
||||
z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32)
|
||||
x *= z * tl.sigmoid(z)
|
||||
if ACTIVATION == "swish" or ACTIVATION == "silu":
|
||||
x *= z * tl.sigmoid(z)
|
||||
elif ACTIVATION == "sigmoid":
|
||||
x *= tl.sigmoid(z)
|
||||
|
||||
# Compute mean and variance per row (reduce along axis 1)
|
||||
if not IS_RMS_NORM:
|
||||
@@ -155,7 +159,10 @@ def layer_norm_fwd_kernel(
|
||||
if HAS_Z and NORM_BEFORE_GATE:
|
||||
Z_base = Z + rows[:, None] * stride_z_row + col_offsets
|
||||
z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32)
|
||||
y *= z * tl.sigmoid(z)
|
||||
if ACTIVATION == "swish" or ACTIVATION == "silu":
|
||||
y *= z * tl.sigmoid(z)
|
||||
elif ACTIVATION == "sigmoid":
|
||||
y *= tl.sigmoid(z)
|
||||
|
||||
# Write output
|
||||
tl.store(Y_base, y, mask=mask)
|
||||
@@ -178,6 +185,7 @@ def layer_norm_fwd(
|
||||
group_size: int = None,
|
||||
norm_before_gate: bool = True,
|
||||
is_rms_norm: bool = False,
|
||||
activation: str = "swish",
|
||||
):
|
||||
M, N = x.shape
|
||||
if group_size is None:
|
||||
@@ -232,9 +240,12 @@ def layer_norm_fwd(
|
||||
eps,
|
||||
BLOCK_N=BLOCK_N,
|
||||
ROWS_PER_BLOCK=rows_per_block,
|
||||
HAS_BIAS=bias is not None,
|
||||
HAS_Z=z is not None,
|
||||
NORM_BEFORE_GATE=norm_before_gate,
|
||||
IS_RMS_NORM=is_rms_norm,
|
||||
num_warps=num_warps,
|
||||
ACTIVATION=activation,
|
||||
)
|
||||
return out, mean, rstd
|
||||
|
||||
@@ -252,6 +263,7 @@ class LayerNormFn(torch.autograd.Function):
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False,
|
||||
activation: str = "swish",
|
||||
):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
||||
|
||||
@@ -277,6 +289,7 @@ class LayerNormFn(torch.autograd.Function):
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
is_rms_norm=is_rms_norm,
|
||||
activation=activation,
|
||||
)
|
||||
ctx.save_for_backward(x, weight, bias, mean, rstd, z)
|
||||
ctx.x_shape_og = x_shape_og
|
||||
@@ -284,6 +297,7 @@ class LayerNormFn(torch.autograd.Function):
|
||||
ctx.group_size = group_size
|
||||
ctx.norm_before_gate = norm_before_gate
|
||||
ctx.is_rms_norm = is_rms_norm
|
||||
ctx.activation = activation
|
||||
return y.reshape(x_shape_og)
|
||||
|
||||
|
||||
@@ -296,17 +310,25 @@ def layernorm_fn(
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False,
|
||||
activation: str = "swish",
|
||||
):
|
||||
return LayerNormFn.apply(
|
||||
x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm
|
||||
x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm, activation
|
||||
)
|
||||
|
||||
|
||||
def rmsnorm_fn(
|
||||
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
activation: str = "swish",
|
||||
):
|
||||
return LayerNormFn.apply(
|
||||
x, weight, bias, z, eps, group_size, norm_before_gate, True
|
||||
x, weight, bias, z, eps, group_size, norm_before_gate, True, activation
|
||||
)
|
||||
|
||||
|
||||
@@ -359,6 +381,7 @@ class RMSNormGated(nn.Module):
|
||||
norm_before_gate: bool = False,
|
||||
device: torch.device | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
activation: str = "swish",
|
||||
):
|
||||
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
|
||||
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
|
||||
@@ -366,6 +389,7 @@ class RMSNormGated(nn.Module):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.activation = activation
|
||||
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.register_parameter("bias", None)
|
||||
self.group_size = group_size
|
||||
@@ -385,4 +409,5 @@ class RMSNormGated(nn.Module):
|
||||
eps=self.eps,
|
||||
group_size=self.group_size,
|
||||
norm_before_gate=self.norm_before_gate,
|
||||
activation=self.activation,
|
||||
)
|
||||
|
||||
@@ -122,7 +122,7 @@ def recompute_w_u_fwd(
|
||||
beta: torch.Tensor,
|
||||
g_cumsum: torch.Tensor,
|
||||
A: torch.Tensor,
|
||||
cu_seqlens: torch.LongTensor | None,
|
||||
cu_seqlens: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, Hg, K, V = *k.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
|
||||
@@ -22,12 +22,13 @@ from vllm.model_executor.layers.fused_moe.layer import (
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEActivationFormat,
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize,
|
||||
FusedMoEExpertsModular,
|
||||
FusedMoEPrepareAndFinalizeModular,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
|
||||
FusedMoERouter,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.gate_linear import GateLinear
|
||||
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
|
||||
UnquantizedFusedMoEMethod,
|
||||
@@ -61,9 +62,10 @@ __all__ = [
|
||||
"MoEActivation",
|
||||
"UnquantizedFusedMoEMethod",
|
||||
"FusedMoeWeightScaleSupported",
|
||||
"FusedMoEPermuteExpertsUnpermute",
|
||||
"FusedMoEExpertsModular",
|
||||
"FusedMoEActivationFormat",
|
||||
"FusedMoEPrepareAndFinalize",
|
||||
"FusedMoEPrepareAndFinalizeModular",
|
||||
"GateLinear",
|
||||
"RoutingMethodType",
|
||||
"SharedFusedMoE",
|
||||
"ZeroExpertFusedMoE",
|
||||
@@ -137,4 +139,4 @@ else:
|
||||
raise NotImplementedError(f"{method} is not implemented as lack of triton.")
|
||||
|
||||
fused_topk = lambda *args, **kwargs: _raise_exception("fused_topk")
|
||||
fused_experts = lambda *args, **kwargs: _raise_exception("fused_experts")
|
||||
fused_experts = lambda *args, **kwargs: _raise_exception("fused_experts")
|
||||
@@ -6,8 +6,7 @@ from enum import Enum
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm._custom_ops import silu_and_mul, gelu_and_mul, swigluoai_and_mul
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
|
||||
class MoEActivation(Enum):
|
||||
@@ -114,14 +113,11 @@ def apply_moe_activation(
|
||||
|
||||
# Activations with gated multiplication (gate × activation(up))
|
||||
if activation == MoEActivation.SILU:
|
||||
# torch.ops._C.silu_and_mul(output, input)
|
||||
silu_and_mul(output, input)
|
||||
ops.silu_and_mul(output, input)
|
||||
elif activation == MoEActivation.GELU:
|
||||
# torch.ops._C.gelu_and_mul(output, input)
|
||||
gelu_and_mul(output, input)
|
||||
ops.gelu_and_mul(output, input)
|
||||
elif activation == MoEActivation.SWIGLUOAI:
|
||||
# torch.ops._C.swigluoai_and_mul(output, input)
|
||||
swigluoai_and_mul(output, input)
|
||||
ops.swigluoai_and_mul(output, input)
|
||||
elif activation == MoEActivation.SWIGLUSTEP:
|
||||
from vllm.model_executor.layers.activation import swiglustep_and_mul_triton
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
@@ -20,20 +21,15 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEPrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNaiveEP,
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
make_moe_prepare_and_finalize_naive_dp_ep,
|
||||
make_moe_prepare_and_finalize_no_dp_ep,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
|
||||
from vllm.utils.import_utils import has_deep_ep, has_mori
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
if has_pplx():
|
||||
from .pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize,
|
||||
pplx_hidden_dim_scale_bytes,
|
||||
)
|
||||
if has_deep_ep():
|
||||
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
|
||||
from .deepep_ll_prepare_finalize import (
|
||||
@@ -81,6 +77,7 @@ def maybe_make_prepare_finalize(
|
||||
quant_config: FusedMoEQuantConfig | None,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
allow_new_interface: bool = False,
|
||||
use_monolithic: bool = False,
|
||||
) -> FusedMoEPrepareAndFinalize | None:
|
||||
# NOTE(rob): we are migrating each quant_method to hold the MK
|
||||
# in all cases. The allow_new_interface=False flag allow us to fall
|
||||
@@ -106,65 +103,25 @@ def maybe_make_prepare_finalize(
|
||||
"Detected DP deployment with no --enable-expert-parallel. "
|
||||
"Falling back to AllGather+ReduceScatter dispatch/combine."
|
||||
)
|
||||
return MoEPrepareAndFinalizeNaiveEP(
|
||||
return make_moe_prepare_and_finalize_naive_dp_ep(
|
||||
is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel,
|
||||
num_dispatchers=(
|
||||
get_ep_group().device_communicator.all2all_manager.world_size
|
||||
),
|
||||
use_monolithic=use_monolithic,
|
||||
)
|
||||
else:
|
||||
return MoEPrepareAndFinalizeNoEP()
|
||||
return make_moe_prepare_and_finalize_no_dp_ep(use_monolithic)
|
||||
|
||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||
assert all2all_manager is not None
|
||||
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize | None = None
|
||||
|
||||
if moe.use_pplx_kernels:
|
||||
assert quant_config is not None
|
||||
|
||||
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
|
||||
moe.max_num_tokens,
|
||||
moe.hidden_dim,
|
||||
moe.in_dtype,
|
||||
quant_config.quant_dtype,
|
||||
per_act_token_quant=quant_config.per_act_token_quant,
|
||||
block_shape=quant_config.block_shape,
|
||||
)
|
||||
|
||||
all_to_all_args = dict(
|
||||
max_num_tokens=moe.max_num_tokens,
|
||||
num_experts=moe.num_experts,
|
||||
experts_per_token=moe.experts_per_token, # topk
|
||||
rank=all2all_manager.rank,
|
||||
world_size=all2all_manager.world_size,
|
||||
# dp_size actually means tp_size, bug in pplx kernels
|
||||
dp_size=all2all_manager.tp_group.world_size,
|
||||
hidden_dim=moe.hidden_dim,
|
||||
hidden_dim_bytes=hidden_dim_bytes,
|
||||
hidden_dim_scale_bytes=hidden_scale_bytes,
|
||||
)
|
||||
|
||||
num_dispatchers = (
|
||||
all2all_manager.world_size // all2all_manager.tp_group.world_size
|
||||
)
|
||||
|
||||
# Intranode pplx a2a takes a group name while internode does not.
|
||||
if not all2all_manager.internode:
|
||||
all_to_all_args["group_name"] = all2all_manager.cpu_group.group_name
|
||||
|
||||
handle = all2all_manager.get_handle(all_to_all_args)
|
||||
|
||||
prepare_finalize = PplxPrepareAndFinalize(
|
||||
handle,
|
||||
max_num_tokens=moe.max_num_tokens,
|
||||
num_local_experts=moe.num_local_experts,
|
||||
num_dispatchers=num_dispatchers,
|
||||
)
|
||||
elif moe.use_deepep_ht_kernels:
|
||||
if moe.use_deepep_ht_kernels:
|
||||
assert moe.dp_size == all2all_manager.dp_world_size
|
||||
|
||||
all_to_all_args = dict()
|
||||
all_to_all_args: dict[str, Any] = dict()
|
||||
handle = all2all_manager.get_handle(all_to_all_args)
|
||||
prepare_finalize = DeepEPHTPrepareAndFinalize(
|
||||
handle,
|
||||
@@ -246,8 +203,9 @@ def maybe_make_prepare_finalize(
|
||||
)
|
||||
|
||||
elif moe.use_naive_all2all_kernels and allow_new_interface:
|
||||
prepare_finalize = MoEPrepareAndFinalizeNaiveEP(
|
||||
is_sequence_parallel=(moe.moe_parallel_config.is_sequence_parallel),
|
||||
prepare_finalize = make_moe_prepare_and_finalize_naive_dp_ep(
|
||||
use_monolithic=use_monolithic,
|
||||
is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel,
|
||||
num_dispatchers=all2all_manager.world_size,
|
||||
)
|
||||
|
||||
|
||||
@@ -261,7 +261,7 @@ def persistent_masked_m_silu_mul_quant(
|
||||
return y_q, y_s
|
||||
|
||||
|
||||
class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class BatchedDeepGemmExperts(mk.FusedMoEExpertsModular):
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
|
||||
@@ -228,6 +228,7 @@ class FusedMoEQuantConfig:
|
||||
_a2: FusedMoEQuantDesc
|
||||
_w1: FusedMoEQuantDesc
|
||||
_w2: FusedMoEQuantDesc
|
||||
is_nvfp4_scale_swizzled: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
assert not self.per_act_token_quant or self.block_shape is None, (
|
||||
@@ -475,6 +476,7 @@ class FusedMoEQuantConfig:
|
||||
w1_zp: torch.Tensor | None = None,
|
||||
w2_zp: torch.Tensor | None = None,
|
||||
weight_dtype: torch.dtype | str | None = None,
|
||||
is_nvfp4_scale_swizzled: bool = True,
|
||||
) -> "FusedMoEQuantConfig":
|
||||
"""
|
||||
General builder function for a FusedMoEQuantConfig.
|
||||
@@ -504,6 +506,7 @@ class FusedMoEQuantConfig:
|
||||
- w2_bias: Optional biases for w1 (GPT OSS Triton).
|
||||
- w1_zp: Optional w1 zero points for int4/int8 quantization.
|
||||
- w2_zp: Optional w2 zero points for int4/int8 quantization.
|
||||
- is_nvfp4_scale_swizzled: Whether to swizzle the nvfp4 scale swizzling.
|
||||
"""
|
||||
assert not isinstance(quant_dtype, str) or quant_dtype in {
|
||||
"nvfp4",
|
||||
@@ -536,6 +539,7 @@ class FusedMoEQuantConfig:
|
||||
_w2=FusedMoEQuantDesc(
|
||||
weight_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias
|
||||
),
|
||||
is_nvfp4_scale_swizzled=is_nvfp4_scale_swizzled,
|
||||
)
|
||||
assert quant_config.per_act_token_quant == per_act_token_quant
|
||||
assert quant_config.per_out_ch_quant == per_out_ch_quant
|
||||
@@ -737,6 +741,7 @@ def nvfp4_moe_quant_config(
|
||||
w2_scale: torch.Tensor,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
is_nvfp4_scale_swizzled: bool = True,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for mxfp4 activations and nvp4 weights.
|
||||
@@ -754,6 +759,7 @@ def nvfp4_moe_quant_config(
|
||||
per_act_token_quant=False,
|
||||
per_out_ch_quant=False,
|
||||
block_shape=None,
|
||||
is_nvfp4_scale_swizzled=is_nvfp4_scale_swizzled,
|
||||
)
|
||||
|
||||
|
||||
@@ -939,10 +945,6 @@ class FusedMoEParallelConfig:
|
||||
def use_all2all_kernels(self):
|
||||
return self.dp_size > 1 and self.use_ep
|
||||
|
||||
@property
|
||||
def use_pplx_kernels(self):
|
||||
return self.use_all2all_kernels and self.all2all_backend == "pplx"
|
||||
|
||||
@property
|
||||
def use_deepep_ht_kernels(self):
|
||||
return (
|
||||
@@ -962,7 +964,7 @@ class FusedMoEParallelConfig:
|
||||
|
||||
@property
|
||||
def use_batched_activation_format(self):
|
||||
return self.use_deepep_ll_kernels or self.use_pplx_kernels
|
||||
return self.use_deepep_ll_kernels
|
||||
|
||||
@property
|
||||
def use_naive_all2all_kernels(self):
|
||||
@@ -1221,10 +1223,6 @@ class FusedMoEConfig:
|
||||
def use_ep(self):
|
||||
return self.moe_parallel_config.use_ep
|
||||
|
||||
@property
|
||||
def use_pplx_kernels(self):
|
||||
return self.moe_parallel_config.use_pplx_kernels
|
||||
|
||||
@property
|
||||
def use_deepep_ht_kernels(self):
|
||||
return self.moe_parallel_config.use_deepep_ht_kernels
|
||||
|
||||
@@ -0,0 +1,147 @@
|
||||
{
|
||||
"triton_version": "3.6.0",
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,147 @@
|
||||
{
|
||||
"triton_version": "3.6.0",
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -21,7 +21,7 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||
moe_unpermute,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
MoEPrepareAndFinalizeNoDPEPModular,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate,
|
||||
@@ -166,7 +166,7 @@ def run_cutlass_moe_fp8(
|
||||
problem_sizes1 = torch.empty((local_E, 3), dtype=torch.int32, device=device)
|
||||
problem_sizes2 = torch.empty((local_E, 3), dtype=torch.int32, device=device)
|
||||
|
||||
ops.get_cutlass_pplx_moe_mm_data(
|
||||
ops.get_cutlass_batched_moe_mm_data(
|
||||
expert_offsets,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
@@ -262,7 +262,7 @@ def run_cutlass_moe_fp8(
|
||||
)
|
||||
|
||||
|
||||
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class CutlassExpertsFp8Base(mk.FusedMoEExpertsModular):
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
@@ -661,7 +661,7 @@ def run_cutlass_moe_fp4(
|
||||
return
|
||||
|
||||
|
||||
class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class CutlassExpertsFp4(mk.FusedMoEExpertsModular):
|
||||
"""CUTLASS FP4 fused MoE expert implementation."""
|
||||
|
||||
@property
|
||||
@@ -928,7 +928,7 @@ def run_cutlass_moe_w4a8_fp8(
|
||||
)
|
||||
|
||||
|
||||
class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class CutlassExpertsW4A8Fp8(mk.FusedMoEExpertsModular):
|
||||
def __init__(
|
||||
self,
|
||||
out_dtype: torch.dtype | None,
|
||||
@@ -1170,8 +1170,8 @@ def cutlass_moe_w4a8_fp8(
|
||||
|
||||
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0)
|
||||
|
||||
fn = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
fn = mk.FusedMoEKernel(
|
||||
MoEPrepareAndFinalizeNoDPEPModular(),
|
||||
CutlassExpertsW4A8Fp8(
|
||||
out_dtype=a.dtype,
|
||||
a_strides1=a_strides1,
|
||||
@@ -1186,10 +1186,9 @@ def cutlass_moe_w4a8_fp8(
|
||||
quant_config=quant_config,
|
||||
group_size=group_size,
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
return fn(
|
||||
return fn.apply(
|
||||
a,
|
||||
w1_q,
|
||||
w2_q,
|
||||
|
||||
@@ -113,7 +113,7 @@ def _valid_deep_gemm(
|
||||
return True
|
||||
|
||||
|
||||
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class DeepGemmExperts(mk.FusedMoEExpertsModular):
|
||||
"""DeepGemm-based fused MoE expert implementation."""
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig):
|
||||
|
||||
@@ -25,7 +25,7 @@ from vllm.v1.worker.ubatching import (
|
||||
)
|
||||
|
||||
|
||||
class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
|
||||
"""
|
||||
Prepare/Finalize using DeepEP High-Throughput kernels.
|
||||
"""
|
||||
@@ -123,7 +123,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
is_token_in_rank,
|
||||
event,
|
||||
) = self.buffer.get_dispatch_layout(
|
||||
topk_idx=rank_topk_ids,
|
||||
topk_idx=rank_topk_ids.long(),
|
||||
num_experts=num_experts,
|
||||
previous_event=previous_event,
|
||||
async_finish=False,
|
||||
@@ -148,7 +148,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
|
||||
is_token_in_rank=is_token_in_rank,
|
||||
num_tokens_per_expert=dispatch_expert_num_tokens,
|
||||
topk_idx=rank_topk_ids,
|
||||
topk_idx=rank_topk_ids.long(),
|
||||
topk_weights=rank_topk_weights,
|
||||
# expert_alignment rounds the number of tokens per expert
|
||||
# to this value.
|
||||
@@ -169,7 +169,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
event,
|
||||
has_scales,
|
||||
token_data,
|
||||
expert_topk_ids,
|
||||
expert_topk_ids.int(),
|
||||
num_experts,
|
||||
expert_num_tokens_per_expert_list,
|
||||
expert_topk_weights,
|
||||
@@ -239,6 +239,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
quant_dtype=quant_config.quant_dtype,
|
||||
per_act_token_quant=False,
|
||||
block_shape=quant_config.block_shape,
|
||||
is_fp4_scale_swizzled=quant_config.is_nvfp4_scale_swizzled,
|
||||
)
|
||||
|
||||
return (
|
||||
|
||||
@@ -49,7 +49,7 @@ def dequant_fp8(
|
||||
return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.size())
|
||||
|
||||
|
||||
class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
|
||||
"""
|
||||
Prepare/Finalize using DeepEP low-latency kernels.
|
||||
"""
|
||||
@@ -119,7 +119,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
# time. This setting is handled by post_init_setup.
|
||||
self.use_ue8m0_dispatch = False
|
||||
|
||||
def post_init_setup(self, fused_experts: mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def post_init_setup(self, fused_experts: mk.FusedMoEExperts):
|
||||
if not fused_experts.supports_packed_ue8m0_act_scales():
|
||||
# Early exit.
|
||||
return
|
||||
@@ -297,12 +297,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
dispatch_topk_ids = self._map_global_to_physical_ids(topk_ids)
|
||||
expert_x, expert_num_tokens, handle, _, hook = self.buffer.low_latency_dispatch(
|
||||
a1,
|
||||
dispatch_topk_ids,
|
||||
dispatch_topk_ids.long(),
|
||||
self.max_tokens_per_rank,
|
||||
num_experts,
|
||||
use_fp8=self.use_fp8_dispatch,
|
||||
round_scale=self.use_ue8m0_dispatch,
|
||||
use_ue8m0=self.use_ue8m0_dispatch,
|
||||
# round_scale=self.use_ue8m0_dispatch,
|
||||
# use_ue8m0=self.use_ue8m0_dispatch,
|
||||
**(dict(use_nvfp4=True) if use_nvfp4 else dict()),
|
||||
**(
|
||||
dict(x_global_scale=qc_a1_gscale_or_scale)
|
||||
@@ -398,7 +398,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
dbo_maybe_run_recv_hook()
|
||||
_, _, recv_hook = self.buffer.low_latency_combine(
|
||||
fused_expert_output,
|
||||
combine_topk_ids,
|
||||
combine_topk_ids.long(),
|
||||
combine_topk_weights,
|
||||
handle,
|
||||
async_finish=False,
|
||||
|
||||
335
vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py
Normal file
335
vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py
Normal file
@@ -0,0 +1,335 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
activation_to_flashinfer_int,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8Static128BlockSym,
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
|
||||
"""
|
||||
Fp8 TRTLLM-Gen MoE kernels. Supports monolithic interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
super().__init__(moe_config, quant_config)
|
||||
|
||||
if moe_config.moe_parallel_config.use_ep and quant_config.is_per_tensor:
|
||||
raise NotImplementedError(
|
||||
"EP parallelism is not supported with TRTLLM"
|
||||
"per-tensor FP8 quantization."
|
||||
)
|
||||
|
||||
self.routing_method_type = moe_config.routing_method
|
||||
self.topk = moe_config.experts_per_token
|
||||
self.intermediate_size_per_partition = (
|
||||
moe_config.intermediate_size_per_partition
|
||||
)
|
||||
self.hidden_dim = moe_config.hidden_dim
|
||||
self.local_num_experts = moe_config.num_local_experts
|
||||
self.ep_rank = moe_config.moe_parallel_config.ep_rank
|
||||
|
||||
# Make additional scales for per-tensor interface.
|
||||
if self.quant_config.is_per_tensor:
|
||||
w1_scale = self.quant_config.w1_scale
|
||||
assert w1_scale is not None
|
||||
a1_scale = self.quant_config.a1_scale
|
||||
assert a1_scale is not None
|
||||
w2_scale = self.quant_config.w2_scale
|
||||
assert w2_scale is not None
|
||||
a2_scale = self.quant_config.a2_scale
|
||||
assert a2_scale is not None
|
||||
|
||||
self._g1_alphas = (w1_scale * a1_scale).squeeze()
|
||||
self._g2_alphas = (w2_scale * a2_scale).squeeze()
|
||||
self._g1_scale_c = (
|
||||
self._g1_alphas / self.quant_config.a2_scale
|
||||
if moe_config.is_act_and_mul
|
||||
else torch.ones_like(self._g1_alphas) / self.quant_config.a2_scale
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
@staticmethod
|
||||
def _supports_current_device() -> bool:
|
||||
"""Supports only Blackwell-family GPUs."""
|
||||
p = current_platform
|
||||
# Add check flashinfer trtllm is available
|
||||
return p.is_cuda() and p.is_device_capability_family(100)
|
||||
|
||||
@staticmethod
|
||||
def _supports_no_act_and_mul() -> bool:
|
||||
"""Does not support non-gated MoE (i.e. Nanotron-3-Nano)."""
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _supports_quant_scheme(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""Supports Fp8 per-tensor and Fp8 block."""
|
||||
SUPPORTED_W_A = [
|
||||
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
|
||||
(kFp8StaticTensorSym, kFp8StaticTensorSym),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
"""Supports only SiLU and RELU^2 non-gated activation."""
|
||||
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
|
||||
@staticmethod
|
||||
def _supports_routing_method(
|
||||
routing_method: RoutingMethodType,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""Monolithic kernels need to express router support."""
|
||||
# NOTE(dbari): TopK routing could also be enabled, but need to validate models
|
||||
# NOTE(dbari): Default is not implemented and should not be enabled until it is
|
||||
if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym):
|
||||
# NOTE(rob): potentially allow others here. This is a conservative list.
|
||||
return routing_method in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
]
|
||||
elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
|
||||
# NOTE(dbari): as above, potentially allow others here.
|
||||
return routing_method in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Llama4,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
]
|
||||
else:
|
||||
raise ValueError("Unsupported quantization scheme.")
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
"""Monolithic kernel so only use with naive DP/EP and TP."""
|
||||
return (
|
||||
not moe_parallel_config.use_all2all_kernels
|
||||
or moe_parallel_config.use_naive_all2all_kernels
|
||||
) and not moe_parallel_config.enable_eplb
|
||||
|
||||
@staticmethod
|
||||
def _supports_router_logits_dtype(
|
||||
router_logits_dtype: torch.dtype | None,
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
"""
|
||||
The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default.
|
||||
Only DeepSeekV3 routing supports float32 router_logits (which is converted
|
||||
internally in the kernel).
|
||||
"""
|
||||
if router_logits_dtype == torch.float32:
|
||||
# Only DeepSeekV3 routing handles float32 logits
|
||||
# https://github.com/flashinfer-ai/flashinfer/issues/2469
|
||||
return routing_method == RoutingMethodType.DeepSeekV3
|
||||
return True
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def _apply_per_block(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group: int | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
routed_scaling_factor: float | None = None,
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
# Delay import for non-CUDA.
|
||||
import flashinfer
|
||||
|
||||
assert not apply_router_weight_on_input
|
||||
assert activation == MoEActivation.SILU
|
||||
|
||||
if e_score_correction_bias is not None:
|
||||
e_score_correction_bias = e_score_correction_bias.to(hidden_states.dtype)
|
||||
|
||||
if self.routing_method_type == RoutingMethodType.DeepSeekV3:
|
||||
router_logits = router_logits.to(torch.float32)
|
||||
|
||||
assert self.topk <= global_num_experts
|
||||
assert self.topk <= 10
|
||||
assert global_num_experts % 4 == 0
|
||||
assert self.quant_config.block_shape == [128, 128]
|
||||
# Routing kernel expects #experts <= #threads 512
|
||||
assert global_num_experts <= 512
|
||||
|
||||
# Kernel requires transposed hidden state scales
|
||||
# TODO: fuse into the quant kernel.
|
||||
assert a1q_scale is not None
|
||||
a1q_scale_t = a1q_scale.t().contiguous()
|
||||
|
||||
return flashinfer.fused_moe.trtllm_fp8_block_scale_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=e_score_correction_bias,
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_scale=a1q_scale_t,
|
||||
gemm1_weights=w1,
|
||||
gemm1_weights_scale=self.quant_config.w1_scale,
|
||||
gemm2_weights=w2,
|
||||
gemm2_weights_scale=self.quant_config.w2_scale,
|
||||
num_experts=global_num_experts,
|
||||
top_k=self.topk,
|
||||
n_group=(num_expert_group or 0),
|
||||
topk_group=(topk_group or 0),
|
||||
intermediate_size=self.intermediate_size_per_partition,
|
||||
local_expert_offset=self.ep_rank * self.local_num_experts,
|
||||
local_num_experts=self.local_num_experts,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
routing_method_type=self.routing_method_type,
|
||||
use_shuffled_weight=False,
|
||||
)
|
||||
|
||||
def _apply_per_tensor(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group: int | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
routed_scaling_factor: float | None = None,
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
# Delay import for non-CUDA.
|
||||
import flashinfer
|
||||
from flashinfer.fused_moe.core import ActivationType
|
||||
|
||||
# Confirm supported activation function.
|
||||
assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
|
||||
activation_type = ActivationType(activation_to_flashinfer_int(activation))
|
||||
|
||||
# Confirm Llama-4 routing is proper.
|
||||
if self.routing_method_type == RoutingMethodType.Llama4:
|
||||
assert apply_router_weight_on_input
|
||||
else:
|
||||
assert not apply_router_weight_on_input
|
||||
|
||||
# The DeepSeekV3 routing method requires float32 router logits.
|
||||
if self.routing_method_type == RoutingMethodType.DeepSeekV3:
|
||||
router_logits = router_logits.to(torch.float32)
|
||||
|
||||
out = flashinfer.fused_moe.trtllm_fp8_per_tensor_scale_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=e_score_correction_bias,
|
||||
hidden_states=hidden_states,
|
||||
gemm1_weights=w1,
|
||||
output1_scales_scalar=self._g1_scale_c,
|
||||
output1_scales_gate_scalar=self._g1_alphas,
|
||||
gemm2_weights=w2,
|
||||
output2_scales_scalar=self._g2_alphas,
|
||||
num_experts=global_num_experts,
|
||||
top_k=self.topk,
|
||||
n_group=num_expert_group or 0,
|
||||
topk_group=topk_group or 0,
|
||||
intermediate_size=self.intermediate_size_per_partition,
|
||||
local_expert_offset=self.ep_rank * self.local_num_experts,
|
||||
local_num_experts=self.local_num_experts,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
use_routing_scales_on_input=apply_router_weight_on_input,
|
||||
routing_method_type=self.routing_method_type,
|
||||
activation_type=activation_type,
|
||||
)
|
||||
return out
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group: int | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
routed_scaling_factor: float | None = None,
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.quant_config.block_shape is not None:
|
||||
return self._apply_per_block(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
router_logits,
|
||||
activation,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
a1q_scale,
|
||||
apply_router_weight_on_input,
|
||||
num_expert_group=num_expert_group,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
topk_group=topk_group,
|
||||
)
|
||||
elif self.quant_config.is_per_tensor:
|
||||
return self._apply_per_tensor(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
router_logits,
|
||||
activation,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
a1q_scale,
|
||||
apply_router_weight_on_input,
|
||||
num_expert_group=num_expert_group,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Only per-block and per-tensor quantization are supported in "
|
||||
f"{self.__class__.__name__}."
|
||||
)
|
||||
326
vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py
Normal file
326
vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py
Normal file
@@ -0,0 +1,326 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import flashinfer
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceNoOP,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
activation_to_flashinfer_int,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kNvfp4Dynamic,
|
||||
kNvfp4Static,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class TrtLlmNvFp4ExpertsBase:
|
||||
"""
|
||||
NvFp4 TRTLLM-Gen MoE kernels. Supports modular and monolithic interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
self.moe_config = moe_config
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.routing_method_type = self.moe_config.routing_method
|
||||
self.topk = moe_config.experts_per_token
|
||||
self.intermediate_size_per_partition = (
|
||||
moe_config.intermediate_size_per_partition
|
||||
)
|
||||
self.hidden_dim = moe_config.hidden_dim
|
||||
self.local_num_experts = moe_config.num_local_experts
|
||||
self.ep_rank = moe_config.moe_parallel_config.ep_rank
|
||||
|
||||
assert self.quant_config.g1_alphas is not None
|
||||
assert self.quant_config.a2_gscale is not None
|
||||
if moe_config.is_act_and_mul:
|
||||
# g1_alpha_s = a13_scale * w13_scale_2
|
||||
# a2_gscale = (1 / a2_scale)
|
||||
# g1_scale_c = a13_scale * w13_scale_2 / a2_scale
|
||||
self.g1_scale_c = self.quant_config.g1_alphas * self.quant_config.a2_gscale
|
||||
else:
|
||||
self.g1_scale_c = (
|
||||
torch.ones_like(self.quant_config.a1_gscale)
|
||||
* self.quant_config.a2_gscale
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_current_device() -> bool:
|
||||
"""Supports only Blackwell-family GPUs."""
|
||||
p = current_platform
|
||||
return p.is_cuda() and p.is_device_capability_family(100)
|
||||
|
||||
@staticmethod
|
||||
def _supports_no_act_and_mul() -> bool:
|
||||
"""Supports non-gated MoE (i.e. Nemotron-Nano)."""
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _supports_quant_scheme(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""Supports Nvfp4 quantization."""
|
||||
SUPPORTED_W_A = [
|
||||
(kNvfp4Static, kNvfp4Dynamic),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
"""Supports only SiLU and RELU^2 non-gated activation."""
|
||||
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
|
||||
@staticmethod
|
||||
def _supports_shape(hidden_dim: int) -> bool:
|
||||
"""Requires hidden dim to be multiple of 512."""
|
||||
return hidden_dim % 512 == 0
|
||||
|
||||
@staticmethod
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class TrtLlmNvFp4ExpertsModular(TrtLlmNvFp4ExpertsBase, mk.FusedMoEExpertsModular):
|
||||
"""
|
||||
Modular version of the implementation (just the experts).
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
"""The modular implementation supports all parallel configs."""
|
||||
return True
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# The workspaces for this implementation are managed by flashinfer.
|
||||
workspace1 = (0,)
|
||||
workspace2 = (0,)
|
||||
|
||||
# Hidden states are Nvfp4, packed into int8 dtype, so we
|
||||
# need to multiply K by 2 to get the output shape right.
|
||||
assert self.hidden_dim == K * 2
|
||||
output = (M, self.hidden_dim)
|
||||
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
assert a1q_scale is not None
|
||||
assert self.quant_config.w1_scale is not None
|
||||
assert self.quant_config.w2_scale is not None
|
||||
|
||||
# Pack topk ids and weights into format expected by the kernel.
|
||||
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
|
||||
torch.bfloat16
|
||||
).view(torch.int16)
|
||||
|
||||
# trtllm_fp4_block_scale_routed_moe does not support autotuning
|
||||
# so skip this kernel during dummy run for autotuning.
|
||||
import vllm.utils.flashinfer as fi_utils
|
||||
|
||||
if fi_utils._is_fi_autotuning:
|
||||
return hidden_states
|
||||
|
||||
# Invoke kernel.
|
||||
flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(
|
||||
topk_ids=packed_tensor,
|
||||
routing_bias=None,
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_scale=a1q_scale.view(torch.float8_e4m3fn).reshape(
|
||||
*hidden_states.shape[:-1], -1
|
||||
),
|
||||
gemm1_weights=w1,
|
||||
gemm1_weights_scale=self.quant_config.w1_scale.view(torch.float8_e4m3fn),
|
||||
gemm1_bias=None,
|
||||
gemm1_alpha=None,
|
||||
gemm1_beta=None,
|
||||
gemm1_clamp_limit=None,
|
||||
gemm2_weights=w2,
|
||||
gemm2_weights_scale=self.quant_config.w2_scale.view(torch.float8_e4m3fn),
|
||||
gemm2_bias=None,
|
||||
output1_scale_scalar=self.g1_scale_c,
|
||||
output1_scale_gate_scalar=self.quant_config.g1_alphas,
|
||||
output2_scale_scalar=self.quant_config.g2_alphas,
|
||||
num_experts=global_num_experts,
|
||||
top_k=self.topk,
|
||||
n_group=0,
|
||||
topk_group=0,
|
||||
intermediate_size=self.intermediate_size_per_partition,
|
||||
local_expert_offset=self.ep_rank * self.local_num_experts,
|
||||
local_num_experts=self.local_num_experts,
|
||||
routed_scaling_factor=None,
|
||||
routing_method_type=1,
|
||||
do_finalize=True,
|
||||
activation_type=activation_to_flashinfer_int(activation),
|
||||
output=output,
|
||||
)
|
||||
|
||||
|
||||
class TrtLlmNvFp4ExpertsMonolithic(
|
||||
TrtLlmNvFp4ExpertsBase, mk.FusedMoEExpertsMonolithic
|
||||
):
|
||||
"""
|
||||
Monolithic version of the kernel (router + experts).
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
"""The modular implementation should be used for the Dp/Ep or EPLB case."""
|
||||
return (
|
||||
not moe_parallel_config.use_all2all_kernels
|
||||
and not moe_parallel_config.enable_eplb
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_routing_method(
|
||||
routing_method_type: RoutingMethodType,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
# NOTE(rob): this is a conservative list.
|
||||
return routing_method_type in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
RoutingMethodType.Llama4,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _supports_router_logits_dtype(
|
||||
router_logits_dtype: torch.dtype | None,
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
"""
|
||||
The FlashInfer TRTLLM NvFp4 kernel expects bfloat16 router_logits by default.
|
||||
Only DeepSeekV3 routing supports float32 router_logits (which is converted
|
||||
internally in the kernel).
|
||||
"""
|
||||
if router_logits_dtype == torch.float32:
|
||||
# Only DeepSeekV3 routing handles float32 logits
|
||||
# https://github.com/flashinfer-ai/flashinfer/issues/2469
|
||||
return routing_method == RoutingMethodType.DeepSeekV3
|
||||
return True
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group: int | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
routed_scaling_factor: float | None = None,
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
assert a1q_scale is not None
|
||||
assert self.quant_config.w1_scale is not None
|
||||
assert self.quant_config.w2_scale is not None
|
||||
assert (
|
||||
apply_router_weight_on_input
|
||||
and self.routing_method_type == RoutingMethodType.Llama4
|
||||
) or (
|
||||
not apply_router_weight_on_input
|
||||
and self.routing_method_type != RoutingMethodType.Llama4
|
||||
)
|
||||
|
||||
# Prepare routing bias into kernel format.
|
||||
routing_bias = e_score_correction_bias
|
||||
if routing_bias is not None:
|
||||
routing_bias = routing_bias.to(torch.bfloat16)
|
||||
router_logits = (
|
||||
router_logits.to(torch.float32)
|
||||
if self.routing_method_type == RoutingMethodType.DeepSeekV3
|
||||
else router_logits
|
||||
)
|
||||
|
||||
# Invoke kernel.
|
||||
return flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=routing_bias,
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_scale=a1q_scale.view(torch.float8_e4m3fn).reshape(
|
||||
*hidden_states.shape[:-1], -1
|
||||
),
|
||||
gemm1_weights=w1,
|
||||
gemm1_weights_scale=self.quant_config.w1_scale.view(torch.float8_e4m3fn),
|
||||
gemm1_bias=None,
|
||||
gemm1_alpha=None,
|
||||
gemm1_beta=None,
|
||||
gemm1_clamp_limit=None,
|
||||
gemm2_weights=w2,
|
||||
gemm2_weights_scale=self.quant_config.w2_scale.view(torch.float8_e4m3fn),
|
||||
gemm2_bias=None,
|
||||
output1_scale_scalar=self.g1_scale_c,
|
||||
output1_scale_gate_scalar=self.quant_config.g1_alphas,
|
||||
output2_scale_scalar=self.quant_config.g2_alphas,
|
||||
num_experts=global_num_experts,
|
||||
top_k=self.topk,
|
||||
n_group=(num_expert_group or 0),
|
||||
topk_group=(topk_group or 0),
|
||||
intermediate_size=self.intermediate_size_per_partition,
|
||||
local_expert_offset=self.ep_rank * self.local_num_experts,
|
||||
local_num_experts=self.local_num_experts,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
routing_method_type=self.routing_method_type,
|
||||
do_finalize=True,
|
||||
)[0]
|
||||
@@ -11,13 +11,13 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||
|
||||
|
||||
class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
|
||||
class FallbackExperts(mk.FusedMoEExpertsModular, ABC):
|
||||
"""Base class for runtime dispatching of expert implementations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
experts: mk.FusedMoEPermuteExpertsUnpermute,
|
||||
fallback_experts: mk.FusedMoEPermuteExpertsUnpermute,
|
||||
experts: mk.FusedMoEExpertsModular,
|
||||
fallback_experts: mk.FusedMoEExpertsModular,
|
||||
):
|
||||
super().__init__(
|
||||
moe_config=experts.moe_config, quant_config=experts.quant_config
|
||||
@@ -27,8 +27,8 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
|
||||
|
||||
@staticmethod
|
||||
def get_clses() -> tuple[
|
||||
type[mk.FusedMoEPermuteExpertsUnpermute],
|
||||
type[mk.FusedMoEPermuteExpertsUnpermute],
|
||||
type[mk.FusedMoEExpertsModular],
|
||||
type[mk.FusedMoEExpertsModular],
|
||||
]:
|
||||
"""
|
||||
Get the cls for the experts and fallback experts.
|
||||
@@ -149,7 +149,7 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
) -> mk.FusedMoEExpertsModular:
|
||||
raise NotImplementedError
|
||||
|
||||
def apply(
|
||||
|
||||
@@ -18,7 +18,7 @@ def get_local_sizes():
|
||||
return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
|
||||
|
||||
class FlashInferA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
class FlashInferA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
|
||||
"""Base class for FlashInfer MoE prepare and finalize operations."""
|
||||
|
||||
def __init__(
|
||||
@@ -185,8 +185,8 @@ def flashinfer_alltoall_dispatch(
|
||||
ep_size,
|
||||
)
|
||||
|
||||
# Swizzle after the A2A if nvfp4.
|
||||
if quant_config.quant_dtype == "nvfp4":
|
||||
# Swizzle after the A2A if MoE kernel expects swizzled scales.
|
||||
if quant_config.quant_dtype == "nvfp4" and quant_config.is_nvfp4_scale_swizzled:
|
||||
if x_sf.element_size() == 1:
|
||||
x_sf = x_sf.view(torch.uint8)
|
||||
x_sf = nvfp4_block_scale_interleave(x_sf)
|
||||
|
||||
@@ -30,7 +30,7 @@ from vllm.utils.flashinfer import (
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
|
||||
@@ -60,7 +60,7 @@ def is_valid_flashinfer_cutlass_fused_moe(
|
||||
return True
|
||||
|
||||
|
||||
class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class FlashInferExperts(mk.FusedMoEExpertsModular):
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: mk.FusedMoEConfig,
|
||||
|
||||
@@ -10,16 +10,6 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEParallelConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8Static128BlockSym,
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
@@ -39,49 +29,10 @@ def _supports_no_act_and_mul() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def _supports_quant_scheme(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""Supports Fp8 per-tensor and Fp8 block."""
|
||||
SUPPORTED_W_A = [
|
||||
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
|
||||
(kFp8StaticTensorSym, kFp8StaticTensorSym),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
|
||||
|
||||
def _supports_routing_method(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
"""Monolithic kernels need to express router support."""
|
||||
# NOTE(dbari): TopK routing could also be enabled, but need to validate models
|
||||
# NOTE(dbari): Default is not implemented and should not be enabled until it is
|
||||
if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym):
|
||||
# NOTE(rob): potentially allow others here. This is a conservative list.
|
||||
return routing_method in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
]
|
||||
elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
|
||||
# NOTE(dbari): as above, potentially allow others here.
|
||||
return routing_method in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Llama4,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
]
|
||||
else:
|
||||
raise ValueError("Unsupported quantization scheme.")
|
||||
|
||||
|
||||
def _supports_routing_method_bf16(
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
@@ -99,62 +50,6 @@ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bo
|
||||
return not moe_parallel_config.enable_eplb
|
||||
|
||||
|
||||
def _supports_router_logits_dtype(
|
||||
router_logits_dtype: torch.dtype | None,
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
"""
|
||||
The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default.
|
||||
Only DeepSeekV3 routing supports float32 router_logits (which is converted
|
||||
internally in the kernel).
|
||||
"""
|
||||
if router_logits_dtype == torch.float32:
|
||||
# Only DeepSeekV3 routing handles float32 logits
|
||||
# https://github.com/flashinfer-ai/flashinfer/issues/2469
|
||||
return routing_method == RoutingMethodType.DeepSeekV3
|
||||
return True
|
||||
|
||||
|
||||
def is_supported_config_trtllm_fp8(
|
||||
moe_config: FusedMoEConfig,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""
|
||||
This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config
|
||||
"""
|
||||
|
||||
def _make_reason(reason: str) -> str:
|
||||
return f"kernel does not support {reason}"
|
||||
|
||||
if not _supports_current_device():
|
||||
return False, _make_reason(f"current device {current_platform.device_name}")
|
||||
elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()):
|
||||
return False, _make_reason("no act_and_mul MLP layer")
|
||||
elif not _supports_activation(moe_config.activation):
|
||||
return False, _make_reason(f"{moe_config.activation} activation")
|
||||
elif not _supports_quant_scheme(weight_key, activation_key):
|
||||
return False, _make_reason(f"quantization scheme {weight_key}x{activation_key}")
|
||||
elif not _supports_parallel_config(moe_config.moe_parallel_config):
|
||||
return False, _make_reason(f"parallel config {moe_config.moe_parallel_config}")
|
||||
elif not _supports_routing_method(
|
||||
weight_key, activation_key, moe_config.routing_method
|
||||
):
|
||||
return False, _make_reason(f"routing method {moe_config.routing_method}")
|
||||
elif activation_format != mk.FusedMoEActivationFormat.Standard:
|
||||
return False, _make_reason(f"activation format {activation_format}")
|
||||
elif not _supports_router_logits_dtype(
|
||||
moe_config.router_logits_dtype, moe_config.routing_method
|
||||
):
|
||||
return False, _make_reason(
|
||||
"float32 router_logits with non-DeepSeekV3 routing "
|
||||
f"{moe_config.router_logits_dtype}x{moe_config.routing_method}"
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
def is_supported_config_trtllm_bf16(
|
||||
moe_config: FusedMoEConfig,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
@@ -183,199 +78,6 @@ def is_supported_config_trtllm_bf16(
|
||||
return True, None
|
||||
|
||||
|
||||
def flashinfer_fused_moe_blockscale_fp8(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor | None,
|
||||
x: torch.Tensor,
|
||||
w13_weight: torch.Tensor,
|
||||
w13_weight_scale_inv: torch.Tensor,
|
||||
w2_weight: torch.Tensor,
|
||||
w2_weight_scale_inv: torch.Tensor,
|
||||
global_num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: int | None,
|
||||
topk_group: int | None,
|
||||
intermediate_size: int,
|
||||
expert_offset: int,
|
||||
local_num_experts: int,
|
||||
block_shape: list[int],
|
||||
routing_method_type: int,
|
||||
routed_scaling: float | None = 1.0,
|
||||
) -> torch.Tensor:
|
||||
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
|
||||
|
||||
num_expert_group = num_expert_group if num_expert_group is not None else 0
|
||||
topk_group = topk_group if topk_group is not None else 0
|
||||
assert top_k <= global_num_experts
|
||||
assert top_k <= 10
|
||||
assert global_num_experts % 4 == 0
|
||||
assert block_shape == [128, 128]
|
||||
# Routing kernel expects #experts <= #threads 512
|
||||
assert global_num_experts <= 512
|
||||
|
||||
# The DeepSeekV3 routing method requires float32 router logits.
|
||||
if routing_method_type == RoutingMethodType.DeepSeekV3:
|
||||
routing_logits = routing_logits.to(torch.float32)
|
||||
|
||||
if routing_bias is not None:
|
||||
routing_bias = routing_bias.to(x.dtype)
|
||||
|
||||
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
|
||||
# NOTE: scales of hidden states have to be transposed!
|
||||
a_sf_t = a_sf.t().contiguous()
|
||||
return flashinfer_trtllm_fp8_block_scale_moe(
|
||||
routing_logits=routing_logits,
|
||||
routing_bias=routing_bias,
|
||||
hidden_states=a_q,
|
||||
hidden_states_scale=a_sf_t,
|
||||
gemm1_weights=w13_weight,
|
||||
gemm1_weights_scale=w13_weight_scale_inv,
|
||||
gemm2_weights=w2_weight,
|
||||
gemm2_weights_scale=w2_weight_scale_inv,
|
||||
num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
n_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
intermediate_size=intermediate_size,
|
||||
local_expert_offset=expert_offset,
|
||||
local_num_experts=local_num_experts,
|
||||
routed_scaling_factor=routed_scaling,
|
||||
routing_method_type=routing_method_type,
|
||||
use_shuffled_weight=False,
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_fused_moe_blockscale_fp8_fake(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor | None,
|
||||
x: torch.Tensor,
|
||||
w13_weight: torch.Tensor,
|
||||
w13_weight_scale_inv: torch.Tensor,
|
||||
w2_weight: torch.Tensor,
|
||||
w2_weight_scale_inv: torch.Tensor,
|
||||
global_num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
intermediate_size: int,
|
||||
expert_offset: int,
|
||||
local_num_experts: int,
|
||||
block_shape: list[int],
|
||||
routing_method_type: int,
|
||||
routed_scaling: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
# TODO(bnell): Does this really need to be a torch.op?
|
||||
direct_register_custom_op(
|
||||
op_name="flashinfer_fused_moe_blockscale_fp8",
|
||||
op_func=flashinfer_fused_moe_blockscale_fp8,
|
||||
fake_impl=flashinfer_fused_moe_blockscale_fp8_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order,),
|
||||
)
|
||||
|
||||
|
||||
def fi_trtllm_fp8_per_tensor_moe(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor | None,
|
||||
hidden_states: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
gemm1_weights: torch.Tensor,
|
||||
gemm2_weights: torch.Tensor,
|
||||
output1_scales_scalar: torch.Tensor,
|
||||
output1_scales_gate_scalar: torch.Tensor,
|
||||
output2_scales_scalar: torch.Tensor,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: int | None,
|
||||
topk_group: int | None,
|
||||
intermediate_size: int,
|
||||
local_expert_offset: int,
|
||||
local_num_experts: int,
|
||||
use_routing_scales_on_input: bool,
|
||||
routing_method_type: int,
|
||||
activation_type: int,
|
||||
routed_scaling_factor: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
num_expert_group = num_expert_group if num_expert_group is not None else 0
|
||||
topk_group = topk_group if topk_group is not None else 0
|
||||
|
||||
quant_hidden_states, _ = moe_kernel_quantize_input(
|
||||
hidden_states,
|
||||
input_scale,
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
|
||||
from flashinfer.fused_moe.core import ActivationType
|
||||
|
||||
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_per_tensor_scale_moe
|
||||
|
||||
# The DeepSeekV3 routing method requires float32 router logits.
|
||||
if routing_method_type == RoutingMethodType.DeepSeekV3:
|
||||
routing_logits = routing_logits.to(torch.float32)
|
||||
|
||||
return flashinfer_trtllm_fp8_per_tensor_scale_moe(
|
||||
routing_logits=routing_logits,
|
||||
routing_bias=routing_bias,
|
||||
hidden_states=quant_hidden_states,
|
||||
gemm1_weights=gemm1_weights,
|
||||
output1_scales_scalar=output1_scales_scalar,
|
||||
output1_scales_gate_scalar=output1_scales_gate_scalar,
|
||||
gemm2_weights=gemm2_weights,
|
||||
output2_scales_scalar=output2_scales_scalar,
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
n_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
intermediate_size=intermediate_size,
|
||||
local_expert_offset=local_expert_offset,
|
||||
local_num_experts=local_num_experts,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
use_routing_scales_on_input=use_routing_scales_on_input,
|
||||
routing_method_type=routing_method_type,
|
||||
# TODO: enum type Required for flashinfer==0.6.3, remove with update
|
||||
# https://github.com/flashinfer-ai/flashinfer/pull/2508
|
||||
activation_type=ActivationType(activation_type),
|
||||
)
|
||||
|
||||
|
||||
def fi_trtllm_fp8_per_tensor_moe_fake(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor | None,
|
||||
hidden_states: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
gemm1_weights: torch.Tensor,
|
||||
gemm2_weights: torch.Tensor,
|
||||
output1_scales_scalar: torch.Tensor,
|
||||
output1_scales_gate_scalar: torch.Tensor,
|
||||
output2_scales_scalar: torch.Tensor,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: int | None,
|
||||
topk_group: int | None,
|
||||
intermediate_size: int,
|
||||
local_expert_offset: int,
|
||||
local_num_experts: int,
|
||||
use_routing_scales_on_input: bool,
|
||||
routing_method_type: int,
|
||||
activation_type: int,
|
||||
routed_scaling_factor: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
# TODO(bnell): Does this really need to be a torch.op?
|
||||
direct_register_custom_op(
|
||||
op_name="fi_trtllm_fp8_per_tensor_moe",
|
||||
op_func=fi_trtllm_fp8_per_tensor_moe,
|
||||
mutates_args=["hidden_states"],
|
||||
fake_impl=fi_trtllm_fp8_per_tensor_moe_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order,),
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_fused_moe_bf16(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor | None,
|
||||
|
||||
@@ -489,11 +489,11 @@ def invoke_moe_batched_triton_kernel(
|
||||
)
|
||||
|
||||
|
||||
class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
|
||||
"""
|
||||
A reference prepare/finalize class that reorganizes the tokens into
|
||||
expert batched format, i.e. E x max_num_tokens x K. This is the format
|
||||
that the PPLX dispatch/combine kernels use.
|
||||
that the batched dispatch/combine kernels use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -645,10 +645,10 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
)
|
||||
|
||||
|
||||
class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class NaiveBatchedExperts(mk.FusedMoEExpertsModular):
|
||||
"""
|
||||
A reference MoE expert class that operates on expert batched format,
|
||||
i.e. E x max_num_tokens x K. This is the format that the pplx
|
||||
i.e. E x max_num_tokens x K. This is the format that the batched
|
||||
dispatch/combine kernels use.
|
||||
"""
|
||||
|
||||
@@ -877,10 +877,10 @@ def batched_moe_kernel_quantize_input(
|
||||
return A_q, A_q_scale
|
||||
|
||||
|
||||
class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class BatchedTritonExperts(mk.FusedMoEExpertsModular):
|
||||
"""
|
||||
A Triton based MoE expert class that operates on expert batched format,
|
||||
i.e. E x max_num_tokens x K. This is the format that the pplx
|
||||
i.e. E x max_num_tokens x K. This is the format that the batched
|
||||
dispatch/combine kernels use.
|
||||
"""
|
||||
|
||||
|
||||
@@ -526,7 +526,7 @@ def batched_fused_marlin_moe(
|
||||
return output
|
||||
|
||||
|
||||
class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class MarlinExpertsBase(mk.FusedMoEExpertsModular):
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
|
||||
@@ -53,7 +53,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
import ixformer.inference.functions as ixfops
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.distributed import get_ep_group
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -575,56 +578,6 @@ def fused_moe_kernel(
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
def invoke_fused_moe_kernel(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
C: torch.Tensor,
|
||||
A_scale: torch.Tensor | None,
|
||||
B_scale: torch.Tensor | None,
|
||||
B_zp: torch.Tensor | None,
|
||||
topk_weights: torch.Tensor | None,
|
||||
topk_ids: torch.Tensor,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
mul_routed_weight: bool,
|
||||
top_k: int,
|
||||
config: dict[str, Any],
|
||||
compute_type: tl.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
per_channel_quant: bool,
|
||||
block_shape: list[int] | None = None,
|
||||
B_bias: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
assert topk_weights is not None or not mul_routed_weight
|
||||
assert topk_weights is None or topk_weights.stride(1) == 1
|
||||
assert sorted_token_ids.stride(0) == 1
|
||||
ops.invoke_fused_moe_kernel(
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
A_scale,
|
||||
B_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
mul_routed_weight,
|
||||
top_k,
|
||||
config,
|
||||
compute_type,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
block_shape,
|
||||
)
|
||||
# ops.invoke_fused_moe_kernel(A,B,C,A_scale,B_scale,topk_weights,topk_ids,sorted_token_ids,expert_ids,num_tokens_post_padded,mul_routed_weight,top_k,config,compute_type,use_fp8_w8a8,use_int8_w8a16,block_shape,B_bias)
|
||||
return
|
||||
|
||||
|
||||
# NOTE(zyongye): we can remove all the wna16 kernel
|
||||
# once we drop off sm75 support
|
||||
def invoke_fused_moe_wna16_cuda_kernel(
|
||||
@@ -782,6 +735,7 @@ def invoke_fused_moe_triton_kernel(
|
||||
A_scale: torch.Tensor | None,
|
||||
B_scale: torch.Tensor | None,
|
||||
topk_weights: torch.Tensor | None,
|
||||
topk_ids: torch.Tensor,
|
||||
sorted_token_ids: torch.Tensor | None,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
@@ -799,7 +753,9 @@ def invoke_fused_moe_triton_kernel(
|
||||
):
|
||||
assert topk_weights is not None or not mul_routed_weight
|
||||
assert topk_weights is None or topk_weights.stride(1) == 1
|
||||
assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1
|
||||
assert sorted_token_ids.stride(0) == 1
|
||||
ops.invoke_fused_moe_kernel(A,B,C,A_scale,B_scale,topk_weights,topk_ids,sorted_token_ids,expert_ids,num_tokens_post_padded,mul_routed_weight,top_k,config,compute_type,use_fp8_w8a8,use_int8_w8a16,block_shape,B_bias)
|
||||
return
|
||||
|
||||
if use_fp8_w8a8 or use_int8_w8a8:
|
||||
assert B_scale is not None
|
||||
@@ -910,32 +866,6 @@ def dispatch_fused_moe_kernel(
|
||||
block_shape: list[int] | None = None,
|
||||
B_bias: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
invoke_fused_moe_kernel(
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
A_scale,
|
||||
B_scale,
|
||||
B_zp,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
mul_routed_weight,
|
||||
top_k,
|
||||
config,
|
||||
compute_type,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16,
|
||||
per_channel_quant,
|
||||
block_shape,
|
||||
B_bias
|
||||
)
|
||||
return
|
||||
|
||||
assert topk_weights is not None or not mul_routed_weight
|
||||
assert topk_weights is None or topk_weights.stride(1) == 1
|
||||
assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1
|
||||
@@ -999,6 +929,7 @@ def dispatch_fused_moe_kernel(
|
||||
A_scale,
|
||||
B_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
@@ -1397,14 +1328,13 @@ def get_default_config(
|
||||
"num_warps": num_warps,
|
||||
"num_stages": num_stages,
|
||||
}
|
||||
# TODO
|
||||
numel = M * topk
|
||||
if numel <= 64:
|
||||
config["BLOCK_SIZE_M"] = 32
|
||||
config['BLOCK_SIZE_M'] = 32
|
||||
elif numel <= 1024:
|
||||
config["BLOCK_SIZE_M"] = 64
|
||||
config['BLOCK_SIZE_M'] = 64
|
||||
else:
|
||||
config["BLOCK_SIZE_M"] = 256
|
||||
config['BLOCK_SIZE_M'] = 256
|
||||
return config
|
||||
|
||||
|
||||
@@ -1424,14 +1354,12 @@ def try_get_optimal_moe_config(
|
||||
else:
|
||||
# First try to load optimal config from the file
|
||||
E, _, N = w2_shape
|
||||
if dtype == "int4_w4a16":
|
||||
N = N * 2
|
||||
block_n = block_shape[0] if block_shape else 0
|
||||
block_k = block_shape[1] if block_shape else 0
|
||||
configs = get_moe_configs(E, N, dtype, block_n, block_k)
|
||||
# block_n = block_shape[0] if block_shape else 0
|
||||
# block_k = block_shape[1] if block_shape else 0
|
||||
# configs = get_moe_configs(E, N, dtype, block_n, block_k)
|
||||
|
||||
configs = None
|
||||
|
||||
|
||||
if configs:
|
||||
# If an optimal configuration map has been found, look up the
|
||||
# optimal config
|
||||
@@ -1560,13 +1488,12 @@ def outplace_fused_experts(
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return fused_experts_impl(
|
||||
return fused_experts_impl_opt(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
False,
|
||||
activation,
|
||||
apply_router_weight_on_input,
|
||||
use_fp8_w8a8,
|
||||
@@ -1626,14 +1553,12 @@ direct_register_custom_op(
|
||||
|
||||
|
||||
def torch_vllm_inplace_fused_experts(**kwargs) -> torch.Tensor:
|
||||
# torch.ops.vllm.inplace_fused_experts(**kwargs)
|
||||
inplace_fused_experts(**kwargs)
|
||||
hidden_states = kwargs["hidden_states"]
|
||||
hidden_states = kwargs['hidden_states']
|
||||
return hidden_states
|
||||
|
||||
|
||||
def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
|
||||
# return torch.ops.vllm.outplace_fused_experts(**kwargs)
|
||||
return outplace_fused_experts(**kwargs)
|
||||
|
||||
|
||||
@@ -1661,7 +1586,6 @@ def fused_experts(
|
||||
"""Run fused MoE expert computation using Triton kernels."""
|
||||
if quant_config is None:
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
|
||||
assert not inplace or not disable_inplace()
|
||||
|
||||
return dispatch_fused_experts_func(inplace)(
|
||||
@@ -1691,6 +1615,245 @@ def fused_experts(
|
||||
w2_bias=quant_config.w2_bias,
|
||||
)
|
||||
|
||||
def fused_experts_impl_opt(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
ocp_mx_scheme: str | None = None,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
w1_zp: torch.Tensor | None = None,
|
||||
w2_zp: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
block_shape: torch.Tensor | None = None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
output: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
# check constraints
|
||||
if use_fp8_w8a8 or use_int8_w8a8 or use_int8_w8a16 or use_int4_w4a16 or w1_scale or \
|
||||
w2_scale or w1_zp or w2_zp or a1_scale or a2_scale:
|
||||
raise ValueError("Quantized MoE is not supported")
|
||||
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
use_ep = expert_map is not None
|
||||
|
||||
# unsupported ep now
|
||||
if attn_metadata:
|
||||
only_decode = (use_ep == False and all(t.num_decodes > 0 and t.num_prefills ==0 for t in list(attn_metadata.values())))
|
||||
else:
|
||||
only_decode = False
|
||||
|
||||
assert topk_weights.size() == topk_ids.size(), "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert hidden_states.dtype in [
|
||||
torch.float32, torch.float16, torch.bfloat16
|
||||
]
|
||||
|
||||
num_tokens = hidden_states.size(0)
|
||||
num_experts = w1.size(0)
|
||||
top_k = topk_weights.size(1)
|
||||
|
||||
if use_ep:
|
||||
local_num_experts = w1.size(0)
|
||||
start_eid = get_ep_group().device_group.rank() * local_num_experts
|
||||
end_eid = min((get_ep_group().device_group.rank() + 1) * local_num_experts, global_num_experts)
|
||||
hidden_size = hidden_states.shape[1]
|
||||
(
|
||||
src_to_dst,
|
||||
sorted_token_ids,
|
||||
expert_sizes_gpu,
|
||||
expert_sizes_cpu,
|
||||
expand_tokens,
|
||||
) = ixfops.moe_compute_token_index_ep(
|
||||
topk_ids=topk_ids,
|
||||
num_experts=global_num_experts,
|
||||
start_expert_id=start_eid,
|
||||
end_expert_id=end_eid,
|
||||
)
|
||||
if expert_sizes_cpu.sum() == 0:
|
||||
return torch.zeros(
|
||||
(num_tokens, hidden_size),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
else:
|
||||
expand_tokens = num_tokens * top_k
|
||||
(
|
||||
src_to_dst,
|
||||
sorted_token_ids,
|
||||
expert_sizes_gpu,
|
||||
expert_sizes_cpu,
|
||||
) = ixfops.moe_compute_token_index(
|
||||
topk_ids=topk_ids,
|
||||
num_experts=num_experts,
|
||||
)
|
||||
|
||||
if only_decode:
|
||||
# expand + reorder
|
||||
hidden_states = ixfops.moe_expand_input(
|
||||
hidden_states=hidden_states,
|
||||
dst_to_src=sorted_token_ids,
|
||||
dst_tokens=expand_tokens,
|
||||
topk=top_k,
|
||||
src_to_dst=src_to_dst,
|
||||
)
|
||||
|
||||
# group gemm 1
|
||||
pt_output_1 = ixfops.moe_w16a16_group_gemv(
|
||||
input=hidden_states,
|
||||
weight=w1,
|
||||
output_dtype=hidden_states.dtype,
|
||||
tokens_per_experts_gpu=expert_sizes_gpu,
|
||||
dst_to_src=None,
|
||||
bias=w1_bias,
|
||||
format="TN",
|
||||
)
|
||||
|
||||
# act
|
||||
if activation == "silu":
|
||||
pt_output_2 = ixfops.silu_and_mul(pt_output_1)
|
||||
elif activation == "gelu":
|
||||
pt_output_2 = ixfops.gelu_and_mul(pt_output_1)
|
||||
elif activation == "swigluoai":
|
||||
pt_output_2 = ixfops.swigluoai_and_mul(pt_output_1)
|
||||
elif activation == "swiglustep":
|
||||
from vllm.model_executor.layers.activation import swiglustep_and_mul_triton
|
||||
output_dim = pt_output_1.shape[1]
|
||||
pt_output_2 = torch.empty(
|
||||
(num_tokens * top_k, output_dim//2),
|
||||
device=pt_output_1.device,
|
||||
dtype=pt_output_1.dtype,
|
||||
)
|
||||
swiglustep_and_mul_triton(pt_output_2, pt_output_1)
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {activation}")
|
||||
|
||||
# group gemm 2 + reorder
|
||||
pt_output_3 = ixfops.moe_w16a16_group_gemv(
|
||||
input=pt_output_2,
|
||||
weight=w2,
|
||||
output_dtype=hidden_states.dtype,
|
||||
tokens_per_experts_gpu=expert_sizes_gpu,
|
||||
dst_to_src=sorted_token_ids,
|
||||
bias=w2_bias,
|
||||
format="TN",
|
||||
)
|
||||
|
||||
# mul + reduce_sum
|
||||
final_hidden_states = ixfops.moe_output_reduce_sum(
|
||||
input=pt_output_3.view(num_tokens, top_k, -1),
|
||||
topk_weight=topk_weights,
|
||||
)
|
||||
|
||||
else:
|
||||
expert_sizes_cpu = expert_sizes_gpu.cpu()
|
||||
# expand + reorder
|
||||
hidden_states = ixfops.moe_expand_input(
|
||||
hidden_states=hidden_states,
|
||||
dst_to_src=sorted_token_ids,
|
||||
dst_tokens=expand_tokens,
|
||||
topk=top_k,
|
||||
src_to_dst=src_to_dst,
|
||||
)
|
||||
# group gemm 1
|
||||
pt_output_1 = ixfops.moe_w16a16_group_gemm(
|
||||
input=hidden_states,
|
||||
weight=w1,
|
||||
output_dtype=hidden_states.dtype,
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
dst_to_src=None,
|
||||
bias=w1_bias,
|
||||
format="TN",
|
||||
)
|
||||
|
||||
# act
|
||||
if activation == "silu":
|
||||
pt_output_2 = ixfops.silu_and_mul(pt_output_1)
|
||||
elif activation == "gelu":
|
||||
pt_output_2 = ixfops.gelu_and_mul(pt_output_1)
|
||||
elif activation == "swigluoai":
|
||||
pt_output_2 = ixfops.swigluoai_and_mul(pt_output_1)
|
||||
elif activation == "swiglustep":
|
||||
from vllm.model_executor.layers.activation import swiglustep_and_mul_triton
|
||||
output_dim = pt_output_1.shape[1]
|
||||
pt_output_2 = torch.empty(
|
||||
(num_tokens * top_k, output_dim//2),
|
||||
device=pt_output_1.device,
|
||||
dtype=pt_output_1.dtype,
|
||||
)
|
||||
swiglustep_and_mul_triton(pt_output_2, pt_output_1)
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {activation}")
|
||||
|
||||
if use_ep:
|
||||
pt_output_3 = torch.empty(
|
||||
(num_tokens * top_k, hidden_size),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
# group gemm 2 + reorder
|
||||
pt_output_3 = ixfops.moe_w16a16_group_gemm(
|
||||
input=pt_output_2,
|
||||
weight=w2,
|
||||
output_dtype=hidden_states.dtype,
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
dst_to_src=sorted_token_ids,
|
||||
format="TN",
|
||||
bias=w2_bias,
|
||||
output=pt_output_3,
|
||||
)
|
||||
|
||||
# mul + reduce_sum
|
||||
reduce_mask = src_to_dst == -1
|
||||
if output != None:
|
||||
ixfops.moe_output_reduce_sum(
|
||||
input=pt_output_3.view(num_tokens, top_k, -1),
|
||||
topk_weight=topk_weights,
|
||||
output=output,
|
||||
mask=reduce_mask,
|
||||
)
|
||||
else:
|
||||
final_hidden_states = ixfops.moe_output_reduce_sum(
|
||||
input=pt_output_3.view(num_tokens, top_k, -1),
|
||||
topk_weight=topk_weights,
|
||||
mask=reduce_mask,
|
||||
)
|
||||
else:
|
||||
# group gemm 2 + reorder
|
||||
pt_output_3 = ixfops.moe_w16a16_group_gemm(
|
||||
input=pt_output_2,
|
||||
weight=w2,
|
||||
output_dtype=hidden_states.dtype,
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
dst_to_src=sorted_token_ids,
|
||||
bias=w2_bias,
|
||||
format="TN",
|
||||
)
|
||||
|
||||
# mul + reduce_sum
|
||||
final_hidden_states = ixfops.moe_output_reduce_sum(
|
||||
input=pt_output_3.view(num_tokens, top_k, -1),
|
||||
topk_weight=topk_weights,
|
||||
)
|
||||
|
||||
if output == None:
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def _get_config_quant_dtype(
|
||||
use_fp8_w8a8: bool,
|
||||
@@ -1825,7 +1988,7 @@ def fused_experts_impl(
|
||||
intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K)
|
||||
|
||||
# This needs separate memory since it's used concurrently with cache1
|
||||
activation_out_dim = mk.FusedMoEPermuteExpertsUnpermute.adjust_N_for_activation(
|
||||
activation_out_dim = mk.FusedMoEExpertsModular.adjust_N_for_activation(
|
||||
N, activation_enum
|
||||
)
|
||||
intermediate_cache2 = torch.empty(
|
||||
@@ -1910,28 +2073,28 @@ def fused_experts_impl(
|
||||
ocp_mx_scheme=ocp_mx_scheme,
|
||||
)
|
||||
|
||||
# SPARSITY_FACTOR is a heuristic margin ensuring tokens_in_chunk * top_k
|
||||
# activates only a small fraction of total experts
|
||||
SPARSITY_FACTOR = 4
|
||||
# block quantized code path is not implemented yet.
|
||||
naive_block_assignment = (
|
||||
expert_map is None
|
||||
and tokens_in_chunk * top_k_num * SPARSITY_FACTOR <= global_num_experts
|
||||
and not (
|
||||
(use_int8_w8a16 or use_int4_w4a16)
|
||||
and block_shape is not None
|
||||
and block_shape[1] > 0
|
||||
)
|
||||
)
|
||||
# # SPARSITY_FACTOR is a heuristic margin ensuring tokens_in_chunk * top_k
|
||||
# # activates only a small fraction of total experts
|
||||
# SPARSITY_FACTOR = 4
|
||||
# # block quantized code path is not implemented yet.
|
||||
# naive_block_assignment = (
|
||||
# expert_map is None
|
||||
# and tokens_in_chunk * top_k_num * SPARSITY_FACTOR <= global_num_experts
|
||||
# and not (
|
||||
# (use_int8_w8a16 or use_int4_w4a16)
|
||||
# and block_shape is not None
|
||||
# and block_shape[1] > 0
|
||||
# )
|
||||
# )
|
||||
|
||||
# if not naive_block_assignment:
|
||||
# sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
# curr_topk_ids,
|
||||
# config["BLOCK_SIZE_M"],
|
||||
# global_num_experts,
|
||||
# expert_map,
|
||||
# ignore_invalid_experts=True,
|
||||
# )
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
curr_topk_ids,
|
||||
config["BLOCK_SIZE_M"],
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
ignore_invalid_experts=True,
|
||||
)
|
||||
# else:
|
||||
# max_num_tokens_padded = topk_ids.numel() * config["BLOCK_SIZE_M"]
|
||||
# expert_ids = curr_topk_ids.view(-1)
|
||||
@@ -1941,14 +2104,6 @@ def fused_experts_impl(
|
||||
# num_tokens_post_padded.fill_(max_num_tokens_padded)
|
||||
# sorted_token_ids = None
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
curr_topk_ids,
|
||||
config["BLOCK_SIZE_M"],
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
ignore_invalid_experts=True,
|
||||
)
|
||||
|
||||
dispatch_fused_moe_kernel(
|
||||
qcurr_hidden_states,
|
||||
w1,
|
||||
@@ -2015,20 +2170,14 @@ def fused_experts_impl(
|
||||
B_bias=w2_bias,
|
||||
)
|
||||
|
||||
# ops.moe_sum(
|
||||
# intermediate_cache3.view(*intermediate_cache3.size()),
|
||||
# out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
# )
|
||||
torch.sum(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
dim=1,
|
||||
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
)
|
||||
torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
dim=1,
|
||||
out=out_hidden_states[begin_chunk_idx:end_chunk_idx])
|
||||
|
||||
return out_hidden_states
|
||||
|
||||
|
||||
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class TritonExperts(mk.FusedMoEExpertsModular):
|
||||
"""Triton-based fused MoE expert implementation."""
|
||||
|
||||
def __init__(
|
||||
@@ -2091,8 +2240,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
# return not moe_parallel_config.use_fi_all2allv_kernels
|
||||
return True
|
||||
return not moe_parallel_config.use_fi_all2allv_kernels
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return True
|
||||
@@ -2138,157 +2286,31 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
# Check constraints.
|
||||
if self.quant_config.use_int4_w4a16:
|
||||
assert hidden_states.size(-1) // 2 == w1.size(2), "Hidden size mismatch"
|
||||
else:
|
||||
assert hidden_states.size(-1) == w1.size(2), (
|
||||
f"Hidden size mismatch {hidden_states.size(-1)} != {w1.size(2)}"
|
||||
)
|
||||
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert hidden_states.dim() == 2
|
||||
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert hidden_states.dtype in [
|
||||
torch.float32,
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e4m3fnuz,
|
||||
]
|
||||
|
||||
E, num_tokens, N, K, top_k_num = self.moe_problem_size(
|
||||
hidden_states, w1, w2, topk_ids
|
||||
)
|
||||
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
|
||||
config = try_get_optimal_moe_config(
|
||||
w1.size(),
|
||||
w2.size(),
|
||||
top_k_num,
|
||||
self.quant_config.config_name(hidden_states.dtype),
|
||||
num_tokens,
|
||||
block_shape=self.block_shape,
|
||||
)
|
||||
|
||||
if hidden_states.dtype == torch.bfloat16:
|
||||
compute_type = tl.bfloat16
|
||||
elif hidden_states.dtype == torch.float16:
|
||||
compute_type = tl.float16
|
||||
elif hidden_states.dtype == torch.float32:
|
||||
compute_type = tl.float32
|
||||
elif (
|
||||
hidden_states.dtype == torch.float8_e4m3fn
|
||||
or hidden_states.dtype == torch.float8_e4m3fnuz
|
||||
):
|
||||
compute_type = tl.bfloat16
|
||||
else:
|
||||
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
|
||||
|
||||
# Note that the output tensor might be in workspace1
|
||||
intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N))
|
||||
cache2_dim = self.adjust_N_for_activation(N, activation)
|
||||
intermediate_cache2 = _resize_cache(
|
||||
workspace13, (num_tokens * top_k_num, cache2_dim)
|
||||
)
|
||||
intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K))
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map
|
||||
)
|
||||
|
||||
invoke_fused_moe_triton_kernel(
|
||||
hidden_states,
|
||||
w1,
|
||||
intermediate_cache1,
|
||||
a1q_scale,
|
||||
self.w1_scale,
|
||||
None, # topk_weights
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
False, # mul_routed_weights
|
||||
top_k_num,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
|
||||
use_int8_w8a8=self.quant_config.use_int8_w8a8,
|
||||
use_int8_w8a16=self.quant_config.use_int8_w8a16,
|
||||
use_int4_w4a16=self.quant_config.use_int4_w4a16,
|
||||
per_channel_quant=self.per_act_token_quant,
|
||||
block_shape=self.block_shape,
|
||||
B_bias=self.w1_bias,
|
||||
)
|
||||
|
||||
self.activation(
|
||||
activation, intermediate_cache2, intermediate_cache1.view(-1, N)
|
||||
)
|
||||
|
||||
a2q_scale: torch.Tensor | None = None
|
||||
|
||||
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||
intermediate_cache2,
|
||||
a2_scale,
|
||||
self.quant_dtype,
|
||||
self.per_act_token_quant,
|
||||
self.block_shape,
|
||||
)
|
||||
|
||||
# invoke_fused_moe_triton_kernel(
|
||||
# qintermediate_cache2,
|
||||
# w2,
|
||||
# intermediate_cache3,
|
||||
# a2q_scale,
|
||||
# self.w2_scale,
|
||||
# topk_weights,
|
||||
# sorted_token_ids,
|
||||
# expert_ids,
|
||||
# num_tokens_post_padded,
|
||||
# not apply_router_weight_on_input,
|
||||
# 1,
|
||||
# config,
|
||||
# compute_type=compute_type,
|
||||
# use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
|
||||
# use_int8_w8a8=self.quant_config.use_int8_w8a8,
|
||||
# use_int8_w8a16=self.quant_config.use_int8_w8a16,
|
||||
# use_int4_w4a16=self.quant_config.use_int4_w4a16,
|
||||
# per_channel_quant=self.per_act_token_quant,
|
||||
# block_shape=self.block_shape,
|
||||
# B_bias=self.w2_bias,
|
||||
# )
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
qintermediate_cache2,
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
a2q_scale,
|
||||
self.w2_scale,
|
||||
self.w2_zp,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
not apply_router_weight_on_input,
|
||||
1,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
|
||||
use_int8_w8a8=self.quant_config.use_int8_w8a8,
|
||||
use_int8_w8a16=self.quant_config.use_int8_w8a16,
|
||||
use_int4_w4a16=self.quant_config.use_int4_w4a16,
|
||||
per_channel_quant=self.per_act_token_quant,
|
||||
block_shape=self.block_shape,
|
||||
B_bias=self.w2_bias,
|
||||
)
|
||||
|
||||
# separate function is required for MoE + LoRA
|
||||
self.moe_sum(intermediate_cache3, output)
|
||||
|
||||
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
|
||||
ops.moe_sum(input, output)
|
||||
fused_experts_impl_opt(hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation,
|
||||
apply_router_weight_on_input,
|
||||
self.quant_config.use_fp8_w8a8,
|
||||
self.quant_config.use_int8_w8a8,
|
||||
self.quant_config.use_int8_w8a16,
|
||||
self.quant_config.use_int4_w4a16,
|
||||
self.quant_config.ocp_mx_scheme,
|
||||
self.quant_config.per_act_token_quant,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
self.quant_config.w1_scale,
|
||||
self.quant_config.w2_scale,
|
||||
self.quant_config.w1_zp,
|
||||
self.quant_config.w2_zp,
|
||||
self.quant_config.a1_scale,
|
||||
self.quant_config.a2_scale,
|
||||
self.quant_config.block_shape,
|
||||
self.quant_config.w1_bias,
|
||||
self.quant_config.w2_bias,
|
||||
output)
|
||||
|
||||
|
||||
class TritonWNA16Experts(TritonExperts):
|
||||
|
||||
@@ -12,8 +12,8 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize,
|
||||
FusedMoEExpertsModular,
|
||||
FusedMoEPrepareAndFinalizeModular,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
@@ -27,19 +27,21 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
super().__init__()
|
||||
self.moe: FusedMoEConfig = moe
|
||||
self.moe_quant_config: FusedMoEQuantConfig | None = None
|
||||
self.moe_mk: mk.FusedMoEModularKernel | None = None
|
||||
self.moe_kernel: mk.FusedMoEKernel | None = None
|
||||
|
||||
@property
|
||||
def supports_internal_mk(self) -> bool:
|
||||
# NOTE(rob): temporary attribute to indicate support for
|
||||
# completed migration to the new internal MK interface.
|
||||
return self.moe_mk is not None
|
||||
return self.moe_kernel is not None
|
||||
|
||||
@property
|
||||
def mk_owns_shared_expert(self) -> bool:
|
||||
# NOTE(rob): temporary attribute to indicate support for
|
||||
# completed migration to the new internal MK interface.
|
||||
return self.moe_mk is not None and self.moe_mk.shared_experts is not None
|
||||
return (
|
||||
self.moe_kernel is not None and self.moe_kernel.shared_experts is not None
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(
|
||||
@@ -66,35 +68,25 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> FusedMoEPrepareAndFinalize | None:
|
||||
) -> FusedMoEPrepareAndFinalizeModular | None:
|
||||
from .all2all_utils import maybe_make_prepare_finalize
|
||||
|
||||
return maybe_make_prepare_finalize(
|
||||
pf = maybe_make_prepare_finalize(
|
||||
self.moe, self.moe_quant_config, routing_tables
|
||||
)
|
||||
assert pf is None or isinstance(pf, FusedMoEPrepareAndFinalizeModular)
|
||||
return pf
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalizeModular,
|
||||
layer: torch.nn.Module,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
) -> FusedMoEExpertsModular:
|
||||
# based on the all2all implementation, select the appropriate
|
||||
# gemm implementation
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} must select appropriate gemm "
|
||||
"implementation based on the prepare_finalize"
|
||||
)
|
||||
|
||||
def prepare_dp_allgather_tensor(
|
||||
self,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
||||
"""Hook to prepare tensors and extra tensors for DP allgather + EP dispatch."""
|
||||
raise NotImplementedError(
|
||||
"Method 'prepare_dp_allgather_tensor' is not implemented in "
|
||||
f"{self.__class__.__name__}."
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
@@ -105,8 +97,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
|
||||
@property
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
if self.moe_mk is not None:
|
||||
return self.moe_mk.prepare_finalize.topk_indices_dtype()
|
||||
if self.moe_kernel is not None:
|
||||
return self.moe_kernel.prepare_finalize.topk_indices_dtype()
|
||||
return None
|
||||
|
||||
@property
|
||||
@@ -119,7 +111,12 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return False
|
||||
if self.moe_kernel is None:
|
||||
if hasattr(self, "experts_cls"):
|
||||
return self.experts_cls.is_monolithic()
|
||||
else:
|
||||
return False
|
||||
return self.moe_kernel.is_monolithic
|
||||
|
||||
def apply(
|
||||
self,
|
||||
|
||||
@@ -13,8 +13,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel,
|
||||
FusedMoEPrepareAndFinalize,
|
||||
FusedMoEKernel,
|
||||
FusedMoEPrepareAndFinalizeModular,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -26,15 +26,15 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
# --8<-- [end:modular_fused_moe]
|
||||
|
||||
def __init__(
|
||||
self, old_quant_method: FusedMoEMethodBase, experts: FusedMoEModularKernel
|
||||
self, old_quant_method: FusedMoEMethodBase, moe_kernel: FusedMoEKernel
|
||||
):
|
||||
super().__init__(old_quant_method.moe)
|
||||
self.moe_quant_config = old_quant_method.moe_quant_config
|
||||
self.moe_mk = experts
|
||||
self.moe_kernel = moe_kernel
|
||||
self.disable_expert_map = getattr(
|
||||
old_quant_method,
|
||||
"disable_expert_map",
|
||||
not self.moe_mk.supports_expert_map(),
|
||||
not self.moe_kernel.supports_expert_map(),
|
||||
)
|
||||
self.old_quant_method = old_quant_method
|
||||
logger.debug("Swapping out %s", self.old_quant_method.__class__.__name__)
|
||||
@@ -43,13 +43,13 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
def make(
|
||||
moe_layer: torch.nn.Module,
|
||||
old_quant_method: FusedMoEMethodBase,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalizeModular,
|
||||
shared_experts: torch.nn.Module | None,
|
||||
inplace: bool = False,
|
||||
) -> "FusedMoEModularMethod":
|
||||
return FusedMoEModularMethod(
|
||||
old_quant_method,
|
||||
FusedMoEModularKernel(
|
||||
FusedMoEKernel(
|
||||
prepare_finalize,
|
||||
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
|
||||
shared_experts,
|
||||
@@ -90,8 +90,8 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.moe_mk is not None
|
||||
return self.moe_mk(
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
|
||||
@@ -6,6 +6,7 @@ import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
@@ -178,7 +179,40 @@ def triton_kernel_moe_forward(
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
unpadded_N_w1=None,
|
||||
unpadded_K_w1=None,
|
||||
unpadded_N_w2=None,
|
||||
unpadded_K_w2=None,
|
||||
) -> torch.Tensor:
|
||||
if (
|
||||
quant_config is not None
|
||||
and quant_config.use_mxfp4_w4a8
|
||||
and rocm_aiter_ops.is_enabled()
|
||||
):
|
||||
from aiter.ops.triton.moe_routing.routing import routing as aiter_routing
|
||||
|
||||
routing_data, gather_idx, scatter_idx = aiter_routing(
|
||||
gating_output, topk, sm_first=not renormalize
|
||||
)
|
||||
return triton_kernel_fused_mxfp4_w4a8_experts(
|
||||
None,
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
routing_data,
|
||||
gather_idx,
|
||||
scatter_idx,
|
||||
activation=activation.value,
|
||||
quant_config=quant_config,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
unpadded_N_w1=unpadded_N_w1,
|
||||
unpadded_K_w1=unpadded_K_w1,
|
||||
unpadded_N_w2=unpadded_N_w2,
|
||||
unpadded_K_w2=unpadded_K_w2,
|
||||
)
|
||||
|
||||
if expert_map is not None:
|
||||
# With expert parallelism, legacy_routing produces routing data
|
||||
# using global expert IDs which don't correspond to local weight
|
||||
@@ -210,6 +244,9 @@ def triton_kernel_moe_forward(
|
||||
effective_global_num_experts = global_num_experts
|
||||
|
||||
output = torch.empty_like(hidden_states)
|
||||
effective_quant_config = (
|
||||
quant_config if quant_config is not None else FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
)
|
||||
|
||||
return triton_kernel_fused_experts(
|
||||
output,
|
||||
@@ -221,7 +258,7 @@ def triton_kernel_moe_forward(
|
||||
scatter_idx,
|
||||
topk=topk,
|
||||
activation=activation,
|
||||
quant_config=quant_config,
|
||||
quant_config=effective_quant_config,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=effective_global_num_experts,
|
||||
expert_map=effective_expert_map,
|
||||
@@ -252,8 +289,7 @@ def triton_kernel_fused_experts(
|
||||
assert activation == MoEActivation.SWIGLUOAI, (
|
||||
"Only SWIGLUOAI activation is supported"
|
||||
)
|
||||
if quant_config is None:
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
assert quant_config is not None
|
||||
|
||||
# type check, uint8 means mxfp4
|
||||
assert hidden_states.dtype == torch.bfloat16
|
||||
@@ -330,6 +366,98 @@ def triton_kernel_fused_experts(
|
||||
return output_tensor
|
||||
|
||||
|
||||
# This is a triton implementation of the fused_experts function
|
||||
def triton_kernel_fused_mxfp4_w4a8_experts(
|
||||
output_tensor: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1, # Tensor or triton_kernels.Tensor
|
||||
w2, # Tensor or triton_kernels.Tensor
|
||||
routing_data, # RoutingData
|
||||
gather_indx, # GatherIndx
|
||||
scatter_indx, # ScatterIndx
|
||||
activation: str = "silu",
|
||||
quant_config: FusedMoEQuantConfig | None = None,
|
||||
swiglu_alpha: float = 1.702,
|
||||
swiglu_limit: float = 7.0,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
a1q_scale: torch.Tensor | None = None,
|
||||
unpadded_N_w1=None,
|
||||
unpadded_K_w1=None,
|
||||
unpadded_N_w2=None,
|
||||
unpadded_K_w2=None,
|
||||
) -> torch.Tensor:
|
||||
assert quant_config is not None
|
||||
# type check, uint8 means mxfp4
|
||||
assert hidden_states.dtype == torch.bfloat16
|
||||
assert quant_config.w1_bias is None or quant_config.w1_bias.dtype == torch.float32
|
||||
assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32
|
||||
|
||||
# Shape check, only check non-mxfp4
|
||||
assert hidden_states.shape[-1] == w1.shape[-2]
|
||||
assert w2.shape[-1] == w1.shape[1]
|
||||
|
||||
E, _, N = w1.shape
|
||||
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
|
||||
gammas = routing_data.gate_scal if routing_data else None
|
||||
|
||||
from aiter.ops.triton.moe_op_gemm_a8w4 import moe_gemm_a8w4
|
||||
from aiter.ops.triton.quant_moe import downcast_to_static_fp8
|
||||
|
||||
assert quant_config.w1_precision is not None, (
|
||||
"w1_precision in quant config can't be None"
|
||||
)
|
||||
assert quant_config.w2_precision is not None, (
|
||||
"w2_precision in quant config can't be None"
|
||||
)
|
||||
|
||||
hidden_states = downcast_to_static_fp8(
|
||||
hidden_states, quant_config.w1_precision.flex_ctx.lhs_data.scale
|
||||
)
|
||||
|
||||
intermediate_cache1 = moe_gemm_a8w4(
|
||||
hidden_states,
|
||||
w1.storage.data,
|
||||
None,
|
||||
quant_config.w1_precision.weight_scale.storage.data,
|
||||
quant_config.w1_precision.flex_ctx.lhs_data.scale,
|
||||
quant_config.w2_precision.flex_ctx.lhs_data.scale,
|
||||
quant_config.w1_bias,
|
||||
routing_data,
|
||||
gather_indx=gather_indx,
|
||||
gammas=gammas if apply_router_weight_on_input else None,
|
||||
swizzle_mx_scale="CDNA4_SCALE",
|
||||
out_dtype=torch.float8_e4m3fn,
|
||||
apply_swiglu=True,
|
||||
alpha=swiglu_alpha,
|
||||
limit=swiglu_limit,
|
||||
unpadded_N=unpadded_N_w1,
|
||||
unpadded_K=unpadded_K_w1,
|
||||
)
|
||||
|
||||
intermediate_cache3 = moe_gemm_a8w4(
|
||||
intermediate_cache1,
|
||||
w2.storage.data,
|
||||
None,
|
||||
quant_config.w2_precision.weight_scale.storage.data,
|
||||
quant_config.w2_precision.flex_ctx.lhs_data.scale,
|
||||
None,
|
||||
quant_config.w2_bias,
|
||||
routing_data,
|
||||
scatter_indx=scatter_indx,
|
||||
gammas=None if apply_router_weight_on_input else gammas,
|
||||
swizzle_mx_scale="CDNA4_SCALE",
|
||||
unpadded_N=unpadded_N_w2,
|
||||
unpadded_K=unpadded_K_w2,
|
||||
)
|
||||
|
||||
return intermediate_cache3
|
||||
|
||||
|
||||
def make_routing_data(
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
@@ -383,7 +511,7 @@ def make_routing_data(
|
||||
return routing_data, gather_indx, scatter_indx
|
||||
|
||||
|
||||
class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class BaseOAITritonExperts(mk.FusedMoEExpertsModular):
|
||||
@staticmethod
|
||||
def _supports_current_device() -> bool:
|
||||
raise NotImplementedError(
|
||||
@@ -520,6 +648,9 @@ class OAITritonExperts(BaseOAITritonExperts):
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
if self.quant_config is None:
|
||||
self.quant_config: FusedMoEQuantConfig = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
|
||||
if expert_map is not None:
|
||||
topk_ids = expert_map[topk_ids]
|
||||
|
||||
|
||||
@@ -5,8 +5,8 @@ from collections.abc import Callable, Iterable
|
||||
from enum import Enum
|
||||
from typing import Literal, cast, get_args, overload
|
||||
|
||||
import ast, re
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import UninitializedParameter
|
||||
|
||||
import vllm.envs as envs
|
||||
@@ -54,10 +54,14 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import round_up
|
||||
from vllm.model_executor.layers.utils import (
|
||||
parse_opt_exclude_layers,
|
||||
weight_quant_l1,
|
||||
weight_quant_l2,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FusedMoeWeightScaleSupported(Enum):
|
||||
TENSOR = "tensor"
|
||||
CHANNEL = "channel"
|
||||
@@ -333,6 +337,7 @@ class FusedMoE(CustomOp):
|
||||
gate: torch.nn.Module | None = None,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
routed_input_transform: torch.nn.Module | None = None,
|
||||
fused_shared_output: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -483,6 +488,8 @@ class FusedMoE(CustomOp):
|
||||
(expert_mask == 0) | (expert_mask == 1)
|
||||
), "Aiter Fused MoE kernel only supports expert_map with 0 and 1s."
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.num_experts = num_experts
|
||||
assert intermediate_size % self.tp_size == 0
|
||||
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
||||
self.reduce_results = reduce_results
|
||||
@@ -526,16 +533,18 @@ class FusedMoE(CustomOp):
|
||||
|
||||
# Round up hidden size before creating moe_config.
|
||||
# This way moe_config is created with the correct hidden_size from the start.
|
||||
unpadded_hidden_size = hidden_size
|
||||
self.model_type = (
|
||||
self.vllm_config.model_config.hf_config.model_type
|
||||
if self.vllm_config.model_config is not None
|
||||
else None
|
||||
)
|
||||
hidden_size = maybe_roundup_hidden_size(
|
||||
hidden_size=hidden_size,
|
||||
act_dtype=moe_in_dtype,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
is_lora_enabled=vllm_config.lora_config is not None,
|
||||
model_type=(
|
||||
self.vllm_config.model_config.hf_config.model_type
|
||||
if self.vllm_config.model_config is not None
|
||||
else None
|
||||
),
|
||||
model_type=self.model_type,
|
||||
is_mxfp4_quant=(
|
||||
quant_config is not None and quant_config.is_mxfp4_quant(prefix, self)
|
||||
),
|
||||
@@ -581,14 +590,27 @@ class FusedMoE(CustomOp):
|
||||
"""
|
||||
quant_method = None
|
||||
if self.quant_config is not None:
|
||||
self.opt_level = 0
|
||||
quant_method = self.quant_config.get_quant_method(self, prefix)
|
||||
if quant_method is None:
|
||||
quant_method = UnquantizedFusedMoEMethod(self.moe_config)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
|
||||
CompressedTensorsL1OptMoEMethod, CompressedTensorsL2OptMoEMethod)
|
||||
if self.opt_level == 1:
|
||||
quant_method = CompressedTensorsL1OptMoEMethod(self.moe_config)
|
||||
elif self.opt_level == 2:
|
||||
quant_method = CompressedTensorsL2OptMoEMethod(self.moe_config)
|
||||
else:
|
||||
quant_method = UnquantizedFusedMoEMethod(self.moe_config)
|
||||
assert isinstance(quant_method, FusedMoEMethodBase)
|
||||
return quant_method
|
||||
|
||||
# Note: get_quant_method will look at the layer's local_num_experts
|
||||
# for heuristic purposes, so it must be initialized first.
|
||||
self.opt_level = envs.VLLM_MOE_OPT_LEVEL
|
||||
if parse_opt_exclude_layers(envs.VLLM_OPT_EXCLUDE_LAYERS, prefix):
|
||||
self.opt_flag = False
|
||||
logger.info(f"Excluding layer {prefix} from optimization")
|
||||
|
||||
self.quant_method: FusedMoEMethodBase = _get_quant_method()
|
||||
|
||||
if not self.moe_config.is_act_and_mul and not current_platform.is_cuda_alike():
|
||||
@@ -611,6 +633,7 @@ class FusedMoE(CustomOp):
|
||||
moe_quant_params = {
|
||||
"num_experts": self.local_num_experts,
|
||||
"hidden_size": hidden_size,
|
||||
"unpadded_hidden_size": unpadded_hidden_size,
|
||||
"intermediate_size_per_partition": self.intermediate_size_per_partition,
|
||||
"params_dtype": params_dtype,
|
||||
"weight_loader": self.weight_loader,
|
||||
@@ -625,6 +648,7 @@ class FusedMoE(CustomOp):
|
||||
moe_quant_params["intermediate_size_full"] = intermediate_size
|
||||
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
self.base_quant_method = self.quant_method
|
||||
|
||||
# Disable shared expert overlap if:
|
||||
# - we are using eplb with non-default backend, because of correctness issues
|
||||
@@ -638,7 +662,10 @@ class FusedMoE(CustomOp):
|
||||
)
|
||||
and self._shared_experts is not None
|
||||
)
|
||||
|
||||
if fused_shared_output:
|
||||
assert self.use_ep == False, "Fused shared output is only supported when EP is disabled."
|
||||
assert shared_experts is not None, "Shared experts must be provided when fused_shared_output is True."
|
||||
self.fused_shared_output = fused_shared_output
|
||||
self.runner = self._init_runner()
|
||||
|
||||
def _init_runner(self):
|
||||
@@ -655,6 +682,7 @@ class FusedMoE(CustomOp):
|
||||
quant_method=self.quant_method,
|
||||
reduce_results=self.reduce_results,
|
||||
enable_dbo=self.vllm_config.parallel_config.enable_dbo,
|
||||
fused_shared_output=self.fused_shared_output,
|
||||
)
|
||||
|
||||
# TODO(bnell): This method is provided as a hook so vllm/lora/layers/fused_moe.py
|
||||
@@ -681,7 +709,7 @@ class FusedMoE(CustomOp):
|
||||
# routing_tables only needed for round-robin expert placement with
|
||||
# DeepEP all2all backend.
|
||||
routing_tables = self._maybe_init_expert_routing_tables()
|
||||
prepare_finalize = self.quant_method.maybe_make_prepare_finalize(
|
||||
prepare_finalize = self.base_quant_method.maybe_make_prepare_finalize(
|
||||
routing_tables=routing_tables
|
||||
)
|
||||
if prepare_finalize is not None:
|
||||
@@ -691,7 +719,7 @@ class FusedMoE(CustomOp):
|
||||
self._replace_quant_method(
|
||||
FusedMoEModularMethod.make(
|
||||
self,
|
||||
self.quant_method,
|
||||
self.base_quant_method,
|
||||
prepare_finalize,
|
||||
self.shared_experts,
|
||||
inplace=not self.moe_config.disable_inplace,
|
||||
@@ -959,11 +987,7 @@ class FusedMoE(CustomOp):
|
||||
else:
|
||||
assert shard_id == "w3"
|
||||
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
||||
try:
|
||||
expert_data.copy_(loaded_weight)
|
||||
except Exception as e:
|
||||
print(expert_data.shape, expert_data.dtype, loaded_weight.shape, loaded_weight.dtype)
|
||||
raise e
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
def _load_w2(
|
||||
self,
|
||||
@@ -976,7 +1000,7 @@ class FusedMoE(CustomOp):
|
||||
# Index the loaded weight for tp sharding.
|
||||
# down_proj: "RowParallel" so tp sharding on input_dim
|
||||
# Narrow parameter and load.
|
||||
shard_size = expert_data.shape[shard_dim]
|
||||
shard_size = loaded_weight.shape[shard_dim] // self.tp_size
|
||||
# Only narrow if the loaded_weight is not a scalar (0-dim tensor)
|
||||
# and we're not loading the full weight
|
||||
if not load_full and loaded_weight.ndim > 0:
|
||||
@@ -984,7 +1008,55 @@ class FusedMoE(CustomOp):
|
||||
shard_dim, shard_size * tp_rank, shard_size
|
||||
)
|
||||
# w2, down_proj: Load into only logical weight of w2.
|
||||
expert_data.copy_(loaded_weight)
|
||||
expert_data.narrow(shard_dim, 0, shard_size).copy_(loaded_weight)
|
||||
|
||||
def _load_model_opt_weight_or_group_weight_scale(self,
|
||||
shard_dim: int,
|
||||
shard_dim_scale: int,
|
||||
expert_data: torch.Tensor,
|
||||
scale_data: torch.Tensor,
|
||||
shard_id: str,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
opt_level: int,
|
||||
load_full_w2: bool = False):
|
||||
"""
|
||||
Load grouped weight scales for group quantization or model weights
|
||||
:param shard_dim: dimension to shard
|
||||
:param expert_data: parameter for a particular expert
|
||||
:param shard_id: either w1, w2, or w3
|
||||
:param loaded_weight: checkpoint weight to load into the param
|
||||
:param tp_rank: tensor parallel rank
|
||||
:param load_full_w2: whether or not the w2 loaded should be sharded.
|
||||
"""
|
||||
|
||||
assert opt_level in [1, 2]
|
||||
if opt_level == 1:
|
||||
weight, scale = weight_quant_l1(loaded_weight)
|
||||
else:
|
||||
weight, scale = weight_quant_l2(loaded_weight)
|
||||
scale = scale.view(1, -1)
|
||||
|
||||
if shard_id == "w2":
|
||||
# In the case where we have actorder/g_idx, we do not partition the
|
||||
# w2 scales, as indicated by `load_full` argument, for all tp cases
|
||||
self._load_w2(shard_dim=shard_dim,
|
||||
loaded_weight=weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank,
|
||||
load_full=load_full_w2)
|
||||
scale_data.copy_(scale)
|
||||
elif shard_id in ("w1", "w3"):
|
||||
self._load_w13(shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank)
|
||||
self._load_w13(shard_id=shard_id,
|
||||
shard_dim=shard_dim_scale,
|
||||
loaded_weight=scale,
|
||||
expert_data=scale_data,
|
||||
tp_rank=tp_rank)
|
||||
|
||||
def _load_single_value(
|
||||
self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int
|
||||
@@ -1147,7 +1219,6 @@ class FusedMoE(CustomOp):
|
||||
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
||||
if is_transposed:
|
||||
shard_dim = int(not shard_dim)
|
||||
|
||||
shard_dim_force = getattr(param, "shard_dim", None)
|
||||
shard_dim = shard_dim_force if shard_dim_force is not None else shard_dim
|
||||
|
||||
@@ -1309,13 +1380,28 @@ class FusedMoE(CustomOp):
|
||||
|
||||
# Case model weights
|
||||
if "weight" in weight_name:
|
||||
self._load_model_weight_or_group_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=self.tp_rank,
|
||||
)
|
||||
if self.opt_level != 0:
|
||||
scale_name = weight_name.split('.')[-1] + "_scale"
|
||||
params_dict = dict(self.named_parameters())
|
||||
scale_param = params_dict[scale_name]
|
||||
shard_dim_scale = getattr(scale_param, "shard_dim", None)
|
||||
scale_expert_data = scale_param.data if full_load else scale_param.data[expert_id]
|
||||
self._load_model_opt_weight_or_group_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
shard_dim_scale=shard_dim_scale,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
scale_data=scale_expert_data,
|
||||
opt_level=self.opt_level,
|
||||
tp_rank=self.tp_rank)
|
||||
else:
|
||||
self._load_model_weight_or_group_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=self.tp_rank)
|
||||
return True if return_success else None
|
||||
|
||||
return False if return_success else None
|
||||
|
||||
@@ -20,6 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache,
|
||||
@@ -56,25 +57,25 @@ logger = init_logger(__name__)
|
||||
# MoE kernel implementations.
|
||||
#
|
||||
# The following main classes are defined:
|
||||
# * FusedMoEPrepareAndFinalize - an abstract base class for preparation of MoE
|
||||
# * FusedMoEPrepareAndFinalizeModular - an abstract base class for preparation of MoE
|
||||
# inputs (e.g. quantization, distribution) and finalization of Moe outputs.
|
||||
# The prepare method must take care of any needed quantization and the
|
||||
# finalize method, informed by the FusedMoEPermuteExpertsUnpermute method,
|
||||
# finalize method, informed by the FusedMoEExpertsModular method,
|
||||
# may apply weights and/or do the final reduction of the output.
|
||||
# * FusedMoEPermuteExpertsUnpermute - an abstract base class for the main fused
|
||||
# * FusedMoEExpertsModular - an abstract base class for the main fused
|
||||
# MoE operation, i.e matmul + act_mul + optionally quant + matmul.
|
||||
# Some FusedMoEPermuteExpertsUnpermute implementations may choose to do
|
||||
# Some FusedMoEExpertsModular implementations may choose to do
|
||||
# the weight application and/or reduction. The class communicates this
|
||||
# to [Finalize] via a TopKWeightAndReduce object.
|
||||
# * FusedMoEModularKernel - an interface class that combines a
|
||||
# FusedMoEPrepareAndFinalize and a FusedMoEPermuteExpertsUnpermute to
|
||||
# FusedMoEPrepareAndFinalizeModular and a FusedMoEExpertsModular to
|
||||
# provide the standard fused MoE kernel interface.
|
||||
# * TopKWeightAndReduce - A TopKWeightAndReduce implementation chosen
|
||||
# by the FusedMoEPermuteExpertsUnpermute implementation that is passed
|
||||
# by the FusedMoEExpertsModular implementation that is passed
|
||||
# on to [Finalize].
|
||||
#
|
||||
# [Quantize-Prepare] and [Finalize] functionality are bundled into a single
|
||||
# class `FusedMoEPrepareAndFinalize` since they could use collective
|
||||
# class `FusedMoEPrepareAndFinalizeModular` since they could use collective
|
||||
# communication mechanisms that need to be consistent.
|
||||
#
|
||||
|
||||
@@ -155,25 +156,96 @@ PrepareResultType = tuple[
|
||||
torch.Tensor | None,
|
||||
]
|
||||
|
||||
#
|
||||
# PrepareResultType is a tuple of:
|
||||
# - quantized + dispatched a.
|
||||
# - quantized + dispatched a1_scales.
|
||||
# - dispatched router logits.
|
||||
#
|
||||
# See `prepare_monolithic` method below.
|
||||
#
|
||||
PrepareMonolithicResultType = tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor,
|
||||
]
|
||||
|
||||
ReceiverType = Callable[[], PrepareResultType]
|
||||
|
||||
################################################################################
|
||||
# Prepare/Finalize
|
||||
################################################################################
|
||||
|
||||
|
||||
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
|
||||
class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
An abstract base class for the [Quantize-Prepare] and [Finalize] steps
|
||||
described above.
|
||||
|
||||
There are two variants of this class:
|
||||
* FusedMoEPrepareAndFinalizeModular - this operates on topk ids and weights
|
||||
* FusedMoEPrepareAndFinalizeMonolithic - the operates on router_logits
|
||||
"""
|
||||
|
||||
def post_init_setup(self, fused_experts: "FusedMoEPermuteExpertsUnpermute"):
|
||||
def post_init_setup(self, fused_experts: "FusedMoEExperts"):
|
||||
"""
|
||||
Initialize FusedMoEPrepareAndFinalize settings that depend on
|
||||
FusedMoEPermuteExpertsUnpermute experts object.
|
||||
The FusedMoEPrepareAndFinalize implementations that have such
|
||||
Initialize FusedMoEPrepareAndFinalizeModular settings that depend on
|
||||
FusedMoEExpertsModular experts object.
|
||||
The FusedMoEPrepareAndFinalizeModular implementations that have such
|
||||
dependencies may choose to override this function.
|
||||
"""
|
||||
return
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def activation_format(self) -> FusedMoEActivationFormat:
|
||||
"""
|
||||
A property indicating the output format of the activations for the
|
||||
'prepare' method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
"""
|
||||
The PrepareFinalize All2All implementations generally constrain the
|
||||
dtype of the topk_ids they support. This function returns the
|
||||
required topk indices dtype so it can be respected.
|
||||
Return None if there are no such restrictions.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
"""
|
||||
Some PrepareFinalize All2All implementations are batched. Meaning,
|
||||
they can process only as set of tokens at a time. This
|
||||
function returns the batch size i.e the maximum number of tokens
|
||||
the implementation can process at a time.
|
||||
Return None if there are no such restrictions.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def num_dispatchers(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def output_is_reduced(self) -> bool:
|
||||
"""
|
||||
Indicates whether or not the output of finalize is reduced across all
|
||||
ranks.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
|
||||
class FusedMoEPrepareAndFinalizeModular(FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
An abstract base class for the [Quantize-Prepare] and [Finalize] steps
|
||||
described above for the Modular case.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def prepare(
|
||||
self,
|
||||
@@ -198,7 +270,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
activations, before quantization + dispatching.
|
||||
- quant_config: Quantization info provided by the fused experts.
|
||||
- defer_input_quant: Runtime parameter indicating whether or not to
|
||||
defer input quantization to the FusedMoEPermuteExpertsUnpermute
|
||||
defer input quantization to the FusedMoEExpertsModular
|
||||
in cases where the compute kernel expects unquantized inputs
|
||||
|
||||
Returns a tuple of:
|
||||
@@ -245,7 +317,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
- apply_router_weight_on_input: When True, apply the weights to the
|
||||
activations, before quantization + dispatching.
|
||||
- defer_input_quant: Runtime parameter indicating whether or not to
|
||||
defer input quantization to the FusedMoEPermuteExpertsUnpermute
|
||||
defer input quantization to the FusedMoEExpertsModular
|
||||
in cases where the compute kernel expects unquantized inputs
|
||||
|
||||
Returns a callback or a hook callback pair that when invoked waits for
|
||||
@@ -338,56 +410,58 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
|
||||
class FusedMoEPrepareAndFinalizeMonolithic(FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
An abstract base class for the [Quantize-Prepare] and [Finalize] steps
|
||||
described above for the monolithic case.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def activation_format(self) -> FusedMoEActivationFormat:
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> PrepareMonolithicResultType:
|
||||
"""
|
||||
A property indicating the output format of the activations for the
|
||||
'prepare' method.
|
||||
Optional method for subclasses compatible with monolithic
|
||||
FusedMoEExpertsModular kernels.
|
||||
|
||||
Perform any quantization (and/or) dispatching needed for this kernel.
|
||||
- a1: The (unquantized) input to the MoE layer.
|
||||
- quant_config: Quantization info provided by the fused experts.
|
||||
- defer_input_quant: Runtime parameter indicating whether or not to
|
||||
defer input quantization to the FusedMoEExpertsModular
|
||||
|
||||
Returns a tuple of:
|
||||
- quantized + dispatched a.
|
||||
- Optional quantized + dispatched a1_scales.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
def finalize(self, fused_expert_output: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
The PrepareFinalize All2All implementations generally constrain the
|
||||
dtype of the topk_ids they support. This function returns the
|
||||
required topk indices dtype so it can be respected.
|
||||
Return None if there are no such restrictions.
|
||||
Optional method for subclasses compatible with monolithic
|
||||
FusedMoEExpertsModular kernels.
|
||||
|
||||
Perform any combine plus apply weights and perform a reduction on the
|
||||
fused experts output.
|
||||
- fused_expert_output: The unweighted, unreduced output of the fused
|
||||
experts, it will have (M, topk, K) shape.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
"""
|
||||
Some PrepareFinalize All2All implementations are batched. Meaning,
|
||||
they can process only as set of tokens at a time. This
|
||||
function returns the batch size i.e the maximum number of tokens
|
||||
the implementation can process at a time.
|
||||
Return None if there are no such restrictions.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def num_dispatchers(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def output_is_reduced(self) -> bool:
|
||||
"""
|
||||
Indicates whether or not the output of finalize is reduced across all
|
||||
ranks.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
################################################################################
|
||||
# Experts
|
||||
################################################################################
|
||||
|
||||
|
||||
# TODO: add supported activations method (return string)
|
||||
class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
"""
|
||||
An abstract base class for the [Permute-Experts-Unpermute] step described
|
||||
above.
|
||||
"""
|
||||
|
||||
class FusedMoEExperts(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
@@ -419,6 +493,10 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.num_dispatchers = num_dispatchers
|
||||
|
||||
@staticmethod
|
||||
def is_monolithic() -> bool:
|
||||
raise NotImplementedError("Implemented by subclasses.")
|
||||
|
||||
@property
|
||||
def expects_unquantized_inputs(self) -> bool:
|
||||
"""
|
||||
@@ -439,49 +517,6 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def moe_problem_size(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
) -> tuple[int, int, int, int, int]:
|
||||
"""
|
||||
Extract the MoE problem size from the given tensor arguments:
|
||||
- a: The hidden states, input to the MoE layer.
|
||||
- w1: The first set of expert weights.
|
||||
- w2: The second set of expert weights.
|
||||
- topk_ids: The topk ids.
|
||||
|
||||
Note: extracting the problem shape from the weight and activation
|
||||
tensors is not obvious. It needs to be done this way specifically
|
||||
due to subtle issues with particular kernels, e.g. the int4 kernels
|
||||
divide the trailing dimension by two, so it's not "correct" to
|
||||
extract N or K from the trailing dimension of w1 or w2. Similarly,
|
||||
some kernels transpose the weights, so this needs to be kept in mind.
|
||||
|
||||
Note: This implementation covers most cases. However, if experts
|
||||
require a specialized implementation, like MarlinExperts, they are free
|
||||
to override this function.
|
||||
"""
|
||||
assert w1.dim() == 3 and w2.dim() == 3
|
||||
E, N, _ = w1.size()
|
||||
K = a1.size(-1)
|
||||
|
||||
if a1.dim() == 2:
|
||||
# Make sure we are using the correct a1 (pre-permute).
|
||||
assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}"
|
||||
M = a1.size(0)
|
||||
else:
|
||||
assert a1.dim() == 3
|
||||
assert a1.size(0) == E, f"{a1.size(0)} == {E}"
|
||||
M = a1.size(1) # This is max_num_tokens
|
||||
|
||||
assert topk_ids.dim() == 2
|
||||
topk = topk_ids.size(1)
|
||||
|
||||
return E, M, N, K, topk
|
||||
|
||||
#
|
||||
# Various helpers for registering support for various features.
|
||||
# Used by the oracle to select a particular kernel for a deployment.
|
||||
@@ -489,7 +524,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
|
||||
@staticmethod
|
||||
def is_supported_config(
|
||||
cls: type["FusedMoEPermuteExpertsUnpermute"],
|
||||
cls: type["FusedMoEExperts"],
|
||||
moe_config: FusedMoEConfig,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
@@ -512,6 +547,21 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
return False, _make_reason(
|
||||
f"parallel config {moe_config.moe_parallel_config}"
|
||||
)
|
||||
elif not cls._supports_routing_method(
|
||||
moe_config.routing_method, weight_key, activation_key
|
||||
):
|
||||
return False, _make_reason(f"routing method {moe_config.routing_method}")
|
||||
elif not cls._supports_router_logits_dtype(
|
||||
moe_config.router_logits_dtype,
|
||||
moe_config.routing_method,
|
||||
):
|
||||
return False, _make_reason(
|
||||
f"router logits dtype {moe_config.router_logits_dtype}"
|
||||
)
|
||||
elif not cls._supports_shape(moe_config.hidden_dim):
|
||||
return False, _make_reason(
|
||||
f"{moe_config.hidden_dim} hidden dim is not supported"
|
||||
)
|
||||
elif activation_format != cls.activation_format():
|
||||
return False, _make_reason(f"{activation_format.value} activation format")
|
||||
return True, None
|
||||
@@ -554,10 +604,48 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
@abstractmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
"""
|
||||
Whether the kernel supports deployment in expert parallel.
|
||||
Whether the kernel supports deployment in particular parallel config.
|
||||
|
||||
Can be overriden if a kernel does not support EP, SP or some other
|
||||
configuration.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def _supports_routing_method(
|
||||
routing_method: RoutingMethodType,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""
|
||||
Whether the kernel supports a routing method (e.g. GroupedTopK).
|
||||
|
||||
Can be overriden by monolithic kernels that execute the router
|
||||
in addition to the experts if certain routers are not supported.
|
||||
"""
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _supports_router_logits_dtype(
|
||||
router_logits_dtype: torch.dtype | None,
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
"""
|
||||
Whether a kernel supports a particular dtype for router logits input.
|
||||
|
||||
Can be overriden by monolithic kernels that execute the router
|
||||
in addition to the experts if certain dtypes are not supported.
|
||||
"""
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _supports_shape(hidden_dim: int) -> bool:
|
||||
"""
|
||||
Whether a kernel supports a particular shape. Can be overridden if a kernel
|
||||
has specific shape requirements.
|
||||
"""
|
||||
return True
|
||||
|
||||
#
|
||||
# Various helpers for accessing quantization parameters from the
|
||||
# quant_config.
|
||||
@@ -654,6 +742,65 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
"""
|
||||
return False
|
||||
|
||||
def enable_chunking(self):
|
||||
return (
|
||||
envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and self.supports_chunking()
|
||||
)
|
||||
|
||||
|
||||
class FusedMoEExpertsModular(FusedMoEExperts):
|
||||
"""
|
||||
An abstract base class for the [Permute-Experts-Unpermute] step described
|
||||
above.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def is_monolithic() -> bool:
|
||||
return False
|
||||
|
||||
def moe_problem_size(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
) -> tuple[int, int, int, int, int]:
|
||||
"""
|
||||
Extract the MoE problem size from the given tensor arguments:
|
||||
- a: The hidden states, input to the MoE layer.
|
||||
- w1: The first set of expert weights.
|
||||
- w2: The second set of expert weights.
|
||||
- topk_ids: The topk ids.
|
||||
|
||||
Note: extracting the problem shape from the weight and activation
|
||||
tensors is not obvious. It needs to be done this way specifically
|
||||
due to subtle issues with particular kernels, e.g. the int4 kernels
|
||||
divide the trailing dimension by two, so it's not "correct" to
|
||||
extract N or K from the trailing dimension of w1 or w2. Similarly,
|
||||
some kernels transpose the weights, so this needs to be kept in mind.
|
||||
|
||||
Note: This implementation covers most cases. However, if experts
|
||||
require a specialized implementation, like MarlinExperts, they are free
|
||||
to override this function.
|
||||
"""
|
||||
assert w1.dim() == 3 and w2.dim() == 3
|
||||
E, N, _ = w1.size()
|
||||
K = a1.size(-1)
|
||||
|
||||
if a1.dim() == 2:
|
||||
# Make sure we are using the correct a1 (pre-permute).
|
||||
assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}"
|
||||
M = a1.size(0)
|
||||
else:
|
||||
assert a1.dim() == 3
|
||||
assert a1.size(0) == E, f"{a1.size(0)} == {E}"
|
||||
M = a1.size(1) # This is max_num_tokens
|
||||
|
||||
assert topk_ids.dim() == 2
|
||||
topk = topk_ids.size(1)
|
||||
|
||||
return E, M, N, K, topk
|
||||
|
||||
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
|
||||
"""
|
||||
Workspace type: The dtype to use for the workspace tensors.
|
||||
@@ -726,11 +873,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
) -> None:
|
||||
apply_moe_activation(activation, output, input)
|
||||
|
||||
def enable_chunking(self):
|
||||
return (
|
||||
envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and self.supports_chunking()
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def finalize_weight_and_reduce_impl(self) -> TopKWeightAndReduce:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -791,6 +934,67 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FusedMoEExpertsMonolithic(FusedMoEExperts):
|
||||
"""
|
||||
An abstract base class for the [Permute-Experts-Unpermute] step described
|
||||
above, but with the monolithic interface (accepts router logits
|
||||
rather than topk ids and weights).
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _supports_routing_method(
|
||||
routing_method: RoutingMethodType,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""
|
||||
Whether the kernel supports a routing method (e.g. GroupedTopK).
|
||||
|
||||
Monolithic kernels should explicitly opt-in to support.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def _supports_router_logits_dtype(
|
||||
router_logits_dtype: torch.dtype | None,
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
"""
|
||||
Whether the kernel supports a dtype for router logits.
|
||||
|
||||
Modular kernels should opt-in to support.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def is_monolithic() -> bool:
|
||||
return True
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group: int | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
routed_scaling_factor: float | None = None,
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Same as apply(), except uses router_logits as opposed
|
||||
to the topk_ids and topk_weights. This is useful for kernels
|
||||
with fused router and fused_experts (e.g. FLASHINFER_TRTLLM).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _slice_scales(
|
||||
scales: torch.Tensor | None, start: int, end: int
|
||||
) -> torch.Tensor | None:
|
||||
@@ -802,75 +1006,32 @@ def _slice_scales(
|
||||
return None
|
||||
|
||||
|
||||
################################################################################
|
||||
# Kernel
|
||||
################################################################################
|
||||
|
||||
|
||||
@final
|
||||
class FusedMoEModularKernel(torch.nn.Module):
|
||||
"""
|
||||
This class combines a FusedMoEPrepareAndFinalize instance and
|
||||
a FusedMoEPermuteExpertsUnpermute to provide an interface that
|
||||
is compatible with the `fused_experts` function in fused_moe.py.
|
||||
|
||||
It takes care of managing any required scratch space.
|
||||
|
||||
Note: Instances of this class should only be used for a single model
|
||||
layer due to any layer specific state that may be used by the component
|
||||
objects.
|
||||
"""
|
||||
|
||||
class FusedMoEKernelModularImpl:
|
||||
def __init__(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
fused_experts: FusedMoEPermuteExpertsUnpermute,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalizeModular,
|
||||
fused_experts: FusedMoEExpertsModular,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
moe_parallel_config: FusedMoEParallelConfig | None = None,
|
||||
inplace: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.prepare_finalize = prepare_finalize
|
||||
self.fused_experts = fused_experts
|
||||
self.shared_experts = shared_experts
|
||||
self.moe_parallel_config = moe_parallel_config
|
||||
self.inplace = inplace
|
||||
|
||||
# prefer an explicit FusedMoEParallelConfig when available (from
|
||||
# FusedMoE layers / tests).
|
||||
# if not provided, assume this kernel is
|
||||
# running in a non-DP+EP context
|
||||
self.moe_parallel_config: FusedMoEParallelConfig | None = moe_parallel_config
|
||||
self.is_dp_ep = (
|
||||
moe_parallel_config is not None
|
||||
and moe_parallel_config.dp_size > 1
|
||||
and moe_parallel_config.use_ep
|
||||
)
|
||||
|
||||
self._post_init_setup()
|
||||
assert (
|
||||
prepare_finalize.activation_format == fused_experts.activation_format()
|
||||
), (
|
||||
f"{prepare_finalize.__class__.__name__}."
|
||||
f"{prepare_finalize.activation_format} == "
|
||||
f"{fused_experts.__class__.__name__}."
|
||||
f"{fused_experts.activation_format()}"
|
||||
)
|
||||
|
||||
def _post_init_setup(self):
|
||||
"""
|
||||
Resolve any leftover setup dependencies between self.prepare_finalize
|
||||
and self.fused_experts here.
|
||||
"""
|
||||
self.prepare_finalize.post_init_setup(self.fused_experts)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
"""
|
||||
A flag indicating whether or not this class supports expert maps.
|
||||
"""
|
||||
return self.fused_experts.supports_expert_map()
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
"""
|
||||
Indicates whether or not the output of fused MoE kernel
|
||||
is reduced across all ranks.
|
||||
"""
|
||||
return self.prepare_finalize.output_is_reduced()
|
||||
|
||||
def _chunk_info(self, M: int) -> tuple[int, int]:
|
||||
"""
|
||||
Compute number of chunks and chunk size for given M.
|
||||
@@ -919,7 +1080,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
workspace_dtype = self.fused_experts.workspace_dtype(out_dtype)
|
||||
|
||||
# Force worst-case allocation in profiling run for
|
||||
# "mk.FusedMoEModularKernel.Standard" formats where this is only bounded
|
||||
# "mk.FusedMoEKernel.Standard" formats where this is only bounded
|
||||
# by `VLLM_FUSED_MOE_CHUNK_SIZE` and may not be seen during profiling with
|
||||
# DP+EP due to the random token routing.
|
||||
is_profile_run = (
|
||||
@@ -1172,9 +1333,9 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
# This happens when none of the tokens from the all2all reach this
|
||||
# EP rank. Also, note that this is only relevant for CUDAGraph
|
||||
# incompatible all2all kernels like the DeepEP high-throughput
|
||||
# kernels. CUDAGraph compatible all2all kernels like the pplx
|
||||
# kernels and the DeepEP low-latency kernels are always batched
|
||||
# and can never run into the tensor.numel() == 0 case.
|
||||
# kernels. CUDAGraph compatible all2all kernels like the DeepEP
|
||||
# low-latency kernels are always batched and can never run into
|
||||
# the tensor.numel() == 0 case.
|
||||
if M_full == 0:
|
||||
assert num_chunks == 0
|
||||
workspace13 = None
|
||||
@@ -1313,19 +1474,18 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
assert shared_output is not None
|
||||
return shared_output, output
|
||||
|
||||
def forward(
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
activation: MoEActivation = MoEActivation.SILU,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
shared_experts_input: torch.Tensor | None = None,
|
||||
**kwargs
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets
|
||||
@@ -1335,8 +1495,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
- hidden_states: (torch.Tensor): The input tensor to the MoE layer.
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- topk_weights (torch.Tensor): The topk weights applied at the end of
|
||||
the layer.
|
||||
- topk_weights (torch.Tensor): The topk weights applied at the end of the layer.
|
||||
- topk_ids (torch.Tensor): A map of row to expert id.
|
||||
- activation (MoEActivation): The activation function to apply after the first
|
||||
MoE layer.
|
||||
@@ -1355,23 +1514,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
from .fused_moe import fused_experts as fused_experts_kernel
|
||||
|
||||
result = fused_experts_kernel(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
quant_config=kwargs.get("quant_config", None),
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
|
||||
return result
|
||||
if self.inplace:
|
||||
assert self.shared_experts is None
|
||||
assert not disable_inplace()
|
||||
@@ -1417,3 +1559,206 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
apply_router_weight_on_input,
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
class FusedMoEKernelMonolithicImpl:
|
||||
def __init__(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalizeMonolithic,
|
||||
fused_experts: FusedMoEExpertsMonolithic,
|
||||
):
|
||||
self.prepare_finalize = prepare_finalize
|
||||
self.fused_experts = fused_experts
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group: int | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
routed_scaling_factor: float | None = None,
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Same as forward(), except uses router_logits as opposed
|
||||
to the topk_ids and topk_weights. This is used for kernels
|
||||
that have fused router + experts (e.g. FLASHINFER_TRTLLM).
|
||||
"""
|
||||
|
||||
# TODO(rob): add inplace support.
|
||||
a1q, a1q_scale, router_logits = self.prepare_finalize.prepare(
|
||||
hidden_states,
|
||||
router_logits=router_logits,
|
||||
quant_config=self.fused_experts.quant_config,
|
||||
defer_input_quant=self.fused_experts.expects_unquantized_inputs,
|
||||
)
|
||||
|
||||
fused_out = self.fused_experts.apply(
|
||||
hidden_states=a1q,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
router_logits=router_logits,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
a1q_scale=a1q_scale,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group=num_expert_group,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
topk_group=topk_group,
|
||||
)
|
||||
|
||||
output = self.prepare_finalize.finalize(fused_out)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@final
|
||||
class FusedMoEKernel:
|
||||
def __init__(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
fused_experts: FusedMoEExperts,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
moe_parallel_config: FusedMoEParallelConfig | None = None,
|
||||
inplace: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.shared_experts = shared_experts # NOTE: check if we can remove
|
||||
|
||||
# Initialize the implementation (monolithic or modular).
|
||||
self.impl: FusedMoEKernelModularImpl | FusedMoEKernelMonolithicImpl
|
||||
if isinstance(
|
||||
prepare_finalize, FusedMoEPrepareAndFinalizeModular
|
||||
) and isinstance(fused_experts, FusedMoEExpertsModular):
|
||||
self.impl = FusedMoEKernelModularImpl(
|
||||
prepare_finalize,
|
||||
fused_experts,
|
||||
shared_experts,
|
||||
moe_parallel_config,
|
||||
inplace,
|
||||
)
|
||||
|
||||
elif isinstance(
|
||||
prepare_finalize, FusedMoEPrepareAndFinalizeMonolithic
|
||||
) and isinstance(fused_experts, FusedMoEExpertsMonolithic):
|
||||
assert shared_experts is None
|
||||
assert not inplace
|
||||
self.impl = FusedMoEKernelMonolithicImpl(
|
||||
prepare_finalize,
|
||||
fused_experts,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"prepare_finalize and fused_experts must both be either monolithic "
|
||||
f"or non-monolithic but got {prepare_finalize.__class__.__name__} "
|
||||
f"and {fused_experts.__class__.__name__}"
|
||||
)
|
||||
|
||||
self._post_init_setup()
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return isinstance(self.impl, FusedMoEKernelMonolithicImpl)
|
||||
|
||||
@property
|
||||
def prepare_finalize(self) -> FusedMoEPrepareAndFinalize:
|
||||
return self.impl.prepare_finalize
|
||||
|
||||
@property
|
||||
def fused_experts(self) -> FusedMoEExperts:
|
||||
return self.impl.fused_experts
|
||||
|
||||
def _post_init_setup(self):
|
||||
"""
|
||||
Resolve any leftover setup dependencies between self.prepare_finalize
|
||||
and self.fused_experts here.
|
||||
"""
|
||||
self.prepare_finalize.post_init_setup(self.impl.fused_experts)
|
||||
assert (
|
||||
self.prepare_finalize.activation_format
|
||||
== self.fused_experts.activation_format()
|
||||
)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
"""
|
||||
A flag indicating whether or not this class supports expert maps.
|
||||
"""
|
||||
return self.fused_experts.supports_expert_map()
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
"""
|
||||
Indicates whether or not the output of fused MoE kernel
|
||||
is reduced across all ranks.
|
||||
"""
|
||||
return self.prepare_finalize.output_is_reduced()
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group: int | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
routed_scaling_factor: float | None = None,
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert isinstance(self.impl, FusedMoEKernelMonolithicImpl)
|
||||
return self.impl.apply(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
router_logits=router_logits,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
num_expert_group=num_expert_group,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
topk_group=topk_group,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
shared_experts_input: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert isinstance(self.impl, FusedMoEKernelModularImpl)
|
||||
return self.impl.apply(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
@@ -12,7 +12,7 @@ from vllm.platforms import current_platform
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
|
||||
"""
|
||||
Prepare/Finalize using MoRI kernels.
|
||||
"""
|
||||
|
||||
@@ -18,13 +18,9 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config,
|
||||
fp8_w8a16_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import (
|
||||
is_supported_config_trtllm_fp8,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
FlashinferMoeBackend,
|
||||
get_flashinfer_moe_backend,
|
||||
make_fp8_moe_alpha_scales_for_fi,
|
||||
prepare_fp8_moe_layer_for_fi,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
@@ -103,9 +99,13 @@ def _get_priority_backends(
|
||||
|
||||
def backend_to_kernel_cls(
|
||||
backend: Fp8MoeBackend,
|
||||
) -> type[mk.FusedMoEPermuteExpertsUnpermute]:
|
||||
) -> type[mk.FusedMoEExperts]:
|
||||
if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
raise NotImplementedError
|
||||
from vllm.model_executor.layers.fused_moe.experts.trtllm_fp8_moe import ( # noqa: E501
|
||||
TrtLlmFp8Experts,
|
||||
)
|
||||
|
||||
return TrtLlmFp8Experts
|
||||
|
||||
elif backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
@@ -205,13 +205,11 @@ def select_fp8_moe_backend(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
allow_vllm_cutlass: bool = False,
|
||||
) -> tuple[Fp8MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]:
|
||||
) -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts] | None]:
|
||||
"""
|
||||
Select the primary FP8 MoE backend
|
||||
Note: Shape-specific fallbacks may still occur at runtime.
|
||||
"""
|
||||
k_cls: type[mk.FusedMoEPermuteExpertsUnpermute] | None = None
|
||||
|
||||
if config.is_lora_enabled:
|
||||
return Fp8MoeBackend.TRITON, backend_to_kernel_cls(Fp8MoeBackend.TRITON)
|
||||
|
||||
@@ -252,7 +250,7 @@ def select_fp8_moe_backend(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
) -> tuple[Fp8MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]:
|
||||
) -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts]]:
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls, config, weight_key, activation_key, activation_format
|
||||
@@ -287,16 +285,6 @@ def select_fp8_moe_backend(
|
||||
"vLLM CUTLASS FP8 MoE backend is disabled for this configuration."
|
||||
)
|
||||
|
||||
# Handle FLASHINFER_TRTLLM specially (no kernel class).
|
||||
if requested_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
supported, reason = is_supported_config_trtllm_fp8(
|
||||
config, weight_key, activation_key, activation_format
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(requested_backend))
|
||||
return requested_backend, None
|
||||
raise ValueError(_make_log_unsupported(requested_backend, reason))
|
||||
|
||||
return _return_or_raise(
|
||||
requested_backend, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
@@ -311,51 +299,32 @@ def select_fp8_moe_backend(
|
||||
elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"):
|
||||
# If user is explicit about backend, validate it.
|
||||
fi_backend = get_flashinfer_moe_backend()
|
||||
|
||||
if fi_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
backend = Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
supported, reason = is_supported_config_trtllm_fp8(
|
||||
config, weight_key, activation_key, activation_format
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend))
|
||||
return backend, None
|
||||
else:
|
||||
raise ValueError(_make_log_unsupported(backend, reason))
|
||||
|
||||
elif fi_backend == FlashinferMoeBackend.CUTLASS:
|
||||
if fi_backend == FlashinferMoeBackend.CUTLASS:
|
||||
backend = Fp8MoeBackend.FLASHINFER_CUTLASS
|
||||
return _return_or_raise(
|
||||
backend, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
|
||||
elif fi_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
backend = Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
else:
|
||||
assert fi_backend == FlashinferMoeBackend.CUTEDSL
|
||||
raise ValueError("FlashInfer MaskedGEMM not supported for FP8")
|
||||
|
||||
raise ValueError(
|
||||
f"FlashInfer MOE backend {fi_backend} does not support FP8 MoE."
|
||||
)
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
return _return_or_raise(
|
||||
backend, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
else:
|
||||
# If the user is not explicit about the backend, try both.
|
||||
for backend in [
|
||||
Fp8MoeBackend.FLASHINFER_TRTLLM,
|
||||
Fp8MoeBackend.FLASHINFER_CUTLASS,
|
||||
]:
|
||||
if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
k_cls = None
|
||||
supported, reason = is_supported_config_trtllm_fp8(
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
else:
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls,
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls,
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend), scope="local")
|
||||
@@ -408,23 +377,14 @@ def select_fp8_moe_backend(
|
||||
|
||||
# Select kernels in order of backend.
|
||||
for backend in AVAILABLE_BACKENDS:
|
||||
if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
k_cls = None
|
||||
supported, reason = is_supported_config_trtllm_fp8(
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
else:
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls,
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls,
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend), scope="local")
|
||||
@@ -510,7 +470,7 @@ def make_fp8_moe_quant_config(
|
||||
block_shape: list[int] | None = None,
|
||||
per_act_token_quant: bool = False,
|
||||
per_out_ch_quant: bool = False,
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Create FusedMoEQuantConfig for the specified FP8 Backend.
|
||||
The FusedMoEQuantConfig holds the scales that are used
|
||||
@@ -523,9 +483,6 @@ def make_fp8_moe_quant_config(
|
||||
In a future PR, we will have this function should be
|
||||
a method of the modular kernel itself.
|
||||
"""
|
||||
# TRTLLM does not use Modular Kernel abstraction yet.
|
||||
if fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
return None
|
||||
|
||||
# MARLIN is mixed precision W8A16 config.
|
||||
if fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
@@ -539,12 +496,6 @@ def make_fp8_moe_quant_config(
|
||||
# (alpha = w_scale * a_scale) and inverse a2 scale.
|
||||
if fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS and block_shape is None:
|
||||
assert a1_scale is not None and a2_scale is not None
|
||||
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
|
||||
w1_scale,
|
||||
a1_scale,
|
||||
w2_scale,
|
||||
a2_scale,
|
||||
)
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
@@ -552,8 +503,8 @@ def make_fp8_moe_quant_config(
|
||||
a2_scale=a2_scale,
|
||||
a1_gscale=(1.0 / a1_scale),
|
||||
a2_gscale=(1.0 / a2_scale),
|
||||
g1_alphas=g1_alphas,
|
||||
g2_alphas=g2_alphas,
|
||||
g1_alphas=(w1_scale * a1_scale).squeeze(),
|
||||
g2_alphas=(w2_scale * a2_scale).squeeze(),
|
||||
)
|
||||
# All other backends use normal config.
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
@@ -570,17 +521,18 @@ def make_fp8_moe_quant_config(
|
||||
def make_fp8_moe_kernel(
|
||||
moe_quant_config: FusedMoEQuantConfig,
|
||||
moe_config: FusedMoEConfig,
|
||||
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
|
||||
experts_cls: type[mk.FusedMoEExperts],
|
||||
fp8_backend: Fp8MoeBackend,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
) -> mk.FusedMoEModularKernel:
|
||||
) -> mk.FusedMoEKernel:
|
||||
# Create Prepare/Finalize.
|
||||
prepare_finalize = maybe_make_prepare_finalize(
|
||||
moe=moe_config,
|
||||
quant_config=moe_quant_config,
|
||||
routing_tables=routing_tables,
|
||||
allow_new_interface=True,
|
||||
use_monolithic=issubclass(experts_cls, mk.FusedMoEExpertsMonolithic),
|
||||
)
|
||||
assert prepare_finalize is not None
|
||||
|
||||
@@ -603,9 +555,9 @@ def make_fp8_moe_kernel(
|
||||
)
|
||||
|
||||
# NOTE(rob): we only want the mk to control the shared_expert
|
||||
# if using all2all (for SBO). bnell is making this explict in
|
||||
# if using all2all (for SBO). bnell is making this explicit in
|
||||
# the new MoE runner class.
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
kernel = mk.FusedMoEKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
shared_experts=(
|
||||
|
||||
@@ -19,7 +19,6 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
nvfp4_w4a16_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
is_supported_config_trtllm,
|
||||
prepare_nvfp4_moe_layer_for_fi_or_cutlass,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
@@ -67,39 +66,46 @@ def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool:
|
||||
|
||||
def backend_to_kernel_cls(
|
||||
backend: NvFp4MoeBackend,
|
||||
) -> type[mk.FusedMoEPermuteExpertsUnpermute]:
|
||||
) -> list[type[mk.FusedMoEExperts]]:
|
||||
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
|
||||
raise NotImplementedError(
|
||||
"FLASHINFER_TRTLLM doesn't support Modular Kernel Interface"
|
||||
from vllm.model_executor.layers.fused_moe.experts.trtllm_nvfp4_moe import (
|
||||
TrtLlmNvFp4ExpertsModular,
|
||||
TrtLlmNvFp4ExpertsMonolithic,
|
||||
)
|
||||
|
||||
# NOTE: prefer Monolthic > Modular, so return Monolithic first.
|
||||
return [
|
||||
TrtLlmNvFp4ExpertsMonolithic,
|
||||
TrtLlmNvFp4ExpertsModular,
|
||||
]
|
||||
|
||||
elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts,
|
||||
)
|
||||
|
||||
return FlashInferExperts
|
||||
return [FlashInferExperts]
|
||||
|
||||
elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL:
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (
|
||||
FlashInferCuteDSLExperts,
|
||||
)
|
||||
|
||||
return FlashInferCuteDSLExperts
|
||||
return [FlashInferCuteDSLExperts]
|
||||
|
||||
elif backend == NvFp4MoeBackend.VLLM_CUTLASS:
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
CutlassExpertsFp4,
|
||||
)
|
||||
|
||||
return CutlassExpertsFp4
|
||||
return [CutlassExpertsFp4]
|
||||
|
||||
elif backend == NvFp4MoeBackend.MARLIN:
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
MarlinExperts,
|
||||
)
|
||||
|
||||
return MarlinExperts
|
||||
return [MarlinExperts]
|
||||
else:
|
||||
raise ValueError(f"Unknown NvFP4 MoE backend: {backend.value}")
|
||||
|
||||
@@ -125,7 +131,7 @@ def select_nvfp4_moe_backend(
|
||||
config: FusedMoEConfig,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]:
|
||||
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEExperts]]:
|
||||
"""
|
||||
Select the primary NvFP4 MoE backend
|
||||
Note: Shape-specific fallbacks may still occur at runtime.
|
||||
@@ -143,10 +149,7 @@ def select_nvfp4_moe_backend(
|
||||
# NOTE(rob): this is kind of a hack. We need to peak into
|
||||
# the prepare-finalize selection to determine if we are using
|
||||
# the batched or standard expert format.
|
||||
use_batched = (
|
||||
config.moe_parallel_config.use_deepep_ll_kernels
|
||||
or config.moe_parallel_config.use_pplx_kernels
|
||||
)
|
||||
use_batched = config.moe_parallel_config.use_deepep_ll_kernels
|
||||
activation_format = (
|
||||
mk.FusedMoEActivationFormat.BatchedExperts
|
||||
if use_batched
|
||||
@@ -178,29 +181,21 @@ def select_nvfp4_moe_backend(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]:
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend))
|
||||
return backend, k_cls
|
||||
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEExperts]]:
|
||||
for k_cls in backend_to_kernel_cls(backend):
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend))
|
||||
return backend, k_cls
|
||||
|
||||
raise ValueError(_make_log_unsupported(backend, reason))
|
||||
|
||||
# Handle explicit moe_backend from user.
|
||||
runner_backend = config.moe_backend
|
||||
if runner_backend != "auto":
|
||||
requested_backend = map_nvfp4_backend(runner_backend)
|
||||
if requested_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
|
||||
supported, reason = is_supported_config_trtllm(
|
||||
config, weight_key, activation_key, activation_format
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(requested_backend))
|
||||
return requested_backend, None
|
||||
raise ValueError(_make_log_unsupported(requested_backend, reason))
|
||||
|
||||
return _return_or_raise(
|
||||
requested_backend, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
@@ -213,36 +208,14 @@ def select_nvfp4_moe_backend(
|
||||
|
||||
elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"):
|
||||
# If user is explicit about backend, validate it.
|
||||
fi_backend = get_flashinfer_moe_backend()
|
||||
|
||||
if fi_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
backend = NvFp4MoeBackend.FLASHINFER_TRTLLM
|
||||
supported, reason = is_supported_config_trtllm(
|
||||
config, weight_key, activation_key, activation_format
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend))
|
||||
return backend, None
|
||||
else:
|
||||
raise ValueError(_make_log_unsupported(backend, reason))
|
||||
else:
|
||||
backend = fi_2_vllm_backend_map[fi_backend]
|
||||
return _return_or_raise(
|
||||
backend, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
backend = fi_2_vllm_backend_map[get_flashinfer_moe_backend()]
|
||||
return _return_or_raise(
|
||||
backend, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
else:
|
||||
# If the user is not explicit about the backend, try each.
|
||||
for backend in FLASHINFER_NVFP4_MOE_BACKENDS:
|
||||
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
|
||||
k_cls = None
|
||||
supported, reason = is_supported_config_trtllm(
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
else:
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
for k_cls in backend_to_kernel_cls(backend):
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls,
|
||||
config,
|
||||
@@ -250,13 +223,13 @@ def select_nvfp4_moe_backend(
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend), scope="local")
|
||||
return backend, None
|
||||
else:
|
||||
logger.debug_once(
|
||||
_make_log_unsupported(backend, reason), scope="local"
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend), scope="local")
|
||||
return backend, k_cls
|
||||
else:
|
||||
logger.debug_once(
|
||||
_make_log_unsupported(backend, reason), scope="local"
|
||||
)
|
||||
|
||||
raise NotImplementedError(
|
||||
"Found VLLM_USE_FLASHINFER_MOE_FP4=1, but no "
|
||||
@@ -271,16 +244,7 @@ def select_nvfp4_moe_backend(
|
||||
|
||||
# Select kernels in order of backend.
|
||||
for backend in AVAILABLE_BACKENDS:
|
||||
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
|
||||
k_cls = None # type: ignore[assignment]
|
||||
supported, reason = is_supported_config_trtllm(
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
else:
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
for k_cls in backend_to_kernel_cls(backend):
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls,
|
||||
config,
|
||||
@@ -289,11 +253,11 @@ def select_nvfp4_moe_backend(
|
||||
activation_format,
|
||||
)
|
||||
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend), scope="local")
|
||||
return backend, k_cls
|
||||
else:
|
||||
logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend), scope="local")
|
||||
return backend, k_cls
|
||||
else:
|
||||
logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
|
||||
|
||||
raise NotImplementedError(
|
||||
"No NvFp4 MoE backend supports the deployment configuration."
|
||||
@@ -401,12 +365,8 @@ def make_nvfp4_moe_quant_config(
|
||||
w2_scale_2: torch.Tensor,
|
||||
a13_scale: torch.Tensor,
|
||||
a2_scale: torch.Tensor,
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
UNSUPPORTED = [NvFp4MoeBackend.FLASHINFER_TRTLLM]
|
||||
if backend in UNSUPPORTED:
|
||||
return None
|
||||
|
||||
elif backend == NvFp4MoeBackend.MARLIN:
|
||||
) -> FusedMoEQuantConfig:
|
||||
if backend == NvFp4MoeBackend.MARLIN:
|
||||
return nvfp4_w4a16_moe_quant_config(
|
||||
g1_alphas=w13_scale_2,
|
||||
g2_alphas=w2_scale_2,
|
||||
@@ -423,22 +383,27 @@ def make_nvfp4_moe_quant_config(
|
||||
a2_gscale=(1.0 / a2_scale),
|
||||
w1_scale=w13_scale,
|
||||
w2_scale=w2_scale,
|
||||
# NOTE(rob): this is a hack until the MoE kernels
|
||||
# create their own quant configs. TRTLLM kernel
|
||||
# does not accept swizzled input quant scales.
|
||||
is_nvfp4_scale_swizzled=(backend != NvFp4MoeBackend.FLASHINFER_TRTLLM),
|
||||
)
|
||||
|
||||
|
||||
def make_nvfp4_moe_kernel(
|
||||
moe_quant_config: FusedMoEQuantConfig,
|
||||
moe_config: FusedMoEConfig,
|
||||
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
|
||||
experts_cls: type[mk.FusedMoEExperts],
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
) -> mk.FusedMoEModularKernel:
|
||||
) -> mk.FusedMoEKernel:
|
||||
# Create Prepare/Finalize.
|
||||
prepare_finalize = maybe_make_prepare_finalize(
|
||||
moe=moe_config,
|
||||
quant_config=moe_quant_config,
|
||||
routing_tables=routing_tables,
|
||||
allow_new_interface=True,
|
||||
use_monolithic=issubclass(experts_cls, mk.FusedMoEExpertsMonolithic),
|
||||
)
|
||||
assert prepare_finalize is not None
|
||||
|
||||
@@ -461,9 +426,9 @@ def make_nvfp4_moe_kernel(
|
||||
)
|
||||
|
||||
# NOTE(rob): we only want the mk to control the shared_expert
|
||||
# if using all2all (for SBO). bnell is making this explict in
|
||||
# if using all2all (for SBO). bnell is making this explicit in
|
||||
# the new MoE runner class.
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
kernel = mk.FusedMoEKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
shared_experts=(
|
||||
|
||||
@@ -19,7 +19,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import (
|
||||
is_supported_config_trtllm_bf16,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
MoEPrepareAndFinalizeNoDPEPModular,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
swap_w13_to_w31,
|
||||
@@ -209,7 +209,7 @@ def make_unquantized_moe_kernel(
|
||||
backend: UnquantizedMoeBackend,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
moe_config: FusedMoEConfig,
|
||||
) -> mk.FusedMoEModularKernel | None:
|
||||
) -> mk.FusedMoEKernel | None:
|
||||
if backend in UNSUPPORTED_BACKEND:
|
||||
return None
|
||||
|
||||
@@ -218,8 +218,8 @@ def make_unquantized_moe_kernel(
|
||||
FlashInferExperts,
|
||||
)
|
||||
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
kernel = mk.FusedMoEKernel(
|
||||
MoEPrepareAndFinalizeNoDPEPModular(),
|
||||
FlashInferExperts(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
@@ -232,8 +232,8 @@ def make_unquantized_moe_kernel(
|
||||
AiterExperts,
|
||||
)
|
||||
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
kernel = mk.FusedMoEKernel(
|
||||
MoEPrepareAndFinalizeNoDPEPModular(),
|
||||
AiterExperts(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
@@ -241,25 +241,6 @@ def make_unquantized_moe_kernel(
|
||||
inplace=not moe_config.disable_inplace,
|
||||
)
|
||||
elif backend == UnquantizedMoeBackend.TRITON:
|
||||
from vllm.model_executor.layers.fused_moe import TritonExperts
|
||||
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
TritonExperts(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
inplace=not moe_config.disable_inplace,
|
||||
)
|
||||
elif backend == UnquantizedMoeBackend.XPU:
|
||||
from vllm.model_executor.layers.fused_moe import XPUExperts
|
||||
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
XPUExperts(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
inplace=not moe_config.disable_inplace,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
kernel = fused_experts
|
||||
return kernel
|
||||
|
||||
@@ -1,373 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
|
||||
import pplx_kernels as pplx
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_validate_scale_shape,
|
||||
moe_kernel_quantize_input,
|
||||
)
|
||||
from vllm.utils.math_utils import cdiv, round_up
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def pplx_hidden_dim_scale_bytes(
|
||||
max_num_tokens: int,
|
||||
hidden_dim: int,
|
||||
in_dtype: torch.dtype,
|
||||
quant_dtype: torch.dtype | str | None,
|
||||
per_act_token_quant: bool,
|
||||
block_shape: list[int] | None,
|
||||
):
|
||||
# All pplx byte sizes must be 16-byte aligned.
|
||||
align = 16
|
||||
|
||||
# For blocked per token: set to
|
||||
# cdiv(hidden_dim, block_size) * sizeof(float32)
|
||||
# For per-token: set to 4 * sizeof(float32) (x4 for alignment)
|
||||
if quant_dtype is not None:
|
||||
assert isinstance(quant_dtype, torch.dtype)
|
||||
assert quant_dtype.itemsize == 1
|
||||
hidden_dim_bytes = hidden_dim * quant_dtype.itemsize
|
||||
elem_size = torch.float32.itemsize
|
||||
|
||||
if per_act_token_quant:
|
||||
# per-token (M x 1)
|
||||
assert block_shape is None
|
||||
hidden_scale_bytes = elem_size
|
||||
elif block_shape is not None:
|
||||
# per-group (M x K_tiles)
|
||||
block_size = block_shape[1]
|
||||
num_blocks = cdiv(hidden_dim, block_size)
|
||||
hidden_scale_bytes = num_blocks * elem_size
|
||||
else:
|
||||
# per-tensor (1 x 1)
|
||||
hidden_scale_bytes = elem_size
|
||||
else:
|
||||
hidden_dim_bytes = hidden_dim * in_dtype.itemsize
|
||||
hidden_scale_bytes = 0
|
||||
|
||||
return (
|
||||
round_up(hidden_dim_bytes, align),
|
||||
round_up(hidden_scale_bytes, align),
|
||||
)
|
||||
|
||||
|
||||
class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
"""PPLX-based prepare and finalize for expert parallelism."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
a2a: pplx.AllToAll,
|
||||
max_num_tokens: int,
|
||||
num_local_experts: int,
|
||||
num_dispatchers: int,
|
||||
):
|
||||
super().__init__()
|
||||
assert max_num_tokens > 0
|
||||
assert num_local_experts > 0
|
||||
self.a2a = a2a
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.num_local_experts = num_local_experts
|
||||
self.num_dispatchers_ = num_dispatchers
|
||||
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.BatchedExperts
|
||||
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
return self.max_num_tokens
|
||||
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return torch.uint32
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
return self.num_dispatchers_
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
return True
|
||||
|
||||
def supports_async(self) -> bool:
|
||||
return True
|
||||
|
||||
def prepare_async(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> tuple[Callable, mk.ReceiverType]:
|
||||
if defer_input_quant:
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not support defer_input_quant=True. "
|
||||
"Please select an MoE kernel that accepts quantized inputs."
|
||||
)
|
||||
|
||||
num_tokens = a1.size(0) # M
|
||||
hidden_dim = a1.size(-1) # K
|
||||
|
||||
assert topk_ids.size(0) == num_tokens
|
||||
# expert_map should be None because with expert map, -1 id is used for
|
||||
# non-local token; this causes error when casting ids to the
|
||||
# topk_indices_dtype() int32
|
||||
#
|
||||
if expert_map is not None:
|
||||
logger.warning_once(
|
||||
"The PPLX backend does not support expert mapping. "
|
||||
"The provided `expert_map` will be ignored."
|
||||
)
|
||||
expert_map = None # noqa: F841
|
||||
|
||||
# Is this always going to be a1.device?
|
||||
device = a1.device
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
# TODO: this only works for topK=1, will need to update for topK>1
|
||||
assert topk == 1, (
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
)
|
||||
a1 = a1 * topk_weights.to(a1.dtype)
|
||||
|
||||
repeat_cols = 4
|
||||
repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0)
|
||||
# TODO(bnell): always pass quant_config.a1_scale?
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
(None if quant_config.per_act_token_quant else quant_config.a1_scale),
|
||||
quant_dtype=quant_config.quant_dtype,
|
||||
per_act_token_quant=quant_config.per_act_token_quant,
|
||||
block_shape=quant_config.block_shape,
|
||||
)
|
||||
|
||||
_validate_scale_shape(
|
||||
a1q, a1q_scale, quant_config.per_act_token_quant, quant_config.block_shape
|
||||
)
|
||||
|
||||
orig_a_scale_block_shape: int | None = None
|
||||
|
||||
if a1q_scale is not None:
|
||||
scalar_scales = a1q_scale.numel() == 1
|
||||
|
||||
# pplx requires 2-d scales even for scalar scales
|
||||
if a1q_scale.dim() <= 1:
|
||||
assert scalar_scales
|
||||
a1q_scale = a1q_scale.view(1, 1)
|
||||
|
||||
orig_a_scale_block_shape = a1q_scale.shape[-1]
|
||||
|
||||
if not quant_config.is_block_quantized:
|
||||
# TODO (bnell): use group_broadcast instead?
|
||||
a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols)
|
||||
|
||||
assert a1q_scale is None or a1q_scale.ndim == 2, (
|
||||
f"{0 if a1q_scale is None else (a1q_scale.ndim, a1q_scale.shape)}"
|
||||
)
|
||||
|
||||
expert_num_tokens = torch.empty(
|
||||
self.num_local_experts,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
expert_x = torch.empty(
|
||||
(
|
||||
self.num_local_experts,
|
||||
self.max_num_tokens * self.num_dispatchers(),
|
||||
hidden_dim,
|
||||
),
|
||||
dtype=a1q.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
expert_x_scale: torch.Tensor | None = None
|
||||
if a1q.dtype.itemsize == 1:
|
||||
if quant_config.is_per_act_token:
|
||||
# (M x 1) -> (E x M x K)
|
||||
final_dim = expert_x.size(2)
|
||||
elif quant_config.is_per_tensor:
|
||||
# (1 x 1) -> (E x 1 x 1)
|
||||
final_dim = 1
|
||||
else:
|
||||
# (M x K_tiles) -> (E x M x K_tiles)
|
||||
assert quant_config.block_shape is not None
|
||||
num_blocks = cdiv(expert_x.size(2), quant_config.block_shape[1])
|
||||
final_dim = num_blocks
|
||||
|
||||
expert_x_scale_shape = (
|
||||
self.num_local_experts,
|
||||
expert_x.size(1),
|
||||
round_up(final_dim, 4), # round up for alignment
|
||||
)
|
||||
|
||||
expert_x_scale = torch.empty(
|
||||
expert_x_scale_shape,
|
||||
dtype=torch.float32,
|
||||
device=expert_x.device,
|
||||
)
|
||||
|
||||
# This argument is optional, defaults to indices.size(0)
|
||||
# There's not much point setting this unless it is != indices.size(0)
|
||||
bound_m: torch.Tensor | None = None
|
||||
|
||||
self.a2a.dispatch(
|
||||
out_expert_num_tokens=expert_num_tokens,
|
||||
out_expert_x=expert_x,
|
||||
out_expert_x_scale=expert_x_scale,
|
||||
dp_x=a1q,
|
||||
dp_x_scale=a1q_scale,
|
||||
indices=topk_ids,
|
||||
bound_m=bound_m,
|
||||
do_send=True,
|
||||
do_recv=False,
|
||||
)
|
||||
|
||||
hook = lambda: self.a2a.dispatch(
|
||||
out_expert_num_tokens=expert_num_tokens,
|
||||
out_expert_x=expert_x,
|
||||
out_expert_x_scale=expert_x_scale,
|
||||
dp_x=a1q,
|
||||
dp_x_scale=a1q_scale,
|
||||
indices=topk_ids,
|
||||
bound_m=bound_m,
|
||||
do_send=False,
|
||||
do_recv=True,
|
||||
)
|
||||
|
||||
return (
|
||||
hook,
|
||||
lambda: self._receiver(
|
||||
expert_num_tokens,
|
||||
expert_x,
|
||||
expert_x_scale,
|
||||
orig_a_scale_block_shape,
|
||||
),
|
||||
)
|
||||
|
||||
def _receiver(
|
||||
self,
|
||||
expert_num_tokens: torch.Tensor,
|
||||
expert_x: torch.Tensor,
|
||||
expert_x_scale: torch.Tensor | None,
|
||||
orig_a_scale_block_shape: int | None,
|
||||
) -> mk.PrepareResultType:
|
||||
if expert_x_scale is not None:
|
||||
expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape]
|
||||
assert expert_x_scale.ndim == 3
|
||||
|
||||
expert_tokens_meta = mk.ExpertTokensMetadata(
|
||||
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None
|
||||
)
|
||||
|
||||
return expert_x, expert_x_scale, expert_tokens_meta, None, None
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> mk.PrepareResultType:
|
||||
hook, receiver = self.prepare_async(
|
||||
a1,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
num_experts,
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
quant_config,
|
||||
defer_input_quant=defer_input_quant,
|
||||
)
|
||||
hook()
|
||||
return receiver()
|
||||
|
||||
def finalize_async(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> Callable:
|
||||
assert isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate), (
|
||||
"Weight application and reduction happens in the combine kernel."
|
||||
)
|
||||
|
||||
# This argument is optional
|
||||
# There's not much point setting this unless it is != topk_ids.size(0)
|
||||
bound_m: torch.Tensor | None = None
|
||||
|
||||
# TODO (bnell): fails in test_pplx_moe.py, figure out what's going on
|
||||
# num_tokens = output.size(0) # M
|
||||
# assert topk_ids.size(0) == num_tokens, (
|
||||
# f"{topk_ids.size(0)} == {num_tokens}")
|
||||
assert topk_ids.size() == topk_weights.size(), (
|
||||
f"{topk_ids.size()} == {topk_weights.size()}"
|
||||
)
|
||||
assert output.size(0) <= self.max_num_tokens, (
|
||||
f"{output.size(0)} <= {self.max_num_tokens}"
|
||||
)
|
||||
assert output.size(1) == fused_expert_output.size(-1)
|
||||
|
||||
# Set weights to 1 if we did them in dispatch. This is hacky.
|
||||
if apply_router_weight_on_input:
|
||||
topk_weights = torch.ones_like(topk_weights)
|
||||
|
||||
topk_ids_u32 = topk_ids.view(dtype=torch.uint32)
|
||||
|
||||
self.a2a.combine(
|
||||
out_tokens=output,
|
||||
indices=topk_ids_u32,
|
||||
weights=topk_weights,
|
||||
expert_y=fused_expert_output,
|
||||
bound_m=bound_m,
|
||||
do_send=True,
|
||||
do_recv=False,
|
||||
)
|
||||
|
||||
return lambda: self.a2a.combine(
|
||||
out_tokens=output,
|
||||
indices=topk_ids_u32,
|
||||
weights=topk_weights,
|
||||
expert_y=fused_expert_output,
|
||||
bound_m=bound_m,
|
||||
do_send=False,
|
||||
do_recv=True,
|
||||
)
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
receiver = self.finalize_async(
|
||||
output,
|
||||
fused_expert_output,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
apply_router_weight_on_input,
|
||||
weight_and_reduce_impl,
|
||||
)
|
||||
receiver()
|
||||
@@ -1,209 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceContiguous,
|
||||
TopKWeightAndReduceDelegate,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
|
||||
|
||||
|
||||
class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize):
|
||||
def __init__(
|
||||
self,
|
||||
is_sequence_parallel: bool = False,
|
||||
num_dispatchers: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_sequence_parallel = is_sequence_parallel
|
||||
self._num_dispatchers = num_dispatchers
|
||||
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return None
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
return self._num_dispatchers
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
return False
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> mk.PrepareResultType:
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
assert topk == 1, (
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
)
|
||||
# Note: do not use inplace for shared experts overlap
|
||||
a1 = a1 * topk_weights.to(a1.dtype)
|
||||
|
||||
# Defer input quantization to the MoE kernel.
|
||||
use_nvfp4 = quant_config.use_nvfp4_w4a4
|
||||
if defer_input_quant:
|
||||
a1q = a1
|
||||
a1q_scale = None
|
||||
else:
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
quant_config.a1_gscale if use_nvfp4 else quant_config.a1_scale,
|
||||
quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant,
|
||||
quant_config.block_shape,
|
||||
# NOTE: swizzling pads the scales to multiple of 128
|
||||
# which makes the scales tensor different shape than
|
||||
# the hidden states, breaking the A2A kernel. So, we
|
||||
# delay the swizzling until after the A2A.
|
||||
is_fp4_scale_swizzled=False,
|
||||
)
|
||||
|
||||
# Skip gathering scales if we have static quantization
|
||||
# (the scale is a scalar, replicated on all ranks) or
|
||||
# if quantization is deferred.
|
||||
skip_gather_scales = a1q_scale is None or a1q_scale.ndim == 0
|
||||
scales = None if skip_gather_scales else [a1q_scale]
|
||||
|
||||
res = get_ep_group().dispatch(
|
||||
a1q,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
is_sequence_parallel=self.is_sequence_parallel,
|
||||
extra_tensors=scales,
|
||||
)
|
||||
if skip_gather_scales:
|
||||
a1q, topk_weights, topk_ids = res
|
||||
else:
|
||||
a1q, topk_weights, topk_ids, scales = res
|
||||
assert scales is not None and len(scales) == 1
|
||||
a1q_scale = scales[0]
|
||||
if quant_config.quant_dtype == "nvfp4":
|
||||
assert a1q_scale is not None
|
||||
if a1q_scale.element_size() == 1:
|
||||
a1q_scale = a1q_scale.view(torch.uint8)
|
||||
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
||||
|
||||
return a1q, a1q_scale, None, topk_ids, topk_weights
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
|
||||
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
|
||||
|
||||
out = weight_and_reduce_impl.apply(
|
||||
output=None,
|
||||
fused_expert_output=fused_expert_output,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
output.copy_(
|
||||
get_ep_group().combine(out, is_sequence_parallel=self.is_sequence_parallel)
|
||||
)
|
||||
|
||||
|
||||
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||
"""MoE prepare and finalize without expert parallelism."""
|
||||
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return None
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
return 1
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
return False
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> mk.PrepareResultType:
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
# TODO: this only works for topK=1, will need to update for topK>1
|
||||
assert topk == 1, (
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
)
|
||||
# Note: do not use inplace for shared experts overlap
|
||||
a1 = a1 * topk_weights.to(a1.dtype)
|
||||
|
||||
# Defer input quant to moe kernel for backends (e.g. AITER, FI)
|
||||
# which use a single kernel call for quant + experts.
|
||||
if defer_input_quant:
|
||||
return a1, None, None, None, None
|
||||
|
||||
input_sf = (
|
||||
quant_config.a1_gscale
|
||||
if quant_config.use_nvfp4_w4a4
|
||||
else quant_config.a1_scale
|
||||
)
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
input_sf,
|
||||
quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant,
|
||||
quant_config.block_shape,
|
||||
)
|
||||
|
||||
return a1q, a1q_scale, None, None, None
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
|
||||
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
|
||||
weight_and_reduce_impl.apply(
|
||||
output=output,
|
||||
fused_expert_output=fused_expert_output,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
@@ -0,0 +1,22 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize.naive_dp_ep import (
|
||||
MoEPrepareAndFinalizeNaiveDPEPModular,
|
||||
MoEPrepareAndFinalizeNaiveDPEPMonolithic,
|
||||
make_moe_prepare_and_finalize_naive_dp_ep,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize.no_dp_ep import (
|
||||
MoEPrepareAndFinalizeNoDPEPModular,
|
||||
MoEPrepareAndFinalizeNoDPEPMonolithic,
|
||||
make_moe_prepare_and_finalize_no_dp_ep,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MoEPrepareAndFinalizeNaiveDPEPMonolithic",
|
||||
"MoEPrepareAndFinalizeNaiveDPEPModular",
|
||||
"make_moe_prepare_and_finalize_naive_dp_ep",
|
||||
"MoEPrepareAndFinalizeNoDPEPMonolithic",
|
||||
"MoEPrepareAndFinalizeNoDPEPModular",
|
||||
"make_moe_prepare_and_finalize_no_dp_ep",
|
||||
]
|
||||
@@ -0,0 +1,253 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceContiguous,
|
||||
TopKWeightAndReduceDelegate,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
|
||||
|
||||
|
||||
def _quantize_and_setup_dispatch(
|
||||
a1: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> tuple[torch.Tensor, list[torch.Tensor] | None]:
|
||||
# Defer input quantization to the MoE kernel.
|
||||
if defer_input_quant:
|
||||
a1q = a1
|
||||
a1q_scale = None
|
||||
else:
|
||||
input_sf = (
|
||||
quant_config.a1_gscale
|
||||
if quant_config.use_nvfp4_w4a4
|
||||
else quant_config.a1_scale
|
||||
)
|
||||
|
||||
# NOTE: swizzling pads the scales to multiple of 128
|
||||
# which makes the scales tensor different shape than
|
||||
# the hidden states, breaking the A2A kernel. So, we
|
||||
# delay the swizzling until after the A2A.
|
||||
a1q, a1q_scale = a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
input_sf,
|
||||
quant_dtype=quant_config.quant_dtype,
|
||||
per_act_token_quant=quant_config.per_act_token_quant,
|
||||
block_shape=quant_config.block_shape,
|
||||
is_fp4_scale_swizzled=False,
|
||||
)
|
||||
|
||||
# Skip gathering scales if we have static quantization
|
||||
# (the scale is a scalar, replicated on all ranks) or
|
||||
# if quantization is deferred.
|
||||
skip_gather_scales = a1q_scale is None or a1q_scale.ndim == 0
|
||||
scales = None if skip_gather_scales else [a1q_scale]
|
||||
|
||||
return a1q, scales
|
||||
|
||||
|
||||
def _unwrap_scale_and_prepare_for_moe(
|
||||
scales: list[torch.Tensor] | None,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> torch.Tensor:
|
||||
assert scales is not None and len(scales) == 1
|
||||
a1q_scale = scales[0]
|
||||
# Apply swizzling after a2a if the MoE kernel needs it.
|
||||
if quant_config.quant_dtype == "nvfp4" and quant_config.is_nvfp4_scale_swizzled:
|
||||
assert a1q_scale is not None
|
||||
if a1q_scale.element_size() == 1:
|
||||
a1q_scale = a1q_scale.view(torch.uint8)
|
||||
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
||||
|
||||
return a1q_scale
|
||||
|
||||
|
||||
class MoEPrepareAndFinalizeNaiveDPEPModular(mk.FusedMoEPrepareAndFinalizeModular):
|
||||
"""
|
||||
Naive Prepare/Finalize for Dp/Ep case for Modular Kernels.
|
||||
|
||||
Uses Torch AR/RS or AR for dispatch/combine operations, applied
|
||||
to the topk weights and ids.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_sequence_parallel: bool = False,
|
||||
num_dispatchers: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_sequence_parallel = is_sequence_parallel
|
||||
self._num_dispatchers = num_dispatchers
|
||||
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return None
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
return self._num_dispatchers
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
return False
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> mk.PrepareResultType:
|
||||
"""Quantize and Dispatch Topk Weights and Topk Ids."""
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
assert topk == 1, (
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
)
|
||||
# Note: do not use inplace for shared experts overlap
|
||||
a1 = a1 * topk_weights.to(a1.dtype)
|
||||
|
||||
a1q, scales = _quantize_and_setup_dispatch(a1, quant_config, defer_input_quant)
|
||||
|
||||
res = get_ep_group().dispatch(
|
||||
a1q,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
is_sequence_parallel=self.is_sequence_parallel,
|
||||
extra_tensors=scales,
|
||||
)
|
||||
|
||||
if scales is None:
|
||||
a1q, topk_weights, topk_ids = res
|
||||
a1q_scale = None
|
||||
else:
|
||||
a1q, topk_weights, topk_ids, scales = res
|
||||
a1q_scale = _unwrap_scale_and_prepare_for_moe(scales, quant_config)
|
||||
|
||||
return a1q, a1q_scale, None, topk_ids, topk_weights
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
|
||||
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
|
||||
|
||||
out = weight_and_reduce_impl.apply(
|
||||
output=None,
|
||||
fused_expert_output=fused_expert_output,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
output.copy_(
|
||||
get_ep_group().combine(out, is_sequence_parallel=self.is_sequence_parallel)
|
||||
)
|
||||
|
||||
|
||||
class MoEPrepareAndFinalizeNaiveDPEPMonolithic(mk.FusedMoEPrepareAndFinalizeMonolithic):
|
||||
"""
|
||||
Naive Prepare/Finalize for Dp/Ep case for Modular Kernels.
|
||||
|
||||
Uses Torch AR/RS or AR for dispatch/combine operations, applied
|
||||
to the router logits (the MoE kernel runs the router internally).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_sequence_parallel: bool = False,
|
||||
num_dispatchers: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_sequence_parallel = is_sequence_parallel
|
||||
self._num_dispatchers = num_dispatchers
|
||||
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return None
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
return self._num_dispatchers
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
return False
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> mk.PrepareMonolithicResultType:
|
||||
"""Quantize and Dispatch Router Logits."""
|
||||
|
||||
a1q, scales = _quantize_and_setup_dispatch(a1, quant_config, defer_input_quant)
|
||||
|
||||
res = get_ep_group().dispatch_router_logits(
|
||||
a1q,
|
||||
router_logits,
|
||||
is_sequence_parallel=self.is_sequence_parallel,
|
||||
extra_tensors=scales,
|
||||
)
|
||||
|
||||
if scales is None:
|
||||
a1q, router_logits = res
|
||||
a1q_scale = None
|
||||
else:
|
||||
a1q, router_logits, scales = res
|
||||
a1q_scale = _unwrap_scale_and_prepare_for_moe(scales, quant_config)
|
||||
|
||||
return a1q, a1q_scale, router_logits
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
fused_expert_output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
out = get_ep_group().combine(
|
||||
fused_expert_output, is_sequence_parallel=self.is_sequence_parallel
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def make_moe_prepare_and_finalize_naive_dp_ep(
|
||||
use_monolithic: bool,
|
||||
is_sequence_parallel: bool = False,
|
||||
num_dispatchers: int = 1,
|
||||
) -> MoEPrepareAndFinalizeNaiveDPEPModular | MoEPrepareAndFinalizeNaiveDPEPMonolithic:
|
||||
return (
|
||||
MoEPrepareAndFinalizeNaiveDPEPMonolithic(
|
||||
is_sequence_parallel=is_sequence_parallel,
|
||||
num_dispatchers=num_dispatchers,
|
||||
)
|
||||
if use_monolithic
|
||||
else MoEPrepareAndFinalizeNaiveDPEPModular(
|
||||
is_sequence_parallel=is_sequence_parallel,
|
||||
num_dispatchers=num_dispatchers,
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,141 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceContiguous,
|
||||
TopKWeightAndReduceDelegate,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
|
||||
|
||||
def _quantize_input(
|
||||
a1: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
# Defer input quant to moe kernel for backends (e.g. AITER, FI)
|
||||
# which use a single kernel call for quant + experts.
|
||||
if defer_input_quant:
|
||||
return a1, None
|
||||
|
||||
input_sf = (
|
||||
quant_config.a1_gscale if quant_config.use_nvfp4_w4a4 else quant_config.a1_scale
|
||||
)
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
input_sf,
|
||||
quant_dtype=quant_config.quant_dtype,
|
||||
per_act_token_quant=quant_config.per_act_token_quant,
|
||||
block_shape=quant_config.block_shape,
|
||||
is_fp4_scale_swizzled=quant_config.is_nvfp4_scale_swizzled,
|
||||
)
|
||||
|
||||
return a1q, a1q_scale
|
||||
|
||||
|
||||
class MoEPrepareAndFinalizeNoDPEPModular(mk.FusedMoEPrepareAndFinalizeModular):
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return None
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
return 1
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
return False
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> mk.PrepareResultType:
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
# TODO: this only works for topK=1, will need to update for topK>1
|
||||
assert topk == 1, (
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
)
|
||||
# Note: do not use inplace for shared experts overlap
|
||||
a1 = a1 * topk_weights.to(a1.dtype)
|
||||
|
||||
a1q, a1q_scale = _quantize_input(a1, quant_config, defer_input_quant)
|
||||
|
||||
return a1q, a1q_scale, None, None, None
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
|
||||
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
|
||||
weight_and_reduce_impl.apply(
|
||||
output=output,
|
||||
fused_expert_output=fused_expert_output,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
|
||||
class MoEPrepareAndFinalizeNoDPEPMonolithic(mk.FusedMoEPrepareAndFinalizeMonolithic):
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return None
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
return 1
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
return False
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> mk.PrepareMonolithicResultType:
|
||||
a1q, a1q_scale = _quantize_input(a1, quant_config, defer_input_quant)
|
||||
return a1q, a1q_scale, router_logits
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
fused_expert_output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return fused_expert_output
|
||||
|
||||
|
||||
def make_moe_prepare_and_finalize_no_dp_ep(
|
||||
use_monolithic: bool,
|
||||
) -> MoEPrepareAndFinalizeNoDPEPModular | MoEPrepareAndFinalizeNoDPEPMonolithic:
|
||||
return (
|
||||
MoEPrepareAndFinalizeNoDPEPMonolithic()
|
||||
if use_monolithic
|
||||
else MoEPrepareAndFinalizeNoDPEPModular()
|
||||
)
|
||||
@@ -292,7 +292,7 @@ def rocm_aiter_fused_experts(
|
||||
)
|
||||
|
||||
|
||||
class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class AiterExperts(mk.FusedMoEExpertsModular):
|
||||
@property
|
||||
def expects_unquantized_inputs(self) -> bool:
|
||||
return True
|
||||
|
||||
@@ -20,6 +20,7 @@ import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -132,7 +133,7 @@ class RoutedExpertsCapturer:
|
||||
self._device_buffer = torch.zeros(
|
||||
(max_num_batched_tokens, num_layers, num_experts_per_tok),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
device=current_platform.device_type,
|
||||
)
|
||||
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ if current_platform.is_cuda_alike():
|
||||
|
||||
# TODO(bowen): When using `FusedMoEModularKernel`, this
|
||||
# can be done in a more unified way, since
|
||||
# `FusedMoEPrepareAndFinalize` will return the expert
|
||||
# `FusedMoEPrepareAndFinalizeModular` will return the expert
|
||||
# token count, in some cases directly from the kernel.
|
||||
# However, now there are many code paths not using
|
||||
# the modular kernel, e.g. calling `fused_experts`,
|
||||
@@ -175,6 +175,7 @@ class BaseRouter(FusedMoERouter):
|
||||
topk_ids = topk_ids.to(dtype=indices_type)
|
||||
|
||||
assert topk_ids.dtype == indices_type or indices_type is None
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
return topk_ids
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -31,7 +31,7 @@ def vllm_topk_softmax(
|
||||
token_expert_indices,
|
||||
gating_output,
|
||||
renormalize,
|
||||
e_score_correction_bias,
|
||||
e_score_correction_bias
|
||||
)
|
||||
|
||||
return topk_weights, topk_indices
|
||||
@@ -85,13 +85,14 @@ def fused_topk_bias(
|
||||
token_expert_indices = torch.empty(
|
||||
M, topk, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
gating_output_float = gating_output.float() # TODO(woosuk): Optimize this.
|
||||
|
||||
if scoring_func == "softmax":
|
||||
topk_weights, topk_ids = vllm_topk_softmax(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
token_expert_indices,
|
||||
gating_output,
|
||||
gating_output_float,
|
||||
renormalize,
|
||||
e_score_correction_bias,
|
||||
)
|
||||
@@ -186,7 +187,7 @@ class FusedTopKBiasRouter(BaseRouter):
|
||||
indices_type=indices_type,
|
||||
)
|
||||
|
||||
if self.routed_scaling_factor != 1.0:
|
||||
topk_weights *= self.routed_scaling_factor
|
||||
# if self.routed_scaling_factor != 1.0:
|
||||
# topk_weights *= self.routed_scaling_factor
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
@@ -26,8 +26,9 @@ def vllm_topk_softmax(
|
||||
topk_indices,
|
||||
token_expert_indices,
|
||||
gating_output,
|
||||
renormalize,
|
||||
)
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return topk_weights, topk_indices
|
||||
|
||||
@@ -90,13 +91,14 @@ def fused_topk(
|
||||
token_expert_indices = torch.empty(
|
||||
M, topk, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
gating_output_float = gating_output.float()
|
||||
|
||||
if scoring_func == "softmax":
|
||||
topk_func = dispatch_topk_softmax_func(
|
||||
use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled()
|
||||
)
|
||||
topk_weights, topk_ids = topk_func(
|
||||
topk_weights, topk_ids, token_expert_indices, gating_output.float(), renormalize
|
||||
topk_weights, topk_ids, token_expert_indices, gating_output_float, renormalize
|
||||
)
|
||||
|
||||
return topk_weights, topk_ids, token_expert_indices
|
||||
@@ -105,7 +107,7 @@ def fused_topk(
|
||||
use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled()
|
||||
)
|
||||
topk_weights, topk_ids = topk_func(
|
||||
topk_weights, topk_ids, token_expert_indices, gating_output.float(), renormalize
|
||||
topk_weights, topk_ids, token_expert_indices, gating_output_float, renormalize
|
||||
)
|
||||
|
||||
return topk_weights, topk_ids, token_expert_indices
|
||||
|
||||
115
vllm/model_executor/layers/fused_moe/router/gate_linear.py
Normal file
115
vllm/model_executor/layers/fused_moe/router/gate_linear.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.custom_op import PluggableLayer
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@PluggableLayer.register("gate_linear")
|
||||
class GateLinear(ReplicatedLinear):
|
||||
"""MoE gate linear layer with three-tier GEMM dispatch:
|
||||
|
||||
1. DSV3 specialized kernel (SM90+, batch<=16, supported dims)
|
||||
2. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype)
|
||||
3. F.linear via ReplicatedLinear (ultimate fallback)
|
||||
|
||||
The ``out_dtype`` attribute is mutable and can be set after init
|
||||
(e.g. when the required dtype depends on the expert quantization
|
||||
method which is only known later).
|
||||
"""
|
||||
|
||||
# Dimensions supported by the DSV3 specialized kernel
|
||||
DSV3_SUPPORTED_NUM_EXPERTS = [256, 384]
|
||||
DSV3_SUPPORTED_HIDDEN_SIZES = [7168]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = False,
|
||||
out_dtype: torch.dtype | None = None,
|
||||
params_dtype: torch.dtype | None = None,
|
||||
force_fp32_compute: bool = False,
|
||||
prefix: str = "",
|
||||
):
|
||||
is_hopper_or_blackwell = current_platform.is_device_capability(
|
||||
(9, 0)
|
||||
) or current_platform.is_device_capability_family(100)
|
||||
can_use_specialized_kernels = False
|
||||
|
||||
# If fp32 compute is required and no specialized kernel is available,
|
||||
# store weights in fp32 so Tier 3 computes in fp32 natively.
|
||||
if force_fp32_compute and not can_use_specialized_kernels:
|
||||
params_dtype = torch.float32
|
||||
|
||||
super().__init__(
|
||||
input_size,
|
||||
output_size,
|
||||
bias=bias,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=None,
|
||||
prefix=prefix,
|
||||
)
|
||||
self.out_dtype = out_dtype
|
||||
|
||||
# DSV3 specialized kernel eligibility (SM90+, exact dims)
|
||||
self.allow_specialized_router_gemm = can_use_specialized_kernels
|
||||
self.allow_dsv3_router_gemm = (
|
||||
self.allow_specialized_router_gemm
|
||||
and output_size in self.DSV3_SUPPORTED_NUM_EXPERTS
|
||||
and input_size in self.DSV3_SUPPORTED_HIDDEN_SIZES
|
||||
)
|
||||
|
||||
# cuBLAS bf16→fp32 eligibility
|
||||
self.allow_cublas_router_gemm = (
|
||||
self.allow_specialized_router_gemm
|
||||
and self.weight.dtype == torch.bfloat16
|
||||
and self.out_dtype == torch.float32
|
||||
)
|
||||
|
||||
def set_out_dtype(self, out_dtype: torch.dtype) -> None:
|
||||
"""Set output dtype for the router logits after init.
|
||||
|
||||
Useful when the required dtype depends on the expert quantization
|
||||
method which is only known after the gate is constructed.
|
||||
"""
|
||||
if self.out_dtype is not None:
|
||||
raise ValueError("out_dtype has already been set")
|
||||
self.out_dtype = out_dtype
|
||||
|
||||
if (
|
||||
not self.allow_cublas_router_gemm
|
||||
and self.allow_specialized_router_gemm
|
||||
and out_dtype == torch.float32
|
||||
):
|
||||
self.allow_cublas_router_gemm = self.weight.dtype == torch.bfloat16
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor
|
||||
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
|
||||
import vllm._custom_ops as ops
|
||||
|
||||
# Tier 1: DSV3 specialized kernel
|
||||
if self.allow_dsv3_router_gemm and x.shape[0] <= 16:
|
||||
output = ops.dsv3_router_gemm(
|
||||
hidden_states=x,
|
||||
router_weight=self.weight,
|
||||
output_dtype=self.out_dtype,
|
||||
)
|
||||
return output, None
|
||||
|
||||
# Tier 2: cuBLAS bf16→fp32
|
||||
if self.allow_cublas_router_gemm and x.dtype == torch.bfloat16:
|
||||
output = ops.router_gemm_bf16_fp32(x, self.weight)
|
||||
return output, None
|
||||
|
||||
# Tier 3: F.linear (ReplicatedLinear)
|
||||
if self.out_dtype is not None and x.dtype != self.weight.dtype:
|
||||
x = x.to(self.weight.dtype)
|
||||
output, output_bias = super().forward(x)
|
||||
if self.out_dtype is not None and output.dtype != self.out_dtype:
|
||||
output = output.to(self.out_dtype)
|
||||
return output, output_bias
|
||||
@@ -92,77 +92,9 @@ def grouped_topk(
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if (
|
||||
envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK
|
||||
and current_platform.is_cuda()
|
||||
and num_expert_group <= 32
|
||||
and topk <= 32
|
||||
and e_score_correction_bias is not None
|
||||
):
|
||||
return fused_grouped_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=gating_output,
|
||||
topk=topk,
|
||||
renormalize=renormalize,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
|
||||
|
||||
if scoring_func == "softmax":
|
||||
scores = torch.softmax(gating_output, dim=-1)
|
||||
elif scoring_func == "sigmoid":
|
||||
scores = gating_output.sigmoid()
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
num_token = scores.size(0)
|
||||
if e_score_correction_bias is not None:
|
||||
# Store original scores before applying correction bias. We use biased
|
||||
# scores for expert selection but original scores for routing weights
|
||||
original_scores = scores
|
||||
scores = scores + e_score_correction_bias.unsqueeze(0)
|
||||
group_scores = (
|
||||
scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
|
||||
)
|
||||
else:
|
||||
group_scores = (
|
||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||
) # [n, n_group]
|
||||
|
||||
# For batch invariance, use sorted=True to ensure deterministic expert selection
|
||||
use_sorted = vllm_is_batch_invariant()
|
||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
|
||||
1
|
||||
] # [n, top_k_group]
|
||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1)
|
||||
.expand(num_token, num_expert_group, scores.size(-1) // num_expert_group)
|
||||
.reshape(num_token, -1)
|
||||
) # [n, e]
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
|
||||
|
||||
if e_score_correction_bias is not None:
|
||||
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1]
|
||||
# Use original unbiased scores for the routing weights
|
||||
topk_weights = original_scores.gather(1, topk_ids)
|
||||
else:
|
||||
topk_weights, topk_ids = torch.topk(
|
||||
tmp_scores, k=topk, dim=-1, sorted=use_sorted
|
||||
)
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
if routed_scaling_factor != 1.0:
|
||||
topk_weights = topk_weights * routed_scaling_factor
|
||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
from ixformer.inference.functions import moe_grouped_topk as grouped_topk
|
||||
topk_weights, topk_ids = grouped_topk(gating_output, topk, num_expert_group, topk_group, scoring_func, e_score_correction_bias,renormalize = renormalize)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
# --8<-- [start:grouped_topk]
|
||||
@@ -246,7 +178,6 @@ class GroupedTopk(CustomOp):
|
||||
hidden_states, gating_output, e_score_correction_bias
|
||||
)
|
||||
|
||||
from ixformer.inference.functions import moe_grouped_topk as grouped_topk
|
||||
|
||||
class GroupedTopKRouter(BaseRouter):
|
||||
"""Router using grouped top-k routing (e.g., DeepSeekV2/V3)."""
|
||||
@@ -316,8 +247,8 @@ class GroupedTopKRouter(BaseRouter):
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
)
|
||||
if self.routed_scaling_factor != 1.0:
|
||||
topk_weights *= self.routed_scaling_factor
|
||||
# if self.routed_scaling_factor != 1.0:
|
||||
# topk_weights *= self.routed_scaling_factor
|
||||
else:
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
hidden_states=hidden_states,
|
||||
@@ -340,14 +271,14 @@ class GroupedTopKRouter(BaseRouter):
|
||||
grouped_topk_impl = grouped_topk
|
||||
|
||||
topk_weights, topk_ids = grouped_topk_impl(
|
||||
# hidden_states=hidden_states,
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
num_expert_group=self.num_expert_group,
|
||||
topk_group=self.topk_group,
|
||||
scoring_func=self.scoring_func,
|
||||
# routed_scaling_factor=self.routed_scaling_factor,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
)
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ def create_fused_moe_router(
|
||||
# grouped topk + fused topk bias parameters
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
# custom routing paramaters
|
||||
# custom routing parameters
|
||||
custom_routing_function: Callable | None = None,
|
||||
# eplb parameters
|
||||
enable_eplb: bool = False,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -30,6 +31,8 @@ from vllm.model_executor.layers.fused_moe.runner.moe_runner import MoERunner
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import (
|
||||
HAS_OPAQUE_TYPE,
|
||||
ModuleName,
|
||||
aux_stream,
|
||||
current_stream,
|
||||
direct_register_custom_op,
|
||||
@@ -56,13 +59,27 @@ def get_layer_from_name(layer_name: str) -> torch.nn.Module:
|
||||
return forward_context.no_compile_layers[layer_name]
|
||||
|
||||
|
||||
# On torch >= 2.11, layer_name is a hoisted ModuleName opaque object;
|
||||
# on older versions it remains a plain str.
|
||||
if TYPE_CHECKING:
|
||||
from typing import TypeAlias
|
||||
|
||||
_layer_name_type: TypeAlias = str | ModuleName
|
||||
else:
|
||||
_layer_name_type = ModuleName if HAS_OPAQUE_TYPE else str
|
||||
|
||||
|
||||
def _resolve_layer_name(layer_name: str | ModuleName) -> str:
|
||||
return layer_name.value if isinstance(layer_name, ModuleName) else layer_name
|
||||
|
||||
|
||||
def _moe_forward(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
layer_name: str,
|
||||
layer_name: _layer_name_type,
|
||||
) -> torch.Tensor:
|
||||
layer = get_layer_from_name(layer_name)
|
||||
layer = get_layer_from_name(_resolve_layer_name(layer_name))
|
||||
# TODO(bnell): this can be removed after MK migration is complete.
|
||||
layer.ensure_moe_quant_config_init()
|
||||
return layer.runner.forward_impl(
|
||||
@@ -74,7 +91,7 @@ def _moe_forward_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
layer_name: str,
|
||||
layer_name: _layer_name_type,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
@@ -83,9 +100,9 @@ def _moe_forward_shared(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
layer_name: str,
|
||||
layer_name: _layer_name_type,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
layer = get_layer_from_name(layer_name)
|
||||
layer = get_layer_from_name(_resolve_layer_name(layer_name))
|
||||
# TODO(bnell): this can be removed after MK migration is complete.
|
||||
layer.ensure_moe_quant_config_init()
|
||||
return layer.runner.forward_impl(
|
||||
@@ -97,7 +114,7 @@ def _moe_forward_shared_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
layer_name: str,
|
||||
layer_name: _layer_name_type,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Output shapes:
|
||||
# - fused_out: same as hidden_states (routed experts use transformed size)
|
||||
@@ -105,12 +122,10 @@ def _moe_forward_shared_fake(
|
||||
# hidden_states
|
||||
# (For latent MoE: shared experts use original hidden_size, not latent size)
|
||||
fused_out = torch.empty_like(hidden_states)
|
||||
|
||||
if shared_experts_input is not None:
|
||||
shared_out = torch.empty_like(shared_experts_input)
|
||||
else:
|
||||
shared_out = torch.empty_like(hidden_states)
|
||||
|
||||
return shared_out, fused_out
|
||||
|
||||
|
||||
@@ -165,6 +180,7 @@ class DefaultMoERunner(MoERunner):
|
||||
quant_method: FusedMoEMethodBase,
|
||||
reduce_results: bool,
|
||||
enable_dbo: bool,
|
||||
fused_shared_output: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.moe_config = moe_config
|
||||
@@ -175,6 +191,9 @@ class DefaultMoERunner(MoERunner):
|
||||
self.quant_method = quant_method
|
||||
self.reduce_results = reduce_results
|
||||
self.enable_dbo = enable_dbo
|
||||
self.fused_shared_output = fused_shared_output
|
||||
if self.fused_shared_output:
|
||||
assert self.shared_experts is not None, "Shared experts must be provided when fused_shared_output is True."
|
||||
|
||||
# Allow disabling of the separate shared experts stream for
|
||||
# debug purposes.
|
||||
@@ -195,19 +214,19 @@ class DefaultMoERunner(MoERunner):
|
||||
# Needed for string -> FusedMoE layer lookup in custom ops.
|
||||
self.layer_name = layer.layer_name
|
||||
|
||||
if current_platform.is_tpu() or current_platform.is_cpu():
|
||||
# if current_platform.is_tpu() or current_platform.is_cpu():
|
||||
# TODO: Once the OOM issue for the TPU backend is resolved, we
|
||||
# will switch to using the moe_forward custom op.
|
||||
# Note: CPU doesn't require wrapped forward_impl.
|
||||
if self.shared_experts is None:
|
||||
self.moe_forward = _moe_forward
|
||||
else:
|
||||
self.moe_forward = _moe_forward_shared
|
||||
if self.shared_experts is None:
|
||||
self.moe_forward = _moe_forward
|
||||
else:
|
||||
if self.shared_experts is None:
|
||||
self.moe_forward = torch.ops.vllm.moe_forward
|
||||
else:
|
||||
self.moe_forward = torch.ops.vllm.moe_forward_shared
|
||||
self.moe_forward = _moe_forward_shared
|
||||
# else:
|
||||
# if self.shared_experts is None:
|
||||
# self.moe_forward = torch.ops.vllm.moe_forward
|
||||
# else:
|
||||
# self.moe_forward = torch.ops.vllm.moe_forward_shared
|
||||
|
||||
# Chunked all2all staging tensor
|
||||
self.batched_hidden_states: torch.Tensor | None = None
|
||||
@@ -216,8 +235,7 @@ class DefaultMoERunner(MoERunner):
|
||||
@property
|
||||
def use_dp_chunking(self) -> bool:
|
||||
return (
|
||||
self.moe_config.moe_parallel_config.use_pplx_kernels
|
||||
or self.moe_config.moe_parallel_config.use_deepep_ll_kernels
|
||||
self.moe_config.moe_parallel_config.use_deepep_ll_kernels
|
||||
or self.moe_config.moe_parallel_config.use_mori_kernels
|
||||
or self.moe_config.moe_parallel_config.use_fi_all2allv_kernels
|
||||
) and envs.VLLM_ENABLE_MOE_DP_CHUNK
|
||||
@@ -306,8 +324,8 @@ class DefaultMoERunner(MoERunner):
|
||||
"""
|
||||
assert self.quant_method is not None
|
||||
return (
|
||||
self.quant_method.moe_mk is not None
|
||||
and self.quant_method.moe_mk.output_is_reduced()
|
||||
self.quant_method.moe_kernel is not None
|
||||
and self.quant_method.moe_kernel.output_is_reduced()
|
||||
)
|
||||
|
||||
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
|
||||
@@ -362,13 +380,15 @@ class DefaultMoERunner(MoERunner):
|
||||
|
||||
if isinstance(states, tuple):
|
||||
return tuple(
|
||||
[func(s, trunc_size) for s, trunc_size in zip(states, trunc_sizes)]
|
||||
[None if s is None else func(s, trunc_size) for s, trunc_size in zip(states, trunc_sizes)]
|
||||
)
|
||||
else:
|
||||
assert len(trunc_sizes) == 1
|
||||
return func(states, trunc_sizes[0])
|
||||
|
||||
def _encode_layer_name(self) -> str:
|
||||
def _encode_layer_name(self) -> str | ModuleName:
|
||||
if HAS_OPAQUE_TYPE:
|
||||
return ModuleName(self.layer_name)
|
||||
# Can be unavailable or None in unittests
|
||||
if (
|
||||
is_forward_context_available()
|
||||
@@ -624,53 +644,27 @@ class DefaultMoERunner(MoERunner):
|
||||
)
|
||||
|
||||
with sp_ctx:
|
||||
extra_tensors = None
|
||||
if do_naive_dispatch_combine:
|
||||
post_quant_allgather = (
|
||||
self.quant_method is not None
|
||||
and self.moe_config.dp_size > 1
|
||||
and self.moe_config.use_ep
|
||||
and getattr(self.quant_method, "do_post_quant_allgather", False)
|
||||
)
|
||||
if post_quant_allgather:
|
||||
hidden_states_to_dispatch, extra_tensors = (
|
||||
self.quant_method.prepare_dp_allgather_tensor(
|
||||
layer, hidden_states, router_logits
|
||||
)
|
||||
)
|
||||
else:
|
||||
hidden_states_to_dispatch = hidden_states
|
||||
|
||||
dispatch_res = get_ep_group().dispatch_router_logits(
|
||||
hidden_states_to_dispatch,
|
||||
router_logits,
|
||||
self.moe_config.is_sequence_parallel,
|
||||
extra_tensors=extra_tensors,
|
||||
)
|
||||
if extra_tensors is not None:
|
||||
(
|
||||
orig_hidden_states,
|
||||
router_logits,
|
||||
extra_tensors_combined,
|
||||
) = dispatch_res
|
||||
hidden_states_combined = (
|
||||
orig_hidden_states,
|
||||
extra_tensors_combined[0],
|
||||
)
|
||||
else:
|
||||
hidden_states_combined, router_logits = dispatch_res
|
||||
orig_hidden_states = hidden_states_combined
|
||||
else:
|
||||
orig_hidden_states = hidden_states
|
||||
|
||||
# Run shared experts before matrix multiply.
|
||||
# because matrix multiply maybe modify the hidden_states.
|
||||
if has_separate_shared_experts and not use_shared_experts_stream:
|
||||
if has_separate_shared_experts: # and not use_shared_experts_stream:
|
||||
assert self.shared_experts is not None
|
||||
shared_input = (
|
||||
shared_input if shared_input is not None else hidden_states
|
||||
)
|
||||
shared_output = self.shared_experts(shared_input)
|
||||
else:
|
||||
assert self.fused_shared_output == False, "fused_shared_output is only supported when has_separate_shared_experts is True"
|
||||
shared_output = None
|
||||
# For naive dispatch/combine Dp/Ep, dispatch the hidden states and
|
||||
# router logits to all experts.
|
||||
# NOTE: this will be removed once all kernels are migrated into the
|
||||
# MoEKernel framework.
|
||||
if do_naive_dispatch_combine:
|
||||
hidden_states, router_logits = get_ep_group().dispatch_router_logits(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
self.moe_config.is_sequence_parallel,
|
||||
)
|
||||
|
||||
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
|
||||
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
|
||||
@@ -685,42 +679,33 @@ class DefaultMoERunner(MoERunner):
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# TODO(bnell): deal with fp4 flashinfer tuple hidden states hack (#30014).
|
||||
# Figure out nicer way to do this.
|
||||
if do_naive_dispatch_combine:
|
||||
x = hidden_states_combined
|
||||
x_orig = orig_hidden_states
|
||||
else:
|
||||
x = hidden_states
|
||||
x_orig = hidden_states
|
||||
|
||||
# Matrix multiply.
|
||||
if self.quant_method.is_monolithic:
|
||||
final_hidden_states = self.quant_method.apply_monolithic(
|
||||
layer=layer,
|
||||
x=x,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = self.router.select_experts(
|
||||
hidden_states=x_orig,
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=layer,
|
||||
x=x, # The type signture of this is wrong due to the hack.
|
||||
x=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
shared_experts_input=shared_input,
|
||||
router_logits=router_logits,
|
||||
top_k=topk_ids.shape[-1]
|
||||
# Assign the value of shared_experts_output to variable shared_experts_input for fusion
|
||||
shared_experts_input=shared_output if self.fused_shared_output else None,
|
||||
)
|
||||
|
||||
if has_separate_shared_experts:
|
||||
assert self.shared_experts is not None
|
||||
|
||||
if use_shared_experts_stream:
|
||||
assert use_shared_experts_stream == False, "Running shared experts in parallel with the main MoE execution is currently not supported!"
|
||||
# Run shared experts in parallel on a separate stream
|
||||
# NOTE: We start the separate stream here and mark the
|
||||
# sync end point immediately after it is done. This is
|
||||
@@ -733,7 +718,7 @@ class DefaultMoERunner(MoERunner):
|
||||
current_stream().wait_stream(self.shared_experts_stream)
|
||||
|
||||
final_hidden_states = (
|
||||
shared_output,
|
||||
None if self.fused_shared_output else shared_output,
|
||||
final_hidden_states,
|
||||
)
|
||||
|
||||
|
||||
@@ -10,14 +10,15 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
|
||||
class TopKWeightAndReduceDelegate(mk.TopKWeightAndReduce):
|
||||
"""
|
||||
Useful in the case when some FusedMoEPermuteExpertsUnpermute
|
||||
Useful in the case when some FusedMoEExpertsModular
|
||||
implementation does not perform weight application and reduction
|
||||
but cannot address the needs of all the compatible PrepareAndFinalize
|
||||
implementations.
|
||||
For example, BatchedTritonExperts is compatible with both
|
||||
PplxPrepareAndFinalize and BatchedPrepareAndFinalize. PplxPrepareAndFinalize
|
||||
does the weight-application + reduction as part of the pplx combine kernel.
|
||||
But the BatchedPrepareAndFinalize needs an implementation. To facilitate
|
||||
For example, BatchedTritonExperts is compatible with both batched
|
||||
PrepareAndFinalize implementations like DeepEPLLPrepareAndFinalize and
|
||||
BatchedPrepareAndFinalize. Some PrepareAndFinalize implementations do
|
||||
the weight-application + reduction as part of the combine kernel, while
|
||||
BatchedPrepareAndFinalize needs an explicit implementation. To facilitate
|
||||
this case, the BatchedTritonExperts could use TopKWeightAndReduceDelegate
|
||||
so the PrepareAndFinalize implementations could choose how to
|
||||
weight + reduce.
|
||||
@@ -61,7 +62,7 @@ class TopKWeightAndReduceNoOP(mk.TopKWeightAndReduce):
|
||||
if output is None:
|
||||
return fused_expert_output
|
||||
|
||||
# MoEPrepareAndFinalizeNoEP needs the output to be in the `output`
|
||||
# MoEPrepareAndFinalizeNoDPEPModular needs the output to be in the `output`
|
||||
# tensor.
|
||||
assert output.size() == fused_expert_output.size(), (
|
||||
"output shape is expected to match the fused_expert_output shape. "
|
||||
|
||||
@@ -32,8 +32,8 @@ class TritonOrCutlassExperts(FallbackExperts):
|
||||
|
||||
@staticmethod
|
||||
def get_clses() -> tuple[
|
||||
type[mk.FusedMoEPermuteExpertsUnpermute],
|
||||
type[mk.FusedMoEPermuteExpertsUnpermute],
|
||||
type[mk.FusedMoEExpertsModular],
|
||||
type[mk.FusedMoEExpertsModular],
|
||||
]:
|
||||
return (CutlassExpertsFp8, TritonExperts)
|
||||
|
||||
@@ -77,7 +77,7 @@ class TritonOrCutlassExperts(FallbackExperts):
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
) -> mk.FusedMoEExpertsModular:
|
||||
# Small batch fallback for sm100.
|
||||
if self.is_sm100 and hidden_states.shape[0] <= 8:
|
||||
return self.fallback_experts
|
||||
|
||||
@@ -32,8 +32,8 @@ class TritonOrDeepGemmExperts(FallbackExperts):
|
||||
|
||||
@staticmethod
|
||||
def get_clses() -> tuple[
|
||||
type[mk.FusedMoEPermuteExpertsUnpermute],
|
||||
type[mk.FusedMoEPermuteExpertsUnpermute],
|
||||
type[mk.FusedMoEExpertsModular],
|
||||
type[mk.FusedMoEExpertsModular],
|
||||
]:
|
||||
return (DeepGemmExperts, TritonExperts)
|
||||
|
||||
@@ -79,7 +79,7 @@ class TritonOrDeepGemmExperts(FallbackExperts):
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
) -> mk.FusedMoEExpertsModular:
|
||||
if is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2):
|
||||
return self.experts
|
||||
else:
|
||||
|
||||
@@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
)
|
||||
|
||||
|
||||
class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class TrtLlmGenExperts(mk.FusedMoEExpertsModular):
|
||||
"""TensorRT-LLM-based fused MoE expert implementation."""
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -24,8 +24,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEActivationFormat,
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize,
|
||||
FusedMoEExpertsModular,
|
||||
FusedMoEPrepareAndFinalizeModular,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.oracle.unquantized import (
|
||||
UnquantizedMoeBackend,
|
||||
@@ -42,9 +42,9 @@ from vllm.platforms.interface import CpuArchEnum
|
||||
|
||||
if current_platform.is_cuda_alike() or current_platform.is_xpu():
|
||||
from .fused_batched_moe import BatchedTritonExperts
|
||||
from .fused_moe import TritonExperts
|
||||
else:
|
||||
TritonExperts = None # type: ignore
|
||||
fused_experts = None
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -70,7 +70,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
self.rocm_aiter_moe_enabled = (
|
||||
rocm_aiter_ops.is_fused_moe_enabled() and moe.is_act_and_mul
|
||||
)
|
||||
self.kernel: mk.FusedMoEModularKernel | None = None
|
||||
self.kernel: mk.FusedMoEKernel | None = None
|
||||
self._is_monolithic = (
|
||||
current_platform.is_cpu()
|
||||
or self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM
|
||||
@@ -107,7 +107,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> FusedMoEPrepareAndFinalize | None:
|
||||
) -> FusedMoEPrepareAndFinalizeModular | None:
|
||||
if self.unquantized_backend == UnquantizedMoeBackend.AITER:
|
||||
return None
|
||||
else:
|
||||
@@ -115,9 +115,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalizeModular,
|
||||
layer: torch.nn.Module,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
) -> FusedMoEExpertsModular:
|
||||
assert self.moe_quant_config is not None
|
||||
if (
|
||||
prepare_finalize.activation_format
|
||||
@@ -296,16 +296,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
# Assign the value of shared_experts_output to variable shared_experts_input for fusion
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
**kwargs
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.forward(
|
||||
result = self.forward(
|
||||
layer=layer,
|
||||
x=x,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
# not used
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
) * layer.routed_scaling_factor
|
||||
if shared_experts_input is not None:
|
||||
result += shared_experts_input
|
||||
return result
|
||||
|
||||
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
|
||||
if self.moe.has_bias:
|
||||
@@ -333,10 +337,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=layer.activation,
|
||||
quant_config=self.moe_quant_config,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
shared_experts_input=shared_experts_input,
|
||||
expert_map=layer.expert_map
|
||||
)
|
||||
|
||||
def forward_monolithic_cuda(
|
||||
|
||||
@@ -23,7 +23,7 @@ if current_platform.is_xpu():
|
||||
from vllm_xpu_kernels.fused_moe_interface import xpu_fused_moe
|
||||
|
||||
|
||||
class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class XPUExperts(mk.FusedMoEExpertsModular):
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
|
||||
@@ -82,11 +82,12 @@ def fused_add_rms_norm(
|
||||
return rms_norm_batch_invariant(
|
||||
x + residual, weight, variance_epsilon
|
||||
), x + residual
|
||||
ops.fused_add_rms_norm(
|
||||
x, residual = ops.fused_add_rms_norm(
|
||||
x,
|
||||
residual,
|
||||
weight,
|
||||
variance_epsilon,
|
||||
residual_alpha,
|
||||
)
|
||||
return x, residual
|
||||
|
||||
@@ -125,7 +126,7 @@ def dispatch_rocm_rmsnorm_func(
|
||||
return fused_add_rms_norm
|
||||
return rms_norm
|
||||
|
||||
|
||||
|
||||
def rms_norm_qk(
|
||||
input_q: torch.Tensor,
|
||||
input_k: torch.Tensor,
|
||||
@@ -140,11 +141,7 @@ def rms_norm_qk(
|
||||
output_q, output_k, input_q, input_k, weight_q, weight_k, epsilon
|
||||
)
|
||||
return output_q, output_k
|
||||
|
||||
|
||||
def dispatch_cuda_rmsnorm_qk_func() -> callable:
|
||||
return rms_norm_qk
|
||||
|
||||
|
||||
|
||||
@CustomOp.register("rms_norm_qk")
|
||||
class RMSNormQK(CustomOp):
|
||||
@@ -226,8 +223,7 @@ class RMSNormQK(CustomOp):
|
||||
f"[RMSNormQK] Expected input_q and input_k to have same dtype, "
|
||||
f"but got {input_q.dtype} vs {input_k.dtype}"
|
||||
)
|
||||
norm_func = dispatch_cuda_rmsnorm_qk_func()
|
||||
return norm_func(
|
||||
return rms_norm_qk(
|
||||
input_q,
|
||||
input_k,
|
||||
weight_q,
|
||||
@@ -264,7 +260,7 @@ class RMSNormQK(CustomOp):
|
||||
f"eps={self.variance_epsilon}, "
|
||||
)
|
||||
|
||||
|
||||
# --8<-- [start:rms_norm]
|
||||
@CustomOp.register("rms_norm")
|
||||
class RMSNorm(CustomOp):
|
||||
"""Root mean square normalization.
|
||||
@@ -375,7 +371,7 @@ class RMSNorm(CustomOp):
|
||||
# otherwise Inductor eliminates the casts to and from f16,
|
||||
# increasing memory usage (and complicating pattern matching)
|
||||
x = x + residual
|
||||
residual = x.to(orig_dtype).contiguous()
|
||||
residual = x.to(orig_dtype)
|
||||
|
||||
if x.shape[-1] != hidden_size:
|
||||
raise ValueError(
|
||||
@@ -425,6 +421,7 @@ class RMSNorm(CustomOp):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
residual_alpha: float = 1.0,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.variance_size_override is not None:
|
||||
return self.forward_native(x, residual)
|
||||
@@ -499,7 +496,7 @@ class RMSNorm(CustomOp):
|
||||
add_residual = residual is not None
|
||||
if add_residual:
|
||||
return fused_add_rms_norm(
|
||||
x, residual, self.weight.data, self.variance_epsilon
|
||||
x, residual, self.weight.data, self.variance_epsilon,residual_alpha
|
||||
)
|
||||
else:
|
||||
return rms_norm(x, self.weight.data, self.variance_epsilon)
|
||||
@@ -649,6 +646,7 @@ class RMSNormGated(CustomOp):
|
||||
norm_before_gate: bool = False,
|
||||
device: torch.device | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
activation: str = "swish",
|
||||
):
|
||||
"""Initialize RMSNormGated.
|
||||
|
||||
@@ -663,10 +661,12 @@ class RMSNormGated(CustomOp):
|
||||
If False and z is provided: out = norm(x * silu(z))
|
||||
device: Device to create parameters on
|
||||
dtype: Data type for parameters
|
||||
activation: Activation function name for gating
|
||||
"""
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.activation = activation
|
||||
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.register_parameter("bias", None)
|
||||
self.group_size = group_size
|
||||
@@ -693,6 +693,11 @@ class RMSNormGated(CustomOp):
|
||||
- norm_before_gate=True: out = norm(x) * silu(z)
|
||||
- norm_before_gate=False: out = norm(x * silu(z))
|
||||
"""
|
||||
orig_dtype = x.dtype
|
||||
x = x.float()
|
||||
weight = self.weight.float()
|
||||
z = z.float() if z is not None else None
|
||||
|
||||
# Apply gating before normalization if needed
|
||||
if z is not None and not self.norm_before_gate:
|
||||
x = x * F.silu(z)
|
||||
@@ -702,7 +707,7 @@ class RMSNormGated(CustomOp):
|
||||
# Standard RMS norm across the last dimension
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x_normed = x * torch.rsqrt(variance + self.eps)
|
||||
out = x_normed * self.weight
|
||||
out = x_normed * weight
|
||||
else:
|
||||
# Group RMS norm
|
||||
from einops import rearrange
|
||||
@@ -710,13 +715,13 @@ class RMSNormGated(CustomOp):
|
||||
x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size)
|
||||
variance = x_group.pow(2).mean(dim=-1, keepdim=True)
|
||||
x_normed = x_group * torch.rsqrt(variance + self.eps)
|
||||
out = rearrange(x_normed, "... g d -> ... (g d)") * self.weight
|
||||
out = rearrange(x_normed, "... g d -> ... (g d)") * weight
|
||||
|
||||
# Apply gating after normalization if needed
|
||||
if z is not None and self.norm_before_gate:
|
||||
out = out * F.silu(z)
|
||||
|
||||
return out
|
||||
return out.to(orig_dtype)
|
||||
|
||||
def forward_cuda(
|
||||
self, x: torch.Tensor, z: torch.Tensor | None = None
|
||||
@@ -731,6 +736,7 @@ class RMSNormGated(CustomOp):
|
||||
eps=self.eps,
|
||||
group_size=self.group_size,
|
||||
norm_before_gate=self.norm_before_gate,
|
||||
activation=self.activation,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
import ast, re
|
||||
from abc import abstractmethod
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
@@ -16,6 +16,7 @@ from vllm.distributed import (
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import PluggableLayer
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
@@ -28,7 +29,9 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
)
|
||||
from vllm.model_executor.layers.utils import (
|
||||
dispatch_unquantized_gemm,
|
||||
is_layer_moe_router_gate,
|
||||
parse_opt_exclude_layers,
|
||||
weight_quant_l1,
|
||||
weight_quant_l2,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
@@ -41,12 +44,11 @@ from vllm.model_executor.parameter import (
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
import vllm.envs as envs
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
WEIGHT_LOADER_V2_SUPPORTED = [
|
||||
"UnquantizedLinearMethod",
|
||||
"CompressedTensorsLinearMethod",
|
||||
"CompressedTensorsLinearTransformMethod",
|
||||
"AWQMarlinLinearMethod",
|
||||
@@ -66,6 +68,14 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
||||
"PetitNvFp4LinearMethod",
|
||||
]
|
||||
|
||||
LINEAR_OPT_SUPPORTED = [
|
||||
"ColumnParallelLinear",
|
||||
"ReplicatedLinear",
|
||||
"RowParallelLinear",
|
||||
"QKVParallelLinear",
|
||||
"MergedColumnParallelLinear",
|
||||
]
|
||||
|
||||
|
||||
def adjust_marlin_shard(
|
||||
param: Parameter,
|
||||
@@ -135,44 +145,6 @@ def adjust_scalar_to_fused_array(
|
||||
return param_data[shard_id], loaded_weight
|
||||
|
||||
|
||||
# TODO(Isotr0py): We might need a more flexible structure to handle
|
||||
# bitsandbytes shard offsets.
|
||||
def left_shift_bitsandbytes_4bit_shard(
|
||||
bnb_weight_attrs: dict[str, Any],
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""
|
||||
Separate the BitsAndBytes 4-bit shard.
|
||||
|
||||
For example, given bnb weight attributes as below:
|
||||
{
|
||||
'bnb_shard_offsets': array([0, 4, 8, 16]),
|
||||
'bnb_quant_state': {0: ..., 1: ..., 2: ...},
|
||||
}
|
||||
|
||||
The function will return:
|
||||
{
|
||||
'bnb_shard_offsets': array([0, 4]),
|
||||
'bnb_quant_state': {0: ...},
|
||||
}
|
||||
and
|
||||
{
|
||||
'bnb_shard_offsets': array([0, 4, 12]),
|
||||
'bnb_quant_state': {0: ..., 1: ...},
|
||||
}
|
||||
"""
|
||||
shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
|
||||
offset_l = shard_offsets[:2]
|
||||
offset_r = shard_offsets[1:] - shard_offsets[1]
|
||||
quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
|
||||
quant_state_r = {
|
||||
i - 1: bnb_weight_attrs["bnb_quant_state"][i]
|
||||
for i in range(1, len(shard_offsets) - 1)
|
||||
}
|
||||
left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
|
||||
right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
|
||||
return left, right
|
||||
|
||||
|
||||
class LinearMethodBase(QuantizeMethodBase):
|
||||
"""Base class for different (maybe quantized) linear methods."""
|
||||
|
||||
@@ -231,17 +203,11 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
# The weights are not quantized, and they are not sharded.
|
||||
# The amount of memory allocated for the weights is
|
||||
# sum(output_partition_sizes) * input_size_per_partition.
|
||||
weight_loader = extra_weight_attrs.pop("weight_loader")
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
@@ -258,11 +224,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if (
|
||||
vllm_is_batch_invariant()
|
||||
and current_platform.is_cuda_alike()
|
||||
and is_layer_moe_router_gate(getattr(layer, "prefix", ""))
|
||||
):
|
||||
if vllm_is_batch_invariant() and current_platform.is_cuda_alike():
|
||||
return linear_batch_invariant(x, layer.weight, bias)
|
||||
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
|
||||
|
||||
@@ -305,15 +267,31 @@ class LinearBase(PluggableLayer):
|
||||
self.quant_config = quant_config
|
||||
self.prefix = prefix
|
||||
self.allow_fp8_block_shape_mismatch = False
|
||||
if quant_config is None:
|
||||
self.opt_level = envs.VLLM_LINEAR_OPT_LEVEL
|
||||
if parse_opt_exclude_layers(envs.VLLM_LINEAR_SPECIFIED_LAYERS, self.prefix) or \
|
||||
(envs.VLLM_LINEAR_SPECIFIED_KEYS != "" and envs.VLLM_LINEAR_SPECIFIED_KEYS in self.prefix):
|
||||
self.opt_level = envs.VLLM_LINEAR_SPECIFIED_OPT_LEVEL
|
||||
self.opt_flag = quant_config is None and self.opt_level != 0 and \
|
||||
self.__class__.__name__ in LINEAR_OPT_SUPPORTED
|
||||
|
||||
if parse_opt_exclude_layers(envs.VLLM_OPT_EXCLUDE_LAYERS, self.prefix):
|
||||
self.opt_flag = False
|
||||
logger.info(f"Excluding layer {self.prefix} from optimization")
|
||||
|
||||
if self.opt_flag:
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import CompressedTensorsLinearMethod
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import CompressedTensorsW8A8Int8
|
||||
self.quant_method: QuantizeMethodBase | None = CompressedTensorsLinearMethod(None)
|
||||
self.scheme = CompressedTensorsW8A8Int8(QuantizationStrategy.CHANNEL, False, True, is_w4a8_linear=True if self.opt_level == 2 else False)
|
||||
elif quant_config is None:
|
||||
self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod()
|
||||
else:
|
||||
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
||||
self.return_bias = return_bias
|
||||
self.output_padding_size = 0
|
||||
self.disable_tp = disable_tp
|
||||
self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
|
||||
self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
|
||||
self.output_padding_size = 0
|
||||
|
||||
def update_param_tp_status(self):
|
||||
for param in self.parameters():
|
||||
@@ -402,7 +380,7 @@ class ReplicatedLinear(LinearBase):
|
||||
# If the weight on disk does not have a shape, give it one
|
||||
# (such scales for AutoFp8).
|
||||
# Special case for GGUF
|
||||
|
||||
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||
if is_gguf_weight_type:
|
||||
@@ -419,7 +397,17 @@ class ReplicatedLinear(LinearBase):
|
||||
f"Tried to load weights of size {loaded_weight.size()}"
|
||||
f"to a parameter of size {param.size()}"
|
||||
)
|
||||
if self.opt_flag:
|
||||
if self.opt_level == 1:
|
||||
loaded_weight, scale = weight_quant_l1(loaded_weight)
|
||||
elif self.opt_level == 2:
|
||||
loaded_weight, scale = weight_quant_l2(loaded_weight, format="NN")
|
||||
|
||||
param.data.copy_(loaded_weight)
|
||||
if self.opt_flag:
|
||||
params_dict = dict(self.named_parameters())
|
||||
scale_param = params_dict["weight_scale"]
|
||||
scale_param.data.copy_(scale)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -609,7 +597,18 @@ class ColumnParallelLinear(LinearBase):
|
||||
if len(loaded_weight.shape) == 0:
|
||||
assert loaded_weight.numel() == 1
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
if self.opt_flag:
|
||||
if self.opt_level == 1:
|
||||
loaded_weight, scale = weight_quant_l1(loaded_weight)
|
||||
elif self.opt_level == 2:
|
||||
loaded_weight, scale = weight_quant_l2(loaded_weight, format="NN")
|
||||
|
||||
param.load_column_parallel_weight(loaded_weight=loaded_weight)
|
||||
if self.opt_flag:
|
||||
params_dict = dict(self.named_parameters())
|
||||
scale_param = params_dict["weight_scale"]
|
||||
scale_param.load_column_parallel_weight(loaded_weight=scale)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -733,16 +732,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
loaded_shard_id: tuple[int, ...] | int | None = None,
|
||||
):
|
||||
self.validate_shard_id(loaded_shard_id)
|
||||
# FIXME(Isotr0py): Enable tuple shard_id for BNB quantization.
|
||||
if isinstance(loaded_shard_id, tuple):
|
||||
raise NotImplementedError(
|
||||
"Shard id with multiple indices is not supported in weight_loader, "
|
||||
"please use weight_loader_v2 instead."
|
||||
)
|
||||
# Special case for GGUF
|
||||
# initialize GGUF param after we know the quantize type
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||
if isinstance(loaded_shard_id, tuple) and (
|
||||
is_gguf_weight or is_gguf_weight_type
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"Shard id with multiple indices is not supported for GGUF."
|
||||
)
|
||||
if is_gguf_weight_type:
|
||||
if loaded_shard_id is not None:
|
||||
param.data[loaded_shard_id].copy_(loaded_weight)
|
||||
@@ -770,7 +769,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
# Special case for per-tensor scale to load scalar into fused array.
|
||||
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
|
||||
|
||||
if loaded_shard_id is None:
|
||||
if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
|
||||
# Loaded weight is already fused on disk (mlp).
|
||||
# (e.g., Phi-3's gate_up_proj).
|
||||
if output_dim is None:
|
||||
@@ -782,10 +781,25 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
return
|
||||
|
||||
output_sizes = (
|
||||
self.output_sizes[loaded_shard_id[0] : loaded_shard_id[-1] + 1]
|
||||
if loaded_shard_id is not None
|
||||
else self.output_sizes
|
||||
)
|
||||
current_shard_offset = 0
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
if (
|
||||
use_bitsandbytes_4bit
|
||||
and isinstance(loaded_shard_id, tuple)
|
||||
and self.tp_size > 1
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"Shard id with multiple indices is not supported "
|
||||
"for BNB quantization with TP yet."
|
||||
)
|
||||
shard_offsets: list[tuple[int, int, int]] = []
|
||||
for i, output_size in enumerate(self.output_sizes):
|
||||
for i, output_size in enumerate(output_sizes):
|
||||
shard_offsets.append((i, current_shard_offset, output_size))
|
||||
current_shard_offset += output_size
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
@@ -850,9 +864,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
|
||||
|
||||
if use_bitsandbytes_4bit:
|
||||
shard_size = loaded_weight.shape[output_dim]
|
||||
shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
|
||||
|
||||
index = list(itertools.accumulate([0] + self.output_sizes))
|
||||
orig_offsets = {
|
||||
str(i): (index[i], size) for i, size in enumerate(self.output_sizes)
|
||||
}
|
||||
orig_offsets["total"] = (self.output_size, 0)
|
||||
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
||||
param, orig_offsets, str(loaded_shard_id)
|
||||
)
|
||||
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
|
||||
start_idx = self.tp_rank * shard_size
|
||||
if not is_sharded_weight:
|
||||
@@ -921,12 +940,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: tuple[int, ...] | int | None = None,
|
||||
):
|
||||
if self.opt_flag:
|
||||
if self.opt_level == 1:
|
||||
loaded_weight, scale = weight_quant_l1(loaded_weight)
|
||||
elif self.opt_level == 2:
|
||||
loaded_weight, scale = weight_quant_l2(loaded_weight, format="NN")
|
||||
self.validate_shard_id(loaded_shard_id)
|
||||
dtype = loaded_weight.dtype
|
||||
if envs.VLLM_W8A8_LINEAR_USE_W4A8 and not (param.shape[0] == 1 or param.shape[1] == 1) and dtype == torch.int8:
|
||||
load_sizes = [self.output_sizes[i] // 2 for i in range(len(self.output_sizes))]
|
||||
else:
|
||||
load_sizes = self.output_sizes
|
||||
if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
|
||||
if isinstance(param, PerTensorScaleParameter):
|
||||
param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
|
||||
@@ -953,19 +972,21 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
|
||||
assert loaded_shard_id < len(self.output_sizes)
|
||||
|
||||
# shard_offset = sum(self.output_sizes[:loaded_shard_id])
|
||||
# shard_size = self.output_sizes[loaded_shard_id]
|
||||
shard_offset = sum(load_sizes[:loaded_shard_id])
|
||||
shard_size = load_sizes[loaded_shard_id]
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id])
|
||||
shard_size = self.output_sizes[loaded_shard_id]
|
||||
shard_offset //= self.tp_size
|
||||
shard_size //= self.tp_size
|
||||
scale_shard_offset = shard_offset
|
||||
scale_shard_size = shard_size
|
||||
if self.opt_flag and self.opt_level == 2:
|
||||
shard_offset = shard_offset // 2
|
||||
shard_size = shard_size // 2
|
||||
|
||||
if isinstance(param, BlockQuantScaleParameter):
|
||||
weight_block_size = getattr(self, "weight_block_size", None)
|
||||
shard_size, shard_offset = adjust_block_scale_shard(
|
||||
weight_block_size, shard_size, shard_offset
|
||||
)
|
||||
|
||||
param.load_merged_column_weight(
|
||||
loaded_weight=loaded_weight,
|
||||
shard_id=loaded_shard_id,
|
||||
@@ -973,6 +994,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
shard_size=shard_size,
|
||||
tp_rank=self.tp_rank,
|
||||
)
|
||||
if self.opt_flag:
|
||||
params_dict = dict(self.named_parameters())
|
||||
scale_param = params_dict["weight_scale"]
|
||||
scale_param.load_merged_column_weight(
|
||||
loaded_weight=scale,
|
||||
shard_id=loaded_shard_id,
|
||||
shard_offset=scale_shard_offset,
|
||||
shard_size=scale_shard_size,
|
||||
tp_rank=self.tp_rank,
|
||||
)
|
||||
|
||||
|
||||
class QKVParallelLinear(ColumnParallelLinear):
|
||||
@@ -1128,12 +1159,24 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
|
||||
shard_size=shard_size, shard_offset=shard_offset
|
||||
)
|
||||
|
||||
loaded_weight_shard = loaded_weight.narrow(
|
||||
param.output_dim, shard_offset, shard_size
|
||||
)
|
||||
if self.opt_level == 2:
|
||||
loaded_weight_shard = loaded_weight.narrow(
|
||||
0, shard_offset, shard_size
|
||||
)
|
||||
else:
|
||||
loaded_weight_shard = loaded_weight.narrow(
|
||||
param.output_dim, shard_offset, shard_size
|
||||
)
|
||||
self.weight_loader_v2(param, loaded_weight_shard, shard_id)
|
||||
|
||||
def quant(self, loaded_weight: torch.Tensor):
|
||||
if self.opt_flag:
|
||||
if self.opt_level == 1:
|
||||
return weight_quant_l1(loaded_weight)
|
||||
elif self.opt_level == 2:
|
||||
return weight_quant_l2(loaded_weight, format="NN")
|
||||
return loaded_weight, None
|
||||
|
||||
def weight_loader_v2(
|
||||
self,
|
||||
param: BasevLLMParameter,
|
||||
@@ -1141,15 +1184,27 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
loaded_shard_id: str | None = None,
|
||||
):
|
||||
self.validate_shard_id(loaded_shard_id)
|
||||
params_dict = dict(self.named_parameters())
|
||||
if loaded_shard_id is None: # special case for certain models
|
||||
if isinstance(param, PerTensorScaleParameter):
|
||||
loaded_weight, scale = self.quant(loaded_weight)
|
||||
param.load_qkv_weight(
|
||||
loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank
|
||||
)
|
||||
if self.opt_flag:
|
||||
scale_param = params_dict["weight_scale"]
|
||||
scale_param.load_qkv_weight(
|
||||
loaded_weight=scale, shard_id=0, tp_rank=self.tp_rank
|
||||
)
|
||||
return
|
||||
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
|
||||
loaded_weight, scale = self.quant(loaded_weight)
|
||||
param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank)
|
||||
if self.opt_flag:
|
||||
scale_param = params_dict["weight_scale"]
|
||||
scale_param.load_qkv_weight(loaded_weight=scale, tp_rank=self.tp_rank)
|
||||
return
|
||||
|
||||
# TODO: @dsikka - move to parameter.py
|
||||
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
||||
return
|
||||
@@ -1158,11 +1213,15 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
|
||||
shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
|
||||
shard_size = self._get_shard_size_mapping(loaded_shard_id)
|
||||
dtype = loaded_weight.dtype
|
||||
# w4a8 gemm需要除2,scale 不需要
|
||||
if envs.VLLM_W8A8_LINEAR_USE_W4A8 and not (param.shape[0] == 1 or param.shape[1] == 1) and dtype == torch.int8:
|
||||
shard_offset //= 2
|
||||
shard_size //= 2
|
||||
|
||||
scale_shard_offset = shard_offset
|
||||
scale_shard_size = shard_size
|
||||
|
||||
loaded_weight, scale = self.quant(loaded_weight)
|
||||
|
||||
if self.opt_flag and self.opt_level == 2:
|
||||
shard_offset = shard_offset // 2
|
||||
shard_size = shard_size // 2
|
||||
|
||||
if isinstance(param, BlockQuantScaleParameter):
|
||||
weight_block_size = getattr(self, "weight_block_size", None)
|
||||
@@ -1179,6 +1238,15 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
tp_rank=self.tp_rank,
|
||||
)
|
||||
|
||||
if self.opt_flag:
|
||||
scale_param = params_dict["weight_scale"]
|
||||
scale_param.load_qkv_weight(loaded_weight=scale,
|
||||
num_heads=self.num_kv_head_replicas,
|
||||
shard_id=loaded_shard_id,
|
||||
shard_offset=scale_shard_offset,
|
||||
shard_size=scale_shard_size,
|
||||
tp_rank=self.tp_rank)
|
||||
|
||||
def weight_loader(
|
||||
self,
|
||||
param: Parameter,
|
||||
@@ -1525,7 +1593,17 @@ class RowParallelLinear(LinearBase):
|
||||
assert loaded_weight.numel() == 1
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
if self.opt_flag:
|
||||
if self.opt_level == 1:
|
||||
loaded_weight, scale = weight_quant_l1(loaded_weight)
|
||||
elif self.opt_level == 2:
|
||||
loaded_weight, scale = weight_quant_l2(loaded_weight, format="NN")
|
||||
|
||||
param.load_row_parallel_weight(loaded_weight=loaded_weight)
|
||||
if self.opt_flag:
|
||||
params_dict = dict(self.named_parameters())
|
||||
scale_param = params_dict["weight_scale"]
|
||||
scale_param.load_row_parallel_weight(loaded_weight=scale)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -61,7 +61,10 @@ class LogitsProcessor(CustomOp):
|
||||
logits = hidden_states
|
||||
else:
|
||||
# Get the logits for the next tokens.
|
||||
logits = self._get_logits(hidden_states, lm_head, embedding_bias)
|
||||
if hidden_states.shape[0] > 0:
|
||||
logits = self._get_logits(hidden_states, lm_head, embedding_bias)
|
||||
else:
|
||||
logits = torch.empty([0, lm_head.weight.shape[0]], device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
if logits is not None:
|
||||
if self.soft_cap is not None:
|
||||
logits = logits / self.soft_cap
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -43,7 +44,6 @@ class MiniMaxText01RMSNormTP(CustomOp):
|
||||
|
||||
self.weight.weight_loader = self.weight_loader
|
||||
self.variance_epsilon = eps
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def weight_loader(
|
||||
@@ -56,7 +56,6 @@ class MiniMaxText01RMSNormTP(CustomOp):
|
||||
shard_size = loaded_weight.shape[0] // tp_world
|
||||
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
||||
param.data.copy_(loaded_weight[shard])
|
||||
return
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
@@ -102,6 +101,101 @@ class MiniMaxText01RMSNormTP(CustomOp):
|
||||
return q, k
|
||||
|
||||
|
||||
def clear_linear_attention_cache_for_new_sequences(
|
||||
kv_cache: torch.Tensor,
|
||||
state_indices_tensor: torch.Tensor,
|
||||
attn_metadata: LinearAttentionMetadata,
|
||||
) -> None:
|
||||
num_prefills = getattr(attn_metadata, "num_prefills", 0)
|
||||
if num_prefills <= 0:
|
||||
return
|
||||
|
||||
num_decode_tokens = getattr(attn_metadata, "num_decode_tokens", 0)
|
||||
for prefill_idx in range(num_prefills):
|
||||
q_start = attn_metadata.query_start_loc[num_decode_tokens + prefill_idx]
|
||||
q_end = attn_metadata.query_start_loc[num_decode_tokens + prefill_idx + 1]
|
||||
query_len = q_end - q_start
|
||||
context_len = (
|
||||
attn_metadata.seq_lens[num_decode_tokens + prefill_idx] - query_len
|
||||
)
|
||||
if context_len == 0:
|
||||
block_to_clear = state_indices_tensor[num_decode_tokens + prefill_idx]
|
||||
kv_cache[block_to_clear, ...] = 0
|
||||
|
||||
|
||||
def linear_attention_decode(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slope_rate: torch.Tensor,
|
||||
state_indices_tensor: torch.Tensor,
|
||||
q_start: int = 0,
|
||||
q_end: int | None = None,
|
||||
slot_start: int = 0,
|
||||
slot_end: int | None = None,
|
||||
block_size: int = 32,
|
||||
) -> torch.Tensor:
|
||||
q = q[q_start:q_end].unsqueeze(2).contiguous()
|
||||
k = k[q_start:q_end].unsqueeze(2).contiguous()
|
||||
v = v[q_start:q_end].unsqueeze(2).contiguous()
|
||||
slot_id = state_indices_tensor[slot_start:slot_end]
|
||||
return linear_decode_forward_triton(
|
||||
q, k, v, kv_cache, slope_rate, slot_id, block_size
|
||||
)
|
||||
|
||||
|
||||
def linear_attention_prefill_and_mix(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
state_indices_tensor: torch.Tensor,
|
||||
attn_metadata: LinearAttentionMetadata,
|
||||
slope_rate: torch.Tensor,
|
||||
block_size: int,
|
||||
decode_fn: Callable[..., torch.Tensor],
|
||||
prefix_fn: Callable[..., torch.Tensor],
|
||||
layer_idx: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
hidden = []
|
||||
for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
|
||||
if _prefill_idx >= len(attn_metadata.query_start_loc):
|
||||
break
|
||||
if _prefill_idx >= len(state_indices_tensor):
|
||||
break
|
||||
offset = attn_metadata.num_decode_tokens
|
||||
_start = attn_metadata.query_start_loc[offset + _prefill_idx]
|
||||
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
|
||||
slot_id = state_indices_tensor[offset + _prefill_idx]
|
||||
qs = q[_start:_end].transpose(0, 1).contiguous()
|
||||
ks = k[_start:_end].transpose(0, 1).contiguous()
|
||||
vs = v[_start:_end].transpose(0, 1).contiguous()
|
||||
slice_layer_cache = kv_cache[slot_id, ...]
|
||||
out_slice = prefix_fn(
|
||||
qs,
|
||||
ks,
|
||||
vs,
|
||||
slice_layer_cache,
|
||||
slope_rate,
|
||||
block_size,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
hidden.append(out_slice.contiguous())
|
||||
|
||||
if attn_metadata.num_decode_tokens > 0:
|
||||
hidden_decode = decode_fn(
|
||||
q, k, v, kv_cache, state_indices_tensor, attn_metadata
|
||||
)
|
||||
hidden.insert(0, hidden_decode)
|
||||
|
||||
if not hidden:
|
||||
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
|
||||
|
||||
hidden = torch.concat(hidden, dim=0).contiguous()
|
||||
return hidden
|
||||
|
||||
|
||||
class MiniMaxText01LinearKernel:
|
||||
@staticmethod
|
||||
def jit_linear_forward_prefix(
|
||||
@@ -258,50 +352,33 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
def _prefill_and_mix_infer(
|
||||
self, q, k, v, kv_cache, state_indices_tensor, attn_metadata
|
||||
):
|
||||
hidden = []
|
||||
for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
|
||||
if _prefill_idx >= len(attn_metadata.query_start_loc):
|
||||
break
|
||||
if _prefill_idx >= len(state_indices_tensor):
|
||||
break
|
||||
offset = attn_metadata.num_decode_tokens
|
||||
_start = attn_metadata.query_start_loc[offset + _prefill_idx]
|
||||
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
|
||||
slot_id = state_indices_tensor[offset + _prefill_idx]
|
||||
qs = q[_start:_end].transpose(0, 1).contiguous()
|
||||
ks = k[_start:_end].transpose(0, 1).contiguous()
|
||||
vs = v[_start:_end].transpose(0, 1).contiguous()
|
||||
slice_layer_cache = kv_cache[slot_id, ...]
|
||||
|
||||
out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix(
|
||||
qs,
|
||||
ks,
|
||||
vs,
|
||||
slice_layer_cache,
|
||||
self.tp_slope,
|
||||
self.BLOCK,
|
||||
layer_idx=self.layer_idx,
|
||||
)
|
||||
hidden.append(out_slice.contiguous())
|
||||
if attn_metadata.num_decode_tokens > 0:
|
||||
hidden_decode = self._decode_infer(
|
||||
q, k, v, kv_cache, state_indices_tensor, attn_metadata
|
||||
)
|
||||
hidden.insert(0, hidden_decode)
|
||||
|
||||
if not hidden:
|
||||
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
|
||||
|
||||
hidden = torch.concat(hidden, dim=0).contiguous()
|
||||
return hidden
|
||||
return linear_attention_prefill_and_mix(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
kv_cache=kv_cache,
|
||||
state_indices_tensor=state_indices_tensor,
|
||||
attn_metadata=attn_metadata,
|
||||
slope_rate=self.tp_slope,
|
||||
block_size=self.BLOCK,
|
||||
decode_fn=self._decode_infer,
|
||||
prefix_fn=MiniMaxText01LinearKernel.jit_linear_forward_prefix,
|
||||
layer_idx=self.layer_idx,
|
||||
)
|
||||
|
||||
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, attn_metadata):
|
||||
q = q[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
k = k[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
v = v[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
slot_id = state_indices_tensor[: attn_metadata.num_decodes]
|
||||
hidden = linear_decode_forward_triton(
|
||||
q, k, v, kv_cache, self.tp_slope, slot_id, 32
|
||||
hidden = linear_attention_decode(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
kv_cache,
|
||||
self.tp_slope,
|
||||
state_indices_tensor,
|
||||
q_start=0,
|
||||
q_end=attn_metadata.num_decode_tokens,
|
||||
slot_start=0,
|
||||
slot_end=attn_metadata.num_decodes,
|
||||
block_size=32,
|
||||
)
|
||||
return hidden
|
||||
|
||||
@@ -338,27 +415,9 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
if attn_metadata is not None:
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
|
||||
num_prefills = getattr(attn_metadata, "num_prefills", 0)
|
||||
if num_prefills > 0:
|
||||
num_decode_tokens = getattr(attn_metadata, "num_decode_tokens", 0)
|
||||
for prefill_idx in range(num_prefills):
|
||||
q_start = attn_metadata.query_start_loc[
|
||||
num_decode_tokens + prefill_idx
|
||||
]
|
||||
q_end = attn_metadata.query_start_loc[
|
||||
num_decode_tokens + prefill_idx + 1
|
||||
]
|
||||
query_len = q_end - q_start
|
||||
context_len = (
|
||||
attn_metadata.seq_lens[num_decode_tokens + prefill_idx]
|
||||
- query_len
|
||||
)
|
||||
if context_len == 0:
|
||||
block_to_clear = state_indices_tensor[
|
||||
num_decode_tokens + prefill_idx
|
||||
]
|
||||
kv_cache[block_to_clear, ...] = 0
|
||||
clear_linear_attention_cache_for_new_sequences(
|
||||
kv_cache, state_indices_tensor, attn_metadata
|
||||
)
|
||||
|
||||
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
|
||||
if attn_metadata is None:
|
||||
|
||||
@@ -271,6 +271,8 @@ class MambaMixer(MambaBase, PluggableLayer):
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
has_initial_states_p = attn_metadata.has_initial_states_p
|
||||
cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p
|
||||
last_chunk_indices_p = attn_metadata.last_chunk_indices_p
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
||||
@@ -376,6 +378,8 @@ class MambaMixer(MambaBase, PluggableLayer):
|
||||
block_idx_first_scheduled_token=block_idx_first_scheduled_token_p,
|
||||
block_idx_last_scheduled_token=block_idx_last_scheduled_token_p,
|
||||
initial_state_idx=block_idx_last_computed_token_p,
|
||||
cu_chunk_seqlen=cu_chunk_seqlen_p,
|
||||
last_chunk_indices=last_chunk_indices_p,
|
||||
)
|
||||
ssm_outputs.append(scan_out_p)
|
||||
|
||||
|
||||
@@ -289,9 +289,6 @@ def get_temporal_copy_spec(
|
||||
)
|
||||
|
||||
|
||||
get_full_copy_spec = get_temporal_copy_spec
|
||||
|
||||
|
||||
class MambaStateCopyFuncCalculator:
|
||||
@classmethod
|
||||
def linear_attention_state_copy_func(cls):
|
||||
|
||||
@@ -1159,7 +1159,7 @@ def causal_conv1d_update(
|
||||
f"ERROR: conv_state_indices should have shape ({batch},*) but got {conv_state_indices.shape}"
|
||||
)
|
||||
|
||||
# assert num_cache_lines >= batch
|
||||
assert num_cache_lines >= batch
|
||||
assert weight.stride(1) == 1 # Need this
|
||||
|
||||
# adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o'
|
||||
|
||||
@@ -497,6 +497,8 @@ def selective_scan_fn(
|
||||
block_idx_first_scheduled_token=None,
|
||||
block_idx_last_scheduled_token=None,
|
||||
initial_state_idx=None,
|
||||
cu_chunk_seqlen=None,
|
||||
last_chunk_indices=None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
u: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
@@ -588,6 +590,8 @@ def selective_scan_fn(
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
initial_state_idx,
|
||||
cu_chunk_seqlen,
|
||||
last_chunk_indices,
|
||||
)
|
||||
|
||||
if z is None:
|
||||
|
||||
@@ -9,7 +9,6 @@ from vllm.model_executor.custom_op import PluggableLayer
|
||||
from vllm.model_executor.layers.attention import MLAAttention
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLAModules:
|
||||
"""Modules used in MLA."""
|
||||
@@ -18,7 +17,7 @@ class MLAModules:
|
||||
kv_b_proj: torch.nn.Module
|
||||
rotary_emb: torch.nn.Module
|
||||
o_proj: torch.nn.Module
|
||||
fused_qkv_a_proj: torch.nn.Module | None
|
||||
q_a_proj: torch.nn.Module | None
|
||||
kv_a_proj_with_mqa: torch.nn.Module | None
|
||||
q_a_layernorm: torch.nn.Module | None
|
||||
q_b_proj: torch.nn.Module | None
|
||||
@@ -74,7 +73,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.num_heads = num_heads
|
||||
self.fused_qkv_a_proj = mla_modules.fused_qkv_a_proj
|
||||
self.q_a_proj = mla_modules.q_a_proj
|
||||
self.kv_a_proj_with_mqa = mla_modules.kv_a_proj_with_mqa
|
||||
self.q_a_layernorm = mla_modules.q_a_layernorm
|
||||
self.q_b_proj = mla_modules.q_b_proj
|
||||
@@ -106,7 +105,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
use_sparse=self.is_sparse,
|
||||
indexer=self.indexer,
|
||||
rotary_emb=self.rotary_emb
|
||||
rotary_emb=self.rotary_emb,
|
||||
)
|
||||
|
||||
self.prefix = prefix
|
||||
@@ -119,60 +118,47 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
) -> torch.Tensor:
|
||||
q_c = None
|
||||
kv_lora = None
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
assert self.fused_qkv_a_proj is not None, (
|
||||
"fused_qkv_a_proj is required when q_lora_rank is not None"
|
||||
)
|
||||
assert self.q_a_layernorm is not None, (
|
||||
"q_a_layernorm is required when q_lora_rank is not None"
|
||||
)
|
||||
assert self.q_b_proj is not None, (
|
||||
"q_b_proj is required when q_lora_rank is not None"
|
||||
)
|
||||
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
|
||||
q_c, kv_lora = qkv_lora.split(
|
||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
|
||||
dim=-1,
|
||||
)
|
||||
q_c = self.q_a_layernorm(q_c)
|
||||
q = self.q_b_proj(q_c)[0]
|
||||
q = self.q_a_proj(hidden_states)[0]
|
||||
kv_a, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split([self.kv_lora_rank, self.qk_rope_head_dim], dim=1)
|
||||
q = self.q_a_layernorm(q)
|
||||
q = self.q_b_proj(q)[0].view(-1, self.num_heads, self.qk_head_dim)
|
||||
kv_a = self.kv_a_layernorm(kv_a)
|
||||
else:
|
||||
assert self.kv_a_proj_with_mqa is not None, (
|
||||
"kv_a_proj_with_mqa is required when q_lora_rank is None"
|
||||
)
|
||||
assert self.q_proj is not None, (
|
||||
"q_proj is required when q_lora_rank is None"
|
||||
)
|
||||
kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
q = self.q_proj(hidden_states)[0]
|
||||
|
||||
kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c)
|
||||
|
||||
q = q.view(-1, self.num_heads, self.qk_head_dim)
|
||||
# Add head dim of 1 to k_pe
|
||||
# k_pe = k_pe.unsqueeze(1)
|
||||
|
||||
# if self.rotary_emb is not None:
|
||||
# q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb(
|
||||
# positions, q[..., self.qk_nope_head_dim :], k_pe
|
||||
# )
|
||||
|
||||
if self.indexer and self.is_sparse:
|
||||
_topk_indices = self.indexer(
|
||||
hidden_states, q_c, positions, self.indexer_rope_emb
|
||||
)
|
||||
|
||||
q = self.q_proj(hidden_states)[0].view(-1, self.num_heads, self.qk_head_dim)
|
||||
latent_kpe = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
kv_a, k_pe = latent_kpe.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=1)
|
||||
kv_a = self.kv_a_layernorm(kv_a)
|
||||
|
||||
# NOTE attention data do not have position, pass it here
|
||||
if llama_4_scaling is not None:
|
||||
q *= llama_4_scaling
|
||||
|
||||
self.mla_attn.impl.forward_prepare(positions)
|
||||
attn_out = self.mla_attn(
|
||||
q,
|
||||
kv_c_normed,
|
||||
k_pe,
|
||||
output_shape=(hidden_states.shape[0], self.num_heads * self.v_head_dim),
|
||||
)
|
||||
|
||||
attn_out = self.mla_attn(q, kv_a, k_pe, positions)
|
||||
return self.o_proj(attn_out)[0]
|
||||
|
||||
def forward_opt(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
llama_4_scaling: torch.Tensor | None = None):
|
||||
if self.q_lora_rank is not None:
|
||||
q_latent_kpe = self.q_a_proj(hidden_states)[0]
|
||||
q, kv_a, k_pe, _ = q_latent_kpe.split([self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim, self.q_a_proj.output_padding_size], dim=1)
|
||||
q_c = self.q_a_layernorm(q)
|
||||
q = self.q_b_proj(q_c)[0].view(-1, self.num_heads, self.qk_head_dim)
|
||||
kv_a = self.kv_a_layernorm(kv_a)
|
||||
else:
|
||||
q = self.q_proj(hidden_states)[0].view(-1, self.num_heads, self.qk_head_dim)
|
||||
latent_kpe = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
kv_a, k_pe = latent_kpe.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=1)
|
||||
kv_a = self.kv_a_layernorm(kv_a)
|
||||
if self.indexer and self.is_sparse:
|
||||
_topk_indices = self.indexer(hidden_states, q_c, positions,
|
||||
self.rotary_emb)
|
||||
|
||||
# NOTE attention data do not have position, pass it here
|
||||
if llama_4_scaling is not None:
|
||||
q *= llama_4_scaling
|
||||
attn_out = self.mla_attn(q, kv_a, k_pe, positions)
|
||||
return self.o_proj(attn_out)[0]
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ QuantizationMethods = Literal[
|
||||
"modelopt",
|
||||
"modelopt_fp4",
|
||||
"modelopt_mxfp8",
|
||||
"modelopt_mixed",
|
||||
"gguf",
|
||||
"gptq_marlin",
|
||||
"awq_marlin",
|
||||
@@ -32,6 +33,7 @@ QuantizationMethods = Literal[
|
||||
"mxfp4",
|
||||
"petit_nvfp4",
|
||||
"cpu_awq",
|
||||
"w8a16"
|
||||
]
|
||||
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
|
||||
|
||||
@@ -120,12 +122,18 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
from .gptq import GPTQConfig
|
||||
from .gptq_marlin import GPTQMarlinConfig
|
||||
from .inc import INCConfig
|
||||
from .modelopt import ModelOptFp8Config, ModelOptMxFp8Config, ModelOptNvFp4Config
|
||||
from .modelopt import (
|
||||
ModelOptFp8Config,
|
||||
ModelOptMixedPrecisionConfig,
|
||||
ModelOptMxFp8Config,
|
||||
ModelOptNvFp4Config,
|
||||
)
|
||||
from .moe_wna16 import MoeWNA16Config
|
||||
from .mxfp4 import Mxfp4Config
|
||||
from .petit import PetitNvFp4Config
|
||||
from .ptpc_fp8 import PTPCFp8Config
|
||||
from .torchao import TorchAOConfig
|
||||
from .w8a16 import W8a16Config
|
||||
|
||||
method_to_config: dict[str, type[QuantizationConfig]] = {
|
||||
"awq": AWQConfig,
|
||||
@@ -135,6 +143,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
"modelopt": ModelOptFp8Config,
|
||||
"modelopt_fp4": ModelOptNvFp4Config,
|
||||
"modelopt_mxfp8": ModelOptMxFp8Config,
|
||||
"modelopt_mixed": ModelOptMixedPrecisionConfig,
|
||||
"gguf": GGUFConfig,
|
||||
"gptq_marlin": GPTQMarlinConfig,
|
||||
"awq_marlin": AWQMarlinConfig,
|
||||
@@ -151,6 +160,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
"mxfp4": Mxfp4Config,
|
||||
"petit_nvfp4": PetitNvFp4Config,
|
||||
"cpu_awq": CPUAWQConfig,
|
||||
"w8a16": W8a16Config,
|
||||
}
|
||||
# Update the `method_to_config` with customized quantization methods.
|
||||
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
@@ -9,6 +9,7 @@ from torch.nn import Parameter
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
@@ -60,6 +61,7 @@ from vllm.transformers_utils.config import get_safetensors_params_metadata
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
import ixformer.inference.functions as ixfops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -197,7 +199,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
quant_method.input_dtype = get_marlin_input_dtype(prefix)
|
||||
return quant_method
|
||||
elif isinstance(layer, FusedMoE):
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
|
||||
# from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
|
||||
|
||||
# if is_layer_skipped(
|
||||
# prefix,
|
||||
@@ -213,9 +215,10 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
# return MoeWNA16Config.from_config(self.full_config).get_quant_method(
|
||||
# layer, prefix
|
||||
# )
|
||||
moe_quant_method = AWQMarlinMoEMethod(self, layer.moe_config)
|
||||
moe_quant_method.input_dtype = get_marlin_input_dtype(prefix)
|
||||
return moe_quant_method
|
||||
# moe_quant_method = AWQMarlinMoEMethod(self, layer.moe_config)
|
||||
# moe_quant_method.input_dtype = get_marlin_input_dtype(prefix)
|
||||
# return moe_quant_method
|
||||
return AWQMarlinMoEMethod(self, layer.moe_config)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@@ -389,13 +392,13 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
replace_parameter(layer, "qweight", pad_qweight)
|
||||
replace_parameter(layer, "qzeros", pad_qzeros)
|
||||
replace_parameter(layer, "scales", pad_scales)
|
||||
return
|
||||
|
||||
# TODO(gyf) Marlin format is not support for now..
|
||||
device = layer.qweight.device
|
||||
layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
|
||||
layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
|
||||
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
|
||||
|
||||
return
|
||||
# Allocate marlin workspace
|
||||
layer.workspace = marlin_make_workspace_new(device)
|
||||
|
||||
@@ -811,49 +814,33 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
# Assign the value of shared_experts_output to variable shared_experts_input for fusion
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
# return fused_marlin_moe(
|
||||
# x,
|
||||
# layer.w13_qweight,
|
||||
# layer.w2_qweight,
|
||||
# getattr(layer, "w13_bias", None),
|
||||
# getattr(layer, "w2_bias", None),
|
||||
# layer.w13_scales,
|
||||
# layer.w2_scales,
|
||||
# topk_weights,
|
||||
# topk_ids,
|
||||
# input_global_scale1=getattr(layer, "w13_input_global_scale", None),
|
||||
# input_global_scale2=getattr(layer, "w2_input_global_scale", None),
|
||||
# quant_type_id=self.quant_type.id,
|
||||
# apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
# global_num_experts=layer.global_num_experts,
|
||||
# expert_map=layer.expert_map,
|
||||
# w1_zeros=layer.w13_qzeros,
|
||||
# w2_zeros=layer.w2_qzeros,
|
||||
# workspace=layer.workspace,
|
||||
# input_dtype=self.input_dtype,
|
||||
# inplace=not self.moe.disable_inplace,
|
||||
# )
|
||||
|
||||
num_tokens, num_experts = router_logits.shape
|
||||
assert layer.activation.value == "silu", "Only SiLU activation is supported."
|
||||
use_ep = layer.expert_map is not None
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if attn_metadata:
|
||||
if isinstance(attn_metadata, dict):
|
||||
only_decode = (use_ep == False and all(t.num_decodes > 0 and t.num_prefills ==0 for t in list(attn_metadata.values())))
|
||||
else:
|
||||
only_decode = use_ep == False and attn_metadata.num_decodes > 0 and attn_metadata.num_prefills == 0
|
||||
else:
|
||||
only_decode = False
|
||||
|
||||
if use_ep:
|
||||
start_eid = layer.ep_rank * layer.local_num_experts
|
||||
end_eid = min((layer.ep_rank + 1) * layer.local_num_experts, layer.global_num_experts)
|
||||
if layer.apply_router_weight_on_input:
|
||||
raise NotImplementedError(
|
||||
"Apply router weight on input is not supported for"
|
||||
"fused Marlin MoE method.")
|
||||
|
||||
num_tokens = topk_ids.shape[0]
|
||||
num_experts = layer.global_num_experts
|
||||
if use_ep:
|
||||
hidden_size = x.shape[1]
|
||||
(
|
||||
@@ -875,7 +862,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
dtype=x.dtype,
|
||||
)
|
||||
else:
|
||||
expand_tokens = num_tokens * top_k
|
||||
expand_tokens = num_tokens * layer.top_k
|
||||
(
|
||||
src_to_dst,
|
||||
sorted_token_ids,
|
||||
@@ -885,7 +872,6 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
topk_ids=topk_ids,
|
||||
num_experts=num_experts,
|
||||
)
|
||||
expert_sizes_cpu = expert_sizes_gpu.cpu()
|
||||
|
||||
# expand + reorder
|
||||
# TODO use kernel
|
||||
@@ -893,76 +879,130 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
hidden_states=x,
|
||||
dst_to_src=sorted_token_ids,
|
||||
dst_tokens=expand_tokens,
|
||||
topk=top_k,
|
||||
topk=layer.top_k,
|
||||
src_to_dst=src_to_dst,
|
||||
)
|
||||
|
||||
# w4a16 group gemm 1
|
||||
# pt_output_1: (expand_tokens, 2n) dtype
|
||||
pt_output_1 = ixfops.moe_w4a16_group_gemm(
|
||||
input=expand_hidden_states,
|
||||
weight=layer.w13_qweight,
|
||||
w_scales=layer.w13_scales,
|
||||
quant_type="awq",
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
w_zeros=layer.w13_qzeros,
|
||||
group_size=self.quant_config.group_size,
|
||||
dst_to_src=None,
|
||||
format="NN",
|
||||
tokens_per_experts_gpu=expert_sizes_gpu,
|
||||
)
|
||||
|
||||
# act
|
||||
pt_output_2 = ixfops.silu_and_mul(pt_output_1)
|
||||
|
||||
# w4a16 group gemm 2 + reorder
|
||||
# pt_output_3: (expand_tokens, k) dtype
|
||||
if use_ep:
|
||||
pt_output_3 = torch.empty(
|
||||
(num_tokens * top_k, hidden_size),
|
||||
device=x.device,
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
ixfops.moe_w4a16_group_gemm(
|
||||
input=pt_output_2,
|
||||
weight=layer.w2_qweight,
|
||||
w_scales=layer.w2_scales,
|
||||
if only_decode:
|
||||
pt_output_1 = ixfops.moe_w4a16_group_gemv(
|
||||
input=expand_hidden_states,
|
||||
weight=layer.w13_qweight,
|
||||
w_scales=layer.w13_scales,
|
||||
quant_type="awq",
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
w_zeros=layer.w2_qzeros,
|
||||
w_zeros=layer.w13_qzeros,
|
||||
group_size=self.quant_config.group_size,
|
||||
dst_to_src=sorted_token_ids,
|
||||
format="NN",
|
||||
output=pt_output_3,
|
||||
)
|
||||
|
||||
reduce_mask = src_to_dst == -1
|
||||
final_hidden_states = ixfops.moe_output_reduce_sum(
|
||||
input=pt_output_3.view(num_tokens, top_k, -1),
|
||||
topk_weight=topk_weights,
|
||||
scaling_factor=routed_scaling_factor,
|
||||
mask=reduce_mask,
|
||||
)
|
||||
else:
|
||||
pt_output_3 = ixfops.moe_w4a16_group_gemm(
|
||||
input=pt_output_2,
|
||||
weight=layer.w2_qweight,
|
||||
w_scales=layer.w2_scales,
|
||||
quant_type="awq",
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
w_zeros=layer.w2_qzeros,
|
||||
group_size=self.quant_config.group_size,
|
||||
dst_to_src=sorted_token_ids,
|
||||
dst_to_src=None,
|
||||
format="NN",
|
||||
tokens_per_experts_gpu=expert_sizes_gpu,
|
||||
)
|
||||
|
||||
# mul + reduce_sum
|
||||
# final_hidden_states: (num_tokens, k)
|
||||
# act
|
||||
pt_output_2 = ixfops.silu_and_mul(pt_output_1)
|
||||
|
||||
pt_output_3 = ixfops.moe_w4a16_group_gemv(
|
||||
input=pt_output_2,
|
||||
weight=layer.w2_qweight,
|
||||
w_scales=layer.w2_scales,
|
||||
quant_type="awq",
|
||||
w_zeros=layer.w2_qzeros,
|
||||
group_size=self.quant_config.group_size,
|
||||
dst_to_src=sorted_token_ids,
|
||||
format="NN",
|
||||
tokens_per_experts_gpu=expert_sizes_gpu,
|
||||
)
|
||||
|
||||
# mul + reduce_sum
|
||||
# final_hidden_states: (num_tokens, k)
|
||||
final_hidden_states = ixfops.moe_output_reduce_sum(
|
||||
input=pt_output_3.view(num_tokens, top_k, -1),
|
||||
input=pt_output_3.view(num_tokens, layer.top_k, -1),
|
||||
topk_weight=topk_weights,
|
||||
scaling_factor=routed_scaling_factor
|
||||
scaling_factor=layer.routed_scaling_factor,
|
||||
extra_residual=shared_experts_input,
|
||||
)
|
||||
|
||||
else:
|
||||
expert_sizes_cpu = expert_sizes_gpu.cpu()
|
||||
pt_output_1 = ixfops.moe_w4a16_group_gemm(
|
||||
input=expand_hidden_states,
|
||||
weight=layer.w13_qweight,
|
||||
w_scales=layer.w13_scales,
|
||||
quant_type="awq",
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
w_zeros=layer.w13_qzeros,
|
||||
group_size=self.quant_config.group_size,
|
||||
dst_to_src=None,
|
||||
format="NN",
|
||||
tokens_per_experts_gpu=expert_sizes_gpu,
|
||||
)
|
||||
|
||||
# act
|
||||
pt_output_2 = ixfops.silu_and_mul(pt_output_1)
|
||||
|
||||
# w4a16 group gemm 2 + reorder
|
||||
# pt_output_3: (expand_tokens, k) dtype
|
||||
if use_ep:
|
||||
pt_output_3 = torch.empty(
|
||||
(num_tokens * layer.top_k, hidden_size),
|
||||
device=x.device,
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
ixfops.moe_w4a16_group_gemm(
|
||||
input=pt_output_2,
|
||||
weight=layer.w2_qweight,
|
||||
w_scales=layer.w2_scales,
|
||||
quant_type="awq",
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
w_zeros=layer.w2_qzeros,
|
||||
group_size=self.quant_config.group_size,
|
||||
dst_to_src=sorted_token_ids,
|
||||
format="NN",
|
||||
output=pt_output_3,
|
||||
tokens_per_experts_gpu=expert_sizes_gpu,
|
||||
)
|
||||
|
||||
reduce_mask = src_to_dst == -1
|
||||
final_hidden_states = ixfops.moe_output_reduce_sum(
|
||||
input=pt_output_3.view(num_tokens, layer.top_k, -1),
|
||||
topk_weight=topk_weights,
|
||||
scaling_factor=layer.routed_scaling_factor,
|
||||
mask=reduce_mask,
|
||||
)
|
||||
else:
|
||||
pt_output_3 = ixfops.moe_w4a16_group_gemm(
|
||||
input=pt_output_2,
|
||||
weight=layer.w2_qweight,
|
||||
w_scales=layer.w2_scales,
|
||||
quant_type="awq",
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
w_zeros=layer.w2_qzeros,
|
||||
group_size=self.quant_config.group_size,
|
||||
dst_to_src=sorted_token_ids,
|
||||
format="NN",
|
||||
tokens_per_experts_gpu=expert_sizes_gpu,
|
||||
)
|
||||
|
||||
# mul + reduce_sum
|
||||
# final_hidden_states: (num_tokens, k)
|
||||
final_hidden_states = ixfops.moe_output_reduce_sum(
|
||||
input=pt_output_3.view(num_tokens, layer.top_k, -1),
|
||||
topk_weight=topk_weights,
|
||||
scaling_factor=layer.routed_scaling_factor,
|
||||
extra_residual=shared_experts_input,
|
||||
)
|
||||
return final_hidden_states
|
||||
# return torch.ops.vllm.fused_marlin_moe(
|
||||
# x,
|
||||
# layer.w13_qweight,
|
||||
# layer.w2_qweight,
|
||||
# layer.w13_scales,
|
||||
# layer.w2_scales,
|
||||
# router_logits,
|
||||
# topk_weights,
|
||||
# topk_ids,
|
||||
# w1_zeros=layer.w13_qzeros,
|
||||
# w2_zeros=layer.w2_qzeros,
|
||||
# num_bits=self.quant_config.weight_bits,
|
||||
# )
|
||||
|
||||
@@ -18,7 +18,6 @@ from compressed_tensors.quantization import (
|
||||
)
|
||||
from compressed_tensors.transform import TransformConfig
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@@ -52,7 +51,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsW8A8Int8,
|
||||
CompressedTensorsW8A16Fp8,
|
||||
CompressedTensorsWNA16,
|
||||
CompressedTensorsW4A8Int8
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501
|
||||
CompressedTensorsLinearTransformMethod,
|
||||
@@ -401,8 +399,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
) -> bool:
|
||||
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
|
||||
weight_strategy = (
|
||||
weight_quant.strategy == QuantizationStrategy.TENSOR.value
|
||||
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value
|
||||
weight_quant.strategy == QuantizationStrategy.CHANNEL.value
|
||||
or weight_quant.strategy == QuantizationStrategy.GROUP.value
|
||||
)
|
||||
is_tensor = (
|
||||
weight_strategy
|
||||
@@ -420,8 +418,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
) -> bool:
|
||||
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
|
||||
weight_strategy = (
|
||||
weight_quant.strategy == QuantizationStrategy.TENSOR.value
|
||||
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value
|
||||
weight_quant.strategy == QuantizationStrategy.CHANNEL.value
|
||||
or weight_quant.strategy == QuantizationStrategy.GROUP.value
|
||||
)
|
||||
is_token = (
|
||||
weight_strategy and input_quant.strategy == QuantizationStrategy.TOKEN.value
|
||||
@@ -663,12 +661,6 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
)
|
||||
|
||||
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
|
||||
if envs.VLLM_W8A8_LINEAR_USE_W4A8:
|
||||
return CompressedTensorsW4A8Int8(
|
||||
strategy=weight_quant.strategy,
|
||||
is_static_input_scheme=False,
|
||||
input_symmetric=input_quant.symmetric,
|
||||
)
|
||||
return CompressedTensorsW8A8Int8(
|
||||
strategy=weight_quant.strategy,
|
||||
is_static_input_scheme=False,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,7 +8,7 @@ from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int
|
||||
from .compressed_tensors_w4a16_mxfp4 import CompressedTensorsW4A16Mxfp4
|
||||
from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4
|
||||
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
|
||||
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8, CompressedTensorsW4A8Int8
|
||||
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
|
||||
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
|
||||
from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS, CompressedTensorsWNA16
|
||||
|
||||
@@ -28,5 +28,4 @@ __all__ = [
|
||||
"CompressedTensorsW4A4Fp4",
|
||||
"CompressedTensorsW4A8Int",
|
||||
"CompressedTensorsW4A8Fp8",
|
||||
"CompressedTensorsW4A8Int8"
|
||||
]
|
||||
|
||||
@@ -25,11 +25,18 @@ logger = init_logger(__name__)
|
||||
|
||||
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
def __init__(
|
||||
self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool
|
||||
self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool, is_w4a8_linear: bool = False
|
||||
):
|
||||
self.strategy = strategy
|
||||
import vllm.envs as env
|
||||
if env.VLLM_MIX_QUANTIZATION_TYPE == "TENSOR":
|
||||
self.strategy = QuantizationStrategy.TENSOR
|
||||
elif env.VLLM_MIX_QUANTIZATION_TYPE == "CHANNEL":
|
||||
self.strategy = QuantizationStrategy.CHANNEL
|
||||
else:
|
||||
self.strategy = strategy
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.input_symmetric = input_symmetric
|
||||
self.is_w4a8_linear = is_w4a8_linear
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
@@ -53,16 +60,32 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
input_symmetric=self.input_symmetric,
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
remainder = input_size_per_partition % 64
|
||||
if remainder != 0:
|
||||
input_size_per_partition_padded = input_size_per_partition + (64 - remainder)
|
||||
else:
|
||||
input_size_per_partition_padded = input_size_per_partition
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
if self.is_w4a8_linear:
|
||||
# only "NN" is supported
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
input_size_per_partition_padded,
|
||||
sum(output_partition_sizes) // 2,
|
||||
dtype=torch.int8),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
else:
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition_padded,
|
||||
dtype=torch.int8),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
@@ -109,104 +132,4 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
def apply_weights(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
|
||||
) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
|
||||
|
||||
|
||||
class CompressedTensorsW4A8Int8(CompressedTensorsScheme):
|
||||
def __init__(
|
||||
self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool
|
||||
):
|
||||
self.strategy = strategy
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.input_symmetric = input_symmetric
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# turing and up
|
||||
return 75
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
self.kernel = init_int8_linear_kernel(
|
||||
is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
|
||||
is_static_input_scheme=self.is_static_input_scheme,
|
||||
input_symmetric=self.input_symmetric,
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
# WEIGHT
|
||||
# weight = ModelWeightParameter(
|
||||
# data=torch.empty(
|
||||
# sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8
|
||||
# ),
|
||||
# input_dim=1,
|
||||
# output_dim=0,
|
||||
# weight_loader=weight_loader,
|
||||
# )
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition,
|
||||
sum(output_partition_sizes) // 2,
|
||||
dtype=torch.int8
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
if self.strategy == QuantizationStrategy.CHANNEL:
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
else:
|
||||
assert self.strategy == QuantizationStrategy.TENSOR
|
||||
weight_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE
|
||||
input_zero_point = None
|
||||
input_scale = None
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
|
||||
)
|
||||
if not self.input_symmetric:
|
||||
# Note: compressed-tensors stores the zp using the same dtype
|
||||
# as the weights
|
||||
# AZP loaded as int8 but used as int32
|
||||
input_zero_point = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
|
||||
)
|
||||
|
||||
layer.register_parameter("input_zero_point", input_zero_point)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
if not hasattr(layer, "azp_adj"):
|
||||
layer.register_parameter("azp_adj", None)
|
||||
|
||||
# Checkpoints are serialized in compressed-tensors format, which is
|
||||
# different from the format the kernel may want. Handle repacking here.
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
|
||||
) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
return self.kernel.apply_weights(layer, x, bias, self.is_w4a8_linear)
|
||||
|
||||
@@ -23,17 +23,13 @@ from vllm.model_executor.layers.batch_invariant import (
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize,
|
||||
FusedMoeWeightScaleSupported,
|
||||
MoEActivation,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
|
||||
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
Fp8MoeBackend,
|
||||
convert_to_fp8_moe_kernel_format,
|
||||
make_fp8_moe_kernel,
|
||||
make_fp8_moe_quant_config,
|
||||
@@ -50,9 +46,6 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_fi_trtllm_fp8_per_tensor_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
create_fp8_input_scale,
|
||||
@@ -860,14 +853,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
|
||||
replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)
|
||||
|
||||
# Setup modular kernel for TP case and naive DP/EP case.
|
||||
# In non-naive DP/EP case, we will create a ModularKernelMethod.
|
||||
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
|
||||
# in both cases.
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config:
|
||||
assert self.experts_cls is not None
|
||||
self.moe_mk = make_fp8_moe_kernel(
|
||||
self.moe_kernel = make_fp8_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
fp8_backend=self.fp8_backend,
|
||||
@@ -930,29 +919,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
layer: torch.nn.Module,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
# TRTLLM does not use Modular Kernel.
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
return None
|
||||
|
||||
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
|
||||
w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
|
||||
w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
|
||||
a1_scale = layer.w13_input_scale
|
||||
@@ -983,10 +956,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
def supports_eplb(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
@@ -994,50 +963,22 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.is_monolithic
|
||||
assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
|
||||
# TODO(rob): convert this to MK.
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
|
||||
assert layer.activation == MoEActivation.SILU, (
|
||||
f"Expected 'silu' activation but got {layer.activation}"
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply_monolithic(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
router_logits,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
e_score_correction_bias=layer.e_score_correction_bias,
|
||||
routed_scaling_factor=layer.routed_scaling_factor,
|
||||
)
|
||||
|
||||
if self.block_quant:
|
||||
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
|
||||
|
||||
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=layer.e_score_correction_bias,
|
||||
x=x,
|
||||
w13_weight=layer.w13_weight,
|
||||
w13_weight_scale_inv=layer.w13_weight_scale_inv,
|
||||
w2_weight=layer.w2_weight,
|
||||
w2_weight_scale_inv=layer.w2_weight_scale_inv,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
top_k=layer.top_k,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
block_shape=self.weight_block_size,
|
||||
routing_method_type=layer.routing_method_type,
|
||||
routed_scaling=layer.routed_scaling_factor,
|
||||
)
|
||||
else:
|
||||
return apply_fi_trtllm_fp8_per_tensor_moe(
|
||||
layer=layer,
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
routing_bias=layer.e_score_correction_bias,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
top_k=layer.top_k,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
@@ -1046,9 +987,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.moe_mk is not None
|
||||
assert not self.is_monolithic
|
||||
return self.moe_mk(
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Any
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from gguf import GGMLQuantizationType as WeightType
|
||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
@@ -234,7 +235,7 @@ try:
|
||||
op_func=_fused_mul_mat_gguf,
|
||||
fake_impl=_fused_mul_mat_gguf_fake,
|
||||
)
|
||||
fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf
|
||||
fused_mul_mat_gguf = _fused_mul_mat_gguf
|
||||
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
@@ -365,7 +366,7 @@ try:
|
||||
op_func=_fused_moe_gguf,
|
||||
fake_impl=_fused_moe_gguf_fake,
|
||||
)
|
||||
fused_moe_gguf = torch.ops.vllm._fused_moe_gguf
|
||||
fused_moe_gguf = _fused_moe_gguf
|
||||
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
@@ -410,7 +411,7 @@ try:
|
||||
op_func=_apply_gguf_embedding,
|
||||
fake_impl=_apply_gguf_embedding_fake,
|
||||
)
|
||||
apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding
|
||||
apply_gguf_embedding = _apply_gguf_embedding
|
||||
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
@@ -451,6 +452,9 @@ class GGUFLinearMethod(LinearMethodBase):
|
||||
"data_container": [],
|
||||
"shard_id": [],
|
||||
"shard_id_map": {},
|
||||
"params_dtype": params_dtype,
|
||||
"input_size_per_partition" :input_size_per_partition, # restore shape for qkv and merge
|
||||
"output_partition_sizes" :output_partition_sizes,
|
||||
},
|
||||
)
|
||||
set_weight_attrs(qweight, extra_weight_attrs)
|
||||
@@ -664,6 +668,10 @@ class GGUFEmbeddingMethod(GGUFLinearMethod):
|
||||
"""
|
||||
|
||||
def embedding(self, layer: torch.nn.Module, x: torch.Tensor) -> torch.Tensor:
|
||||
weight = layer.weight
|
||||
return F.embedding(x, weight)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
qweight = layer.qweight
|
||||
qweight_type = layer.qweight_type.weight_type
|
||||
hidden_size = qweight.tensor_shape[1]
|
||||
|
||||
@@ -128,7 +128,7 @@ class GPTQConfig(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
|
||||
@@ -59,9 +59,164 @@ from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.transformers_utils.config import get_safetensors_params_metadata
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
import ixformer.inference.functions as ixfops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
#[B,K//8,N] ->[B,K,N]
|
||||
# less memmory
|
||||
def unpack_k_batch_opt(packed_w: torch.Tensor, num_bits: int = 4, chunk_size: int = 2) -> torch.Tensor:
|
||||
"""
|
||||
Memory-efficient unpacking for 3D tensor.
|
||||
Converts [B, K // pack_factor, N] int32 tensor → [B, K, N] int8 tensor,
|
||||
without broadcasting huge intermediate tensors (avoids OOM).
|
||||
|
||||
Args:
|
||||
packed_w: torch.int32 tensor of shape [B, K // pack_factor, N].
|
||||
num_bits: Number of bits per packed element (e.g., 4 or 2).
|
||||
chunk_size: How many bit groups to unpack at once (tradeoff between speed and memory).
|
||||
|
||||
Returns:
|
||||
unpacked: torch.int8 tensor of shape [B, K, N].
|
||||
"""
|
||||
B, k_packed, N = packed_w.shape
|
||||
pack_factor = 32 // num_bits
|
||||
K = k_packed * pack_factor
|
||||
mask = (1 << num_bits) - 1
|
||||
|
||||
# Allocate output tensor once
|
||||
unpacked = torch.empty((B, K, N), dtype=torch.int8, device=packed_w.device)
|
||||
|
||||
# Process bit chunks iteratively to save memory
|
||||
for i in range(0, pack_factor, chunk_size):
|
||||
# Precompute shifts for this chunk
|
||||
shift_vals = num_bits * torch.arange(i, min(i + chunk_size, pack_factor), device=packed_w.device)
|
||||
# [chunk_size, 1, 1, 1]
|
||||
shifts = shift_vals.view(-1, 1, 1, 1)
|
||||
# Compute small chunk only
|
||||
chunk = ((packed_w.unsqueeze(0) >> shifts) & mask).to(torch.int8)
|
||||
|
||||
# chunk: [chunk_size, B, k_packed, N]
|
||||
# write into output
|
||||
for j in range(chunk.shape[0]):
|
||||
unpacked[:, (i + j)::pack_factor, :] = chunk[j]
|
||||
|
||||
del chunk # release memory early
|
||||
|
||||
return unpacked
|
||||
|
||||
# more memmory
|
||||
def unpack_k_batch(packed_w: torch.Tensor, num_bits: int = 4) -> torch.Tensor:
|
||||
"""
|
||||
Efficient vectorized unpacking for 3D tensor (batch version).
|
||||
Converts [B, K // pack_factor, N] int32 tensor → [B, K, N] int8 tensor.
|
||||
|
||||
Args:
|
||||
packed_w: torch.int32 tensor of shape [B, K // pack_factor, N].
|
||||
num_bits: Number of bits per packed element (e.g., 4).
|
||||
|
||||
Returns:
|
||||
unpacked: torch.int8 tensor of shape [B, K, N].
|
||||
"""
|
||||
B, k_packed, n = packed_w.shape
|
||||
pack_factor = 32 // num_bits
|
||||
k = k_packed * pack_factor
|
||||
|
||||
mask = (1 << num_bits) - 1
|
||||
|
||||
# [pack_factor, 1, 1, 1]
|
||||
shifts = (num_bits * torch.arange(pack_factor, device=packed_w.device)).view(-1, 1, 1, 1)
|
||||
|
||||
# [1, B, k_packed, N]
|
||||
packed_expanded = packed_w.unsqueeze(0)
|
||||
|
||||
# Extract each group of num_bits using bitwise ops
|
||||
unpacked_groups = ((packed_expanded >> shifts) & mask).to(torch.int8)
|
||||
|
||||
# [pack_factor, B, k_packed, N] → [B, K, N]
|
||||
unpacked = unpacked_groups.permute(1, 2, 0, 3).reshape(B, k, n)
|
||||
|
||||
return unpacked
|
||||
|
||||
|
||||
#[B,K,N] ->[B,K,N//8]
|
||||
# less memmory
|
||||
def pack_n_batch_opt(x: torch.Tensor, pack_num: int = 8, order_map=None, chunk_size: int = 2) -> torch.Tensor:
|
||||
"""
|
||||
Memory-efficient batch packing with correct bit order.
|
||||
[B, K, N] int4 -> [B, K, N//pack_num] int32.
|
||||
"""
|
||||
B, K, N = x.shape
|
||||
assert N % pack_num == 0, "N must be divisible by pack_num"
|
||||
cols = N // pack_num
|
||||
unit = 32 // pack_num
|
||||
|
||||
if order_map is None:
|
||||
order_map = list(range(pack_num))
|
||||
order_map = torch.tensor(order_map, device=x.device)
|
||||
|
||||
shifts = unit * torch.arange(pack_num, device=x.device) # always 0..unit*(pack_num-1)
|
||||
packed = torch.zeros((B, K, cols), dtype=torch.int32, device=x.device)
|
||||
x_reshape = x.view(B, K, cols, pack_num) & 0xF
|
||||
|
||||
# process in chunks for memory efficiency
|
||||
for start in range(0, pack_num, chunk_size):
|
||||
end = min(start + chunk_size, pack_num)
|
||||
idx_chunk = order_map[start:end]
|
||||
shift_chunk = shifts[start:end]
|
||||
|
||||
vals = torch.gather(x_reshape, 3, idx_chunk.view(1,1,1,-1).expand(B,K,cols,-1)).to(torch.int32)
|
||||
for j in range(vals.shape[-1]):
|
||||
packed.add_(vals[..., j] << shift_chunk[j])
|
||||
|
||||
return packed
|
||||
|
||||
## more memmory
|
||||
def pack_n_batch(x: torch.Tensor, pack_num: int = 8, order_map=None) -> torch.Tensor:
|
||||
"""
|
||||
Efficient vectorized batch packing: [B, K, N] int4 -> [B, K, N//pack_num] int32.
|
||||
|
||||
Args:
|
||||
x: torch.int32 tensor of shape [B, K, N], each element 0-15 (int4).
|
||||
pack_num: Number of 4-bit elements per packed int32 (default=8).
|
||||
order_map: Optional order of elements within each packed int32.
|
||||
|
||||
Returns:
|
||||
torch.int32 tensor of shape [B, K, N//pack_num].
|
||||
"""
|
||||
|
||||
B, K, N = x.shape
|
||||
assert N % pack_num == 0, "N must be divisible by pack_num"
|
||||
cols = N // pack_num
|
||||
|
||||
if order_map is None:
|
||||
order_map = list(range(pack_num))
|
||||
order_map = torch.tensor(order_map, device=x.device)
|
||||
|
||||
unit = 32 // pack_num # number of bits per element
|
||||
|
||||
# reshape to [B, K, cols, pack_num]
|
||||
pack_num_int = int(pack_num)
|
||||
|
||||
x_reshape = x.view(B, K, cols, pack_num_int)
|
||||
|
||||
# reorder according to order_map
|
||||
x_reorder = torch.gather(
|
||||
x_reshape, 3, order_map.view(1, 1, 1, -1).expand(B, K, cols, -1)
|
||||
)
|
||||
|
||||
# mask low 4 bits
|
||||
x_reorder = x_reorder & 0xF
|
||||
|
||||
# bit shifts [pack_num] -> [1,1,1,pack_num] broadcastable
|
||||
shifts = (unit * torch.arange(pack_num_int, device=x.device)).view(1, 1, 1, -1)
|
||||
|
||||
# shift and sum along last dimension to combine bits
|
||||
packed = (x_reorder << shifts).sum(dim=-1).to(torch.int32)
|
||||
|
||||
return packed
|
||||
|
||||
|
||||
|
||||
def get_moe_quant_method(
|
||||
config: "GPTQMarlinConfig",
|
||||
@@ -495,8 +650,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
self.quant_config = quant_config
|
||||
if self.quant_config.quant_type.size_bits == 4:
|
||||
self.quant_type = scalar_types.uint4b8
|
||||
elif self.quant_config.quant_type.size_bits == 8:
|
||||
self.quant_type = scalar_types.uint8b128
|
||||
# elif self.quant_config.quant_type.size_bits == 8:
|
||||
# self.quant_type = scalar_types.uint8b128
|
||||
else:
|
||||
raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.")
|
||||
self.input_dtype = None
|
||||
@@ -594,7 +749,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
num_experts,
|
||||
scales_size13,
|
||||
2 * intermediate_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=params_dtype,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
@@ -606,7 +761,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
num_experts,
|
||||
scales_size2,
|
||||
hidden_size // self.quant_config.pack_factor,
|
||||
dtype=params_dtype,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
@@ -656,7 +811,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
|
||||
|
||||
device = layer.w13_qweight.device
|
||||
layer.workspace = marlin_make_workspace_new(device, 4)
|
||||
# layer.workspace = marlin_make_workspace_new(device, 4)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
|
||||
@@ -673,119 +828,111 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_scales.data = layer.w2_scales.data * 512
|
||||
|
||||
# Process act_order
|
||||
if self.quant_config.desc_act:
|
||||
# if self.quant_config.desc_act:
|
||||
# Get sorting based on g_idx
|
||||
num_experts = layer.w13_g_idx.shape[0]
|
||||
w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx)
|
||||
w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx)
|
||||
w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
|
||||
w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
|
||||
for e in range(num_experts):
|
||||
w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to(
|
||||
torch.int32
|
||||
)
|
||||
w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(
|
||||
torch.int32
|
||||
)
|
||||
w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]]
|
||||
w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]]
|
||||
replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
|
||||
replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
|
||||
replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices)
|
||||
replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices)
|
||||
else:
|
||||
# Reset g_idx related tensors
|
||||
num_experts = layer.w13_g_idx.shape[0]
|
||||
device = layer.w13_g_idx.device
|
||||
layer.w13_g_idx = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_g_idx = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
# Repack weights
|
||||
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
|
||||
layer.w13_qweight,
|
||||
layer.w13_g_idx_sort_indices,
|
||||
layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
|
||||
layer.w13_qweight.shape[2],
|
||||
self.quant_config.quant_type.size_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
|
||||
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
|
||||
layer.w2_qweight,
|
||||
layer.w2_g_idx_sort_indices,
|
||||
layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
|
||||
layer.w2_qweight.shape[2],
|
||||
self.quant_config.quant_type.size_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
|
||||
# num_experts = layer.w13_g_idx.shape[0]
|
||||
# w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx)
|
||||
# w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx)
|
||||
# w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
|
||||
# w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
|
||||
# for e in range(num_experts):
|
||||
# w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to(
|
||||
# torch.int32
|
||||
# )
|
||||
# w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(
|
||||
# torch.int32
|
||||
# )
|
||||
# w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]]
|
||||
# w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]]
|
||||
# replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
|
||||
# replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
|
||||
# replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices)
|
||||
# replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices)
|
||||
# else:
|
||||
# # Reset g_idx related tensors
|
||||
# num_experts = layer.w13_g_idx.shape[0]
|
||||
# device = layer.w13_g_idx.device
|
||||
# layer.w13_g_idx = torch.nn.Parameter(
|
||||
# torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
# requires_grad=False,
|
||||
# )
|
||||
# layer.w2_g_idx = torch.nn.Parameter(
|
||||
# torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
# requires_grad=False,
|
||||
# )
|
||||
# layer.w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||
# torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
# requires_grad=False,
|
||||
# )
|
||||
# layer.w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||
# torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
# requires_grad=False,
|
||||
# )
|
||||
# # Repack weights
|
||||
# marlin_w13_qweight = ops.gptq_marlin_moe_repack(
|
||||
# layer.w13_qweight,
|
||||
# layer.w13_g_idx_sort_indices,
|
||||
# layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
|
||||
# layer.w13_qweight.shape[2],
|
||||
# self.quant_config.quant_type.size_bits,
|
||||
# )
|
||||
# replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
|
||||
# marlin_w2_qweight = ops.gptq_marlin_moe_repack(
|
||||
# layer.w2_qweight,
|
||||
# layer.w2_g_idx_sort_indices,
|
||||
# layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
|
||||
# layer.w2_qweight.shape[2],
|
||||
# self.quant_config.quant_type.size_bits,
|
||||
# )
|
||||
# replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
|
||||
# # Repack scales
|
||||
# marlin_w13_scales = marlin_moe_permute_scales(
|
||||
# s=layer.w13_scales,
|
||||
# size_k=layer.intermediate_size_per_partition,
|
||||
# size_n=layer.w13_scales.shape[2],
|
||||
# group_size=self.quant_config.group_size,
|
||||
# )
|
||||
# replace_parameter(layer, "w13_scales", marlin_w13_scales)
|
||||
# marlin_w2_scales = marlin_moe_permute_scales(
|
||||
# s=layer.w2_scales,
|
||||
# size_k=layer.w2_scales.shape[1]
|
||||
# * (
|
||||
# self.quant_config.group_size
|
||||
# if self.quant_config.group_size != -1
|
||||
# else self.quant_config.pack_factor
|
||||
# ),
|
||||
# size_n=layer.w2_scales.shape[2],
|
||||
# group_size=self.quant_config.group_size,
|
||||
# )
|
||||
# replace_parameter(layer, "w2_scales", marlin_w2_scales)
|
||||
|
||||
# The modular kernel expects w13_weight and w2_weight,
|
||||
# but GPTQ uses w13_qweight and w2_qweight
|
||||
# Alias for modular kernel
|
||||
layer.w13_weight = layer.w13_qweight
|
||||
# Alias for modular kernel
|
||||
layer.w2_weight = layer.w2_qweight
|
||||
# if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
|
||||
# layer.w13_bias.data = marlin_permute_bias(layer.w13_bias)
|
||||
|
||||
# Repack scales
|
||||
marlin_w13_scales = marlin_moe_permute_scales(
|
||||
s=layer.w13_scales,
|
||||
size_k=layer.intermediate_size_per_partition,
|
||||
size_n=layer.w13_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
if self.input_dtype == torch.int8 and layer.num_groups_w13 > 1:
|
||||
marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales(
|
||||
marlin_w13_scales
|
||||
)
|
||||
layer.register_parameter(
|
||||
"w13_input_global_scale",
|
||||
torch.nn.Parameter(w13_input_global_scale, requires_grad=False),
|
||||
)
|
||||
|
||||
replace_parameter(layer, "w13_scales", marlin_w13_scales)
|
||||
marlin_w2_scales = marlin_moe_permute_scales(
|
||||
s=layer.w2_scales,
|
||||
size_k=layer.w2_scales.shape[1]
|
||||
* (
|
||||
self.quant_config.group_size
|
||||
if self.quant_config.group_size != -1
|
||||
else self.quant_config.pack_factor
|
||||
),
|
||||
size_n=layer.w2_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
if self.input_dtype == torch.int8 and layer.num_groups_w2 > 1:
|
||||
marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales(
|
||||
marlin_w2_scales
|
||||
)
|
||||
layer.register_parameter(
|
||||
"w2_input_global_scale",
|
||||
torch.nn.Parameter(w2_input_global_scale, requires_grad=False),
|
||||
)
|
||||
|
||||
replace_parameter(layer, "w2_scales", marlin_w2_scales)
|
||||
|
||||
if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
|
||||
layer.w13_bias.data = marlin_permute_bias(layer.w13_bias)
|
||||
|
||||
if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
|
||||
layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
|
||||
# if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
|
||||
# layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
|
||||
if self.quant_config.desc_act:
|
||||
raise NotImplementedError(
|
||||
"GPTQMarlinMoEMethod now not support desc_act. please fix it")
|
||||
w13_qweight_unpacked = unpack_k_batch(layer.w13_qweight)
|
||||
w13_qweight_repacked = pack_n_batch(w13_qweight_unpacked,self.quant_config.pack_factor,order_map=[0, 2, 4, 6, 1, 3, 5, 7])
|
||||
replace_parameter(layer, "w13_qweight", w13_qweight_repacked)
|
||||
|
||||
# quant vllm/model_executor/layers/quantization/utils/quant_utils.py#quantize_weights
|
||||
# if quant_type.has_bias():
|
||||
# w_q += quant_type.bias
|
||||
# use quant_type.bias as zp,(ixformer support)
|
||||
w13_zp = torch.full_like(layer.w13_scales, self.quant_type.bias, dtype=torch.int32)
|
||||
w13_zp_pack = pack_n_batch(w13_zp, self.quant_config.pack_factor, order_map=[0, 2, 4, 6, 1, 3, 5, 7]).contiguous()
|
||||
replace_parameter(layer, "w13_qzeros", w13_zp_pack)
|
||||
|
||||
w2_qweight_unpacked = unpack_k_batch(layer.w2_qweight)
|
||||
w2_qweight_repacked = pack_n_batch(w2_qweight_unpacked,self.quant_config.pack_factor,order_map=[0, 2, 4, 6, 1, 3, 5, 7])
|
||||
replace_parameter(layer, "w2_qweight", w2_qweight_repacked)
|
||||
|
||||
w2_zp = torch.full_like(layer.w2_scales, self.quant_type.bias, dtype=torch.int32)
|
||||
w2_zp_pack = pack_n_batch(w2_zp, self.quant_config.pack_factor, order_map=[0, 2, 4, 6, 1, 3, 5, 7]).contiguous()
|
||||
replace_parameter(layer, "w2_qzeros", w2_zp_pack)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
@@ -900,30 +1047,165 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
# Assign the value of shared_experts_output to variable shared_experts_input for fusion
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_qweight,
|
||||
layer.w2_qweight,
|
||||
getattr(layer, "w13_bias", None),
|
||||
getattr(layer, "w2_bias", None),
|
||||
layer.w13_scales,
|
||||
layer.w2_scales,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
|
||||
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
|
||||
quant_type_id=self.quant_type.id,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
g_idx1=layer.w13_g_idx,
|
||||
g_idx2=layer.w2_g_idx,
|
||||
sort_indices1=layer.w13_g_idx_sort_indices,
|
||||
sort_indices2=layer.w2_g_idx_sort_indices,
|
||||
workspace=layer.workspace,
|
||||
is_k_full=self.is_k_full,
|
||||
input_dtype=self.input_dtype,
|
||||
inplace=not self.moe.disable_inplace,
|
||||
assert layer.activation.value == "silu", "Only SiLU activation is supported."
|
||||
use_ep = layer.expert_map is not None
|
||||
|
||||
if use_ep:
|
||||
start_eid = layer.ep_rank * layer.local_num_experts
|
||||
end_eid = min((layer.ep_rank + 1) * layer.local_num_experts, layer.global_num_experts)
|
||||
|
||||
if layer.apply_router_weight_on_input:
|
||||
raise NotImplementedError(
|
||||
"GPTQMarlinMoEMethod Apply router weight on input is not supported for"
|
||||
"fused Marlin MoE method.")
|
||||
|
||||
if (hasattr(layer, "w13_bias") and layer.w13_bias is not None) or (hasattr(layer, "w2_bias") and layer.w2_bias is not None):
|
||||
raise NotImplementedError(
|
||||
"GPTQMarlinMoEMethod moe_w4a16_group_gemm not supported bias, please fix this")
|
||||
|
||||
num_tokens = topk_ids.shape[0]
|
||||
num_experts = layer.global_num_experts
|
||||
|
||||
if use_ep:
|
||||
hidden_size = x.shape[1]
|
||||
(
|
||||
src_to_dst,
|
||||
sorted_token_ids,
|
||||
expert_sizes_gpu,
|
||||
expert_sizes_cpu,
|
||||
expand_tokens,
|
||||
) = ixfops.moe_compute_token_index_ep(
|
||||
topk_ids=topk_ids,
|
||||
num_experts=num_experts,
|
||||
start_expert_id=start_eid,
|
||||
end_expert_id=end_eid,
|
||||
)
|
||||
if expert_sizes_cpu.sum() == 0:
|
||||
return torch.zeros(
|
||||
(num_tokens, hidden_size),
|
||||
device=x.device,
|
||||
dtype=x.dtype,
|
||||
)
|
||||
else:
|
||||
expand_tokens = num_tokens * layer.top_k
|
||||
(
|
||||
src_to_dst,
|
||||
sorted_token_ids,
|
||||
expert_sizes_gpu,
|
||||
expert_sizes_cpu,
|
||||
) = ixfops.moe_compute_token_index(
|
||||
topk_ids=topk_ids,
|
||||
num_experts=num_experts,
|
||||
)
|
||||
expert_sizes_cpu = expert_sizes_gpu.cpu()
|
||||
|
||||
# expand + reorder
|
||||
# TODO use kernel
|
||||
expand_hidden_states = ixfops.moe_expand_input(
|
||||
hidden_states=x,
|
||||
dst_to_src=sorted_token_ids,
|
||||
dst_tokens=expand_tokens,
|
||||
topk=layer.top_k,
|
||||
src_to_dst=src_to_dst,
|
||||
)
|
||||
|
||||
# w4a16 group gemm 1
|
||||
# pt_output_1: (expand_tokens, 2n) dtype
|
||||
pt_output_1 = ixfops.moe_w4a16_group_gemm(
|
||||
input=expand_hidden_states,
|
||||
weight=layer.w13_qweight,
|
||||
w_scales=layer.w13_scales,
|
||||
quant_type="awq",
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
w_zeros=layer.w13_qzeros,
|
||||
group_size=self.quant_config.group_size,
|
||||
dst_to_src=None,
|
||||
format="NN",
|
||||
tokens_per_experts_gpu=expert_sizes_gpu,
|
||||
)
|
||||
|
||||
# act
|
||||
pt_output_2 = ixfops.silu_and_mul(pt_output_1)
|
||||
|
||||
# w4a16 group gemm 2 + reorder
|
||||
# pt_output_3: (expand_tokens, k) dtype
|
||||
if use_ep:
|
||||
pt_output_3 = torch.empty(
|
||||
(num_tokens * layer.top_k, hidden_size),
|
||||
device=x.device,
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
ixfops.moe_w4a16_group_gemm(
|
||||
input=pt_output_2,
|
||||
weight=layer.w2_qweight,
|
||||
w_scales=layer.w2_scales,
|
||||
quant_type="awq",
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
w_zeros=layer.w2_qzeros,
|
||||
group_size=self.quant_config.group_size,
|
||||
dst_to_src=sorted_token_ids,
|
||||
format="NN",
|
||||
output=pt_output_3,
|
||||
tokens_per_experts_gpu=expert_sizes_gpu,
|
||||
)
|
||||
|
||||
reduce_mask = src_to_dst == -1
|
||||
final_hidden_states = ixfops.moe_output_reduce_sum(
|
||||
input=pt_output_3.view(num_tokens, layer.top_k, -1),
|
||||
topk_weight=topk_weights,
|
||||
scaling_factor=layer.routed_scaling_factor,
|
||||
mask=reduce_mask,
|
||||
)
|
||||
else:
|
||||
pt_output_3 = ixfops.moe_w4a16_group_gemm(
|
||||
input=pt_output_2,
|
||||
weight=layer.w2_qweight,
|
||||
w_scales=layer.w2_scales,
|
||||
quant_type="awq",
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
w_zeros=layer.w2_qzeros,
|
||||
group_size=self.quant_config.group_size,
|
||||
dst_to_src=sorted_token_ids,
|
||||
format="NN",
|
||||
tokens_per_experts_gpu=expert_sizes_gpu,
|
||||
)
|
||||
|
||||
# mul + reduce_sum
|
||||
# final_hidden_states: (num_tokens, k)
|
||||
final_hidden_states = ixfops.moe_output_reduce_sum(
|
||||
input=pt_output_3.view(num_tokens, layer.top_k, -1),
|
||||
topk_weight=topk_weights,
|
||||
scaling_factor=layer.routed_scaling_factor,
|
||||
extra_residual=shared_experts_input,
|
||||
)
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# return torch.ops.vllm.fused_marlin_moe(
|
||||
# x,
|
||||
# layer.w13_qweight,
|
||||
# layer.w2_qweight,
|
||||
# getattr(layer, "w13_bias", None),
|
||||
# getattr(layer, "w2_bias", None),
|
||||
# layer.w13_scales,
|
||||
# layer.w2_scales,
|
||||
# router_logits,
|
||||
# topk_weights,
|
||||
# topk_ids,
|
||||
# quant_type_id=self.quant_type.id,
|
||||
# apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
# global_num_experts=global_num_experts,
|
||||
# expert_map=expert_map,
|
||||
# g_idx1=layer.w13_g_idx,
|
||||
# g_idx2=layer.w2_g_idx,
|
||||
# sort_indices1=layer.w13_g_idx_sort_indices,
|
||||
# sort_indices2=layer.w2_g_idx_sort_indices,
|
||||
# workspace=layer.workspace,
|
||||
# is_k_full=self.is_k_full)
|
||||
|
||||
@@ -12,8 +12,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.attention import Attention, MLAAttention
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
@@ -24,14 +23,12 @@ from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
Fp8MoeBackend,
|
||||
convert_to_fp8_moe_kernel_format,
|
||||
make_fp8_moe_kernel,
|
||||
make_fp8_moe_quant_config,
|
||||
select_fp8_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
|
||||
NvFp4MoeBackend,
|
||||
convert_to_nvfp4_moe_kernel_format,
|
||||
is_global_sf_supported_for_nvfp4_backend,
|
||||
make_nvfp4_moe_kernel,
|
||||
@@ -49,13 +46,6 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
flashinfer_trtllm_fp4_moe,
|
||||
flashinfer_trtllm_fp4_routed_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_fi_trtllm_fp8_per_tensor_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
process_fp8_input_tensor_strategy_moe,
|
||||
@@ -114,6 +104,8 @@ QUANT_ALGOS = [
|
||||
"NVFP4",
|
||||
# MXFP8
|
||||
"MXFP8",
|
||||
# MIXED_PRECISION,
|
||||
"MIXED_PRECISION",
|
||||
]
|
||||
KV_CACHE_QUANT_ALGOS = ["FP8"]
|
||||
|
||||
@@ -181,7 +173,7 @@ class ModelOptQuantConfigBase(QuantizationConfig):
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> "QuantizeMethodBase | None":
|
||||
# handle kv-cache first so we can focus only on weight quantization thereafter
|
||||
if isinstance(layer, Attention):
|
||||
if isinstance(layer, (Attention, MLAAttention)):
|
||||
return self.KVCacheMethodCls(self)
|
||||
|
||||
# handle exclusion
|
||||
@@ -235,6 +227,26 @@ class ModelOptQuantConfigBase(QuantizationConfig):
|
||||
|
||||
self.exclude_modules = hf_to_vllm_mapper.apply_list(new_exclude_modules)
|
||||
|
||||
@staticmethod
|
||||
def _extract_modelopt_quant_algo(
|
||||
hf_quant_cfg: dict[str, Any] | None,
|
||||
) -> str | None:
|
||||
"""Extract upper-cased quant_algo from a modelopt config.
|
||||
|
||||
Returns the quant_algo string (upper-cased), or None if the config
|
||||
is not a modelopt config.
|
||||
"""
|
||||
if hf_quant_cfg is None:
|
||||
return None
|
||||
if hf_quant_cfg.get("quant_method", "").lower() != "modelopt":
|
||||
return None
|
||||
if "quantization" in hf_quant_cfg:
|
||||
quant_config = hf_quant_cfg["quantization"]
|
||||
if isinstance(quant_config, dict):
|
||||
return str(quant_config.get("quant_algo", "")).upper()
|
||||
return None
|
||||
return str(hf_quant_cfg.get("quant_algo", "")).upper()
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> list[str]:
|
||||
return ["hf_quant_config.json"]
|
||||
@@ -272,10 +284,20 @@ class ModelOptQuantConfigBase(QuantizationConfig):
|
||||
# "exclude_modules" is the key in the legacy hf_quant_config.json
|
||||
exclude_modules = quant_config.get("exclude_modules", [])
|
||||
else:
|
||||
# Compressed-tensors style format:
|
||||
# Compressed-tensors style format (config.json quantization_config):
|
||||
# {"quant_algo": "...", "quant_method": "modelopt"}
|
||||
quant_method = config.get("quant_algo")
|
||||
kv_cache_quant_method = config.get("kv_cache_quant_algo")
|
||||
|
||||
# "kv_cache_scheme" (a dict) instead of "kv_cache_quant_algo" (a string).
|
||||
kv_cache_scheme = config.get("kv_cache_scheme")
|
||||
if isinstance(kv_cache_scheme, dict) and (
|
||||
kv_cache_scheme.get("type") == "float"
|
||||
and kv_cache_scheme.get("num_bits") == 8
|
||||
):
|
||||
kv_cache_quant_method = "FP8"
|
||||
else:
|
||||
kv_cache_quant_method = None
|
||||
|
||||
# "ignore" is the key in config.json
|
||||
exclude_modules = config.get("ignore", [])
|
||||
group_size_raw = config.get("group_size")
|
||||
@@ -379,32 +401,9 @@ class ModelOptFp8Config(ModelOptQuantConfigBase):
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> QuantizationMethods | None:
|
||||
"""Detect if this ModelOpt config should be used based on
|
||||
quantization config."""
|
||||
|
||||
if hf_quant_cfg is None:
|
||||
return None
|
||||
|
||||
# Use the community standard 'quant_method'
|
||||
quant_method = hf_quant_cfg.get("quant_method", "").lower()
|
||||
|
||||
# Only proceed if the method is explicitly "modelopt"
|
||||
if quant_method != "modelopt":
|
||||
return None
|
||||
|
||||
# Look for ModelOpt-specific config structure
|
||||
if "quantization" in hf_quant_cfg:
|
||||
quant_config = hf_quant_cfg["quantization"]
|
||||
if isinstance(quant_config, dict):
|
||||
quant_algo = str(quant_config.get("quant_algo", ""))
|
||||
if quant_algo.upper() == "FP8":
|
||||
return "modelopt"
|
||||
else:
|
||||
# Check for compressed-tensors style config with specific quant_algo
|
||||
quant_algo = str(hf_quant_cfg.get("quant_algo", ""))
|
||||
if quant_algo.upper() == "FP8":
|
||||
return "modelopt"
|
||||
|
||||
algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
|
||||
if algo is not None and algo == "FP8":
|
||||
return "modelopt"
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@@ -737,7 +736,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
@@ -745,9 +744,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
) -> mk.FusedMoEExpertsModular:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
@@ -862,16 +861,15 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
# Setup modular kernel.
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config:
|
||||
assert self.experts_cls is not None
|
||||
self.moe_mk = make_fp8_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
fp8_backend=self.fp8_backend,
|
||||
experts_cls=self.experts_cls,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
shared_experts=layer.shared_experts,
|
||||
)
|
||||
assert self.experts_cls is not None
|
||||
self.moe_kernel = make_fp8_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
fp8_backend=self.fp8_backend,
|
||||
experts_cls=self.experts_cls,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
shared_experts=layer.shared_experts,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
w13 = layer.w13_weight
|
||||
@@ -904,9 +902,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
|
||||
w1_scale = layer.w13_weight_scale
|
||||
w2_scale = layer.w2_weight_scale
|
||||
a1_scale = layer.w13_input_scale
|
||||
@@ -920,10 +916,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
@@ -931,28 +923,20 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.is_monolithic
|
||||
assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for FlashInfer TRTLLM FP8 MoE Backend."
|
||||
)
|
||||
# TODO(rob): this validation should happen at kernel selection
|
||||
# time in the oracle rather than here.
|
||||
SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
assert layer.activation in SUPPORTED_ACTIVATIONS, (
|
||||
f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer "
|
||||
f"TRTLLM FP4 MoE, {layer.activation} found instead."
|
||||
)
|
||||
return apply_fi_trtllm_fp8_per_tensor_moe(
|
||||
layer=layer,
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
routing_bias=layer.e_score_correction_bias,
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply_monolithic(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
router_logits,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
top_k=layer.top_k,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
e_score_correction_bias=layer.e_score_correction_bias,
|
||||
routed_scaling_factor=layer.routed_scaling_factor,
|
||||
)
|
||||
|
||||
def apply(
|
||||
@@ -964,25 +948,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert not self.is_monolithic
|
||||
|
||||
# TODO(rob): this validation should happen at kernel selection
|
||||
# time in the oracle rather than here.
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
assert layer.activation in (
|
||||
MoEActivation.SILU,
|
||||
MoEActivation.RELU2_NO_MUL,
|
||||
), (
|
||||
"Expected activation to be in ('silu', 'relu2_no_mul'),"
|
||||
f"but got {layer.activation}"
|
||||
)
|
||||
|
||||
assert self.moe_mk is not None
|
||||
return self.moe_mk(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
@@ -1031,32 +1003,9 @@ class ModelOptNvFp4Config(ModelOptQuantConfigBase):
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> QuantizationMethods | None:
|
||||
"""Detect if this ModelOpt FP4 config should be used based on
|
||||
quantization config."""
|
||||
if hf_quant_cfg is None:
|
||||
return None
|
||||
|
||||
# Use the community standard 'quant_method'
|
||||
quant_method = hf_quant_cfg.get("quant_method", "").lower()
|
||||
|
||||
# Only proceed if the method is explicitly "modelopt"
|
||||
if quant_method != "modelopt":
|
||||
return None
|
||||
|
||||
# Look for ModelOpt-specific config structure
|
||||
if "quantization" in hf_quant_cfg:
|
||||
quant_config = hf_quant_cfg["quantization"]
|
||||
if isinstance(quant_config, dict):
|
||||
quant_algo = quant_config.get("quant_algo", "")
|
||||
if "NVFP4" in quant_algo:
|
||||
return "modelopt_fp4"
|
||||
else:
|
||||
# Check for compressed-tensors style config with specific
|
||||
# quant_algo field
|
||||
quant_algo = hf_quant_cfg.get("quant_algo", "")
|
||||
if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
|
||||
return "modelopt_fp4"
|
||||
|
||||
algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
|
||||
if algo is not None and ("NVFP4" in algo or "FP4" in algo):
|
||||
return "modelopt_fp4"
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@@ -1249,17 +1198,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
@@ -1434,51 +1373,18 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
replace_parameter(layer, "w2_weight_scale_2", w2_scale_2)
|
||||
replace_parameter(layer, "w2_input_scale", a2_scale)
|
||||
|
||||
# Setup modular kernel for TP case and naive DP/EP case.
|
||||
# In non-naive DP/EP case, we will create a ModularKernelMethod.
|
||||
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
|
||||
# in both cases.
|
||||
# Setup modular kernel.
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config:
|
||||
assert self.experts_cls is not None
|
||||
self.moe_mk = make_nvfp4_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
experts_cls=self.experts_cls,
|
||||
shared_experts=layer.shared_experts,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
)
|
||||
|
||||
@property
|
||||
def do_post_quant_allgather(self):
|
||||
return self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
|
||||
|
||||
def prepare_dp_allgather_tensor(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
||||
"""Optionally prepare extra tensors to carry through DP allgather/EP."""
|
||||
if self.nvfp4_backend != NvFp4MoeBackend.FLASHINFER_TRTLLM:
|
||||
raise RuntimeError(
|
||||
"prepare_dp_allgather_tensor is only supported for "
|
||||
"FlashInfer TRTLLM NVFP4 MoE backend."
|
||||
)
|
||||
|
||||
import flashinfer
|
||||
|
||||
hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize(
|
||||
hidden_states,
|
||||
layer.a1_gscale,
|
||||
is_sf_swizzled_layout=False,
|
||||
assert self.experts_cls is not None
|
||||
self.moe_kernel = make_nvfp4_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
experts_cls=self.experts_cls,
|
||||
shared_experts=layer.shared_experts,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
)
|
||||
extra_tensors: list[torch.Tensor] = [hidden_states_sf]
|
||||
return hidden_states_fp4, extra_tensors
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
|
||||
return make_nvfp4_moe_quant_config(
|
||||
backend=self.nvfp4_backend,
|
||||
w13_scale=layer.w13_weight_scale,
|
||||
@@ -1493,13 +1399,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
def supports_eplb(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return (
|
||||
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
|
||||
and not self.moe.moe_parallel_config.enable_eplb
|
||||
)
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
@@ -1507,22 +1406,20 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.is_monolithic
|
||||
assert (
|
||||
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
|
||||
and not layer.enable_eplb
|
||||
)
|
||||
|
||||
return flashinfer_trtllm_fp4_moe(
|
||||
layer=layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
top_k=layer.top_k,
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply_monolithic(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
router_logits,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
custom_routing_function=layer.custom_routing_function,
|
||||
e_score_correction_bias=layer.e_score_correction_bias,
|
||||
routed_scaling_factor=layer.routed_scaling_factor,
|
||||
)
|
||||
|
||||
def apply(
|
||||
@@ -1534,33 +1431,19 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert not self.is_monolithic
|
||||
|
||||
# EPLB path
|
||||
if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
|
||||
assert layer.enable_eplb
|
||||
return flashinfer_trtllm_fp4_routed_moe(
|
||||
layer=layer,
|
||||
x=x,
|
||||
topk_ids=topk_ids,
|
||||
topk_weights=topk_weights,
|
||||
top_k=layer.top_k,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
)
|
||||
else:
|
||||
assert self.moe_mk is not None
|
||||
return self.moe_mk(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
|
||||
ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
|
||||
@@ -1619,31 +1502,9 @@ class ModelOptMxFp8Config(ModelOptQuantConfigBase):
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> QuantizationMethods | None:
|
||||
"""Detect if this ModelOpt MXFP8 config should be used based on
|
||||
quantization config."""
|
||||
if hf_quant_cfg is None:
|
||||
return None
|
||||
|
||||
# Use the community standard 'quant_method'
|
||||
quant_method = hf_quant_cfg.get("quant_method", "").lower()
|
||||
|
||||
# Only proceed if the method is explicitly "modelopt"
|
||||
if quant_method != "modelopt":
|
||||
return None
|
||||
|
||||
# Look for ModelOpt-specific config structure
|
||||
if "quantization" in hf_quant_cfg:
|
||||
quant_config = hf_quant_cfg["quantization"]
|
||||
if isinstance(quant_config, dict):
|
||||
quant_algo = str(quant_config.get("quant_algo", "")).upper()
|
||||
if "MXFP8" in quant_algo:
|
||||
return "modelopt_mxfp8"
|
||||
else:
|
||||
# Check for compressed-tensors style config with specific quant_algo
|
||||
quant_algo = str(hf_quant_cfg.get("quant_algo", "")).upper()
|
||||
if "MXFP8" in quant_algo:
|
||||
return "modelopt_mxfp8"
|
||||
|
||||
algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
|
||||
if algo is not None and "MXFP8" in algo:
|
||||
return "modelopt_mxfp8"
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@@ -1841,3 +1702,188 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
|
||||
# Register the method classes for ModelOptMxFp8Config
|
||||
ModelOptMxFp8Config.LinearMethodCls = ModelOptMxFp8LinearMethod
|
||||
ModelOptMxFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
|
||||
|
||||
|
||||
class ModelOptMixedPrecisionConfig(ModelOptQuantConfigBase):
|
||||
"""Config class for ModelOpt MIXED_PRECISION.
|
||||
|
||||
Supports checkpoints where different layers use different quantization
|
||||
algorithms (e.g., FP8 for dense layers and NVFP4 for MoE experts).
|
||||
The per-layer algorithm is specified in the ``quantized_layers`` dict
|
||||
inside ``config.json``'s ``quantization_config`` (preferred) or the
|
||||
legacy ``hf_quant_config.json``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_quant_method: str | None,
|
||||
exclude_modules: list[str],
|
||||
quantized_layers: dict[str, dict[str, Any]],
|
||||
fp8_config: ModelOptFp8Config,
|
||||
nvfp4_config: ModelOptNvFp4Config,
|
||||
) -> None:
|
||||
super().__init__(exclude_modules)
|
||||
self.kv_cache_quant_method = kv_cache_quant_method
|
||||
self.quantized_layers = quantized_layers
|
||||
self.fp8_config = fp8_config
|
||||
self.nvfp4_config = nvfp4_config
|
||||
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "modelopt_mixed"
|
||||
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 89
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> QuantizationMethods | None:
|
||||
algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
|
||||
if algo is not None and algo == "MIXED_PRECISION":
|
||||
return "modelopt_mixed"
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _from_config(
|
||||
cls,
|
||||
*,
|
||||
quant_method: str,
|
||||
kv_cache_quant_method: str | None,
|
||||
exclude_modules: list[str],
|
||||
original_config: dict[str, Any],
|
||||
group_size: int | None,
|
||||
**kwargs: Any,
|
||||
) -> "ModelOptMixedPrecisionConfig":
|
||||
if "quantization" in original_config:
|
||||
quantized_layers = original_config["quantization"].get(
|
||||
"quantized_layers", {}
|
||||
)
|
||||
else:
|
||||
quantized_layers = original_config.get("quantized_layers", {})
|
||||
|
||||
if not quantized_layers:
|
||||
raise ValueError(
|
||||
"MIXED_PRECISION quant_algo requires a non-empty "
|
||||
"'quantized_layers' mapping in the quantization config."
|
||||
)
|
||||
|
||||
# Determine group_size from the first NVFP4 entry if not provided.
|
||||
if group_size is None:
|
||||
for layer_info in quantized_layers.values():
|
||||
if layer_info.get("quant_algo", "").upper() == "NVFP4":
|
||||
group_size = layer_info.get("group_size", 16)
|
||||
break
|
||||
if group_size is None:
|
||||
group_size = 16
|
||||
|
||||
fp8_config = ModelOptFp8Config(
|
||||
quant_method="FP8",
|
||||
is_checkpoint_fp8_serialized=True,
|
||||
kv_cache_quant_method=kv_cache_quant_method,
|
||||
exclude_modules=[],
|
||||
)
|
||||
nvfp4_config = ModelOptNvFp4Config(
|
||||
is_checkpoint_nvfp4_serialized=True,
|
||||
kv_cache_quant_algo=kv_cache_quant_method,
|
||||
exclude_modules=[],
|
||||
group_size=group_size,
|
||||
)
|
||||
|
||||
return cls(
|
||||
kv_cache_quant_method=kv_cache_quant_method,
|
||||
exclude_modules=exclude_modules,
|
||||
quantized_layers=quantized_layers,
|
||||
fp8_config=fp8_config,
|
||||
nvfp4_config=nvfp4_config,
|
||||
)
|
||||
|
||||
def _resolve_quant_algo(self, prefix: str) -> str | None:
|
||||
"""Look up the quant_algo for a vLLM-side layer prefix.
|
||||
|
||||
Tries three strategies in order:
|
||||
1. Direct lookup in ``quantized_layers``.
|
||||
2. Packed/fused-layer lookup (unfuse via ``packed_modules_mapping``).
|
||||
3. Prefix-based lookup for FusedMoE (any child key starts with
|
||||
``prefix + "."``).
|
||||
|
||||
Returns the upper-cased quant_algo string, or *None* if the prefix
|
||||
is not found.
|
||||
"""
|
||||
# 1. Direct lookup
|
||||
if prefix in self.quantized_layers:
|
||||
return self.quantized_layers[prefix]["quant_algo"].upper()
|
||||
|
||||
# 2. Packed / fused layer lookup
|
||||
proj_name = prefix.rsplit(".", 1)[-1]
|
||||
if self.packed_modules_mapping and proj_name in self.packed_modules_mapping:
|
||||
algos: set[str] = set()
|
||||
base = prefix.rsplit(".", 1)[0]
|
||||
for shard_name in self.packed_modules_mapping[proj_name]:
|
||||
shard_prefix = f"{base}.{shard_name}"
|
||||
if shard_prefix in self.quantized_layers:
|
||||
algos.add(self.quantized_layers[shard_prefix]["quant_algo"].upper())
|
||||
if len(algos) == 1:
|
||||
return algos.pop()
|
||||
if len(algos) > 1:
|
||||
raise ValueError(
|
||||
f"Mixed quant_algo within fused layer {prefix}: "
|
||||
f"{algos}. All shards must use the same quantization."
|
||||
)
|
||||
|
||||
# 3. Prefix-based lookup (for FusedMoE / parent modules)
|
||||
prefix_dot = prefix + "."
|
||||
for key, info in self.quantized_layers.items():
|
||||
if key.startswith(prefix_dot):
|
||||
return info["quant_algo"].upper()
|
||||
|
||||
return None
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> "QuantizeMethodBase | None":
|
||||
"""Return quantize-method based on layer."""
|
||||
# KV-cache quantization
|
||||
if isinstance(layer, Attention):
|
||||
if self.kv_cache_quant_method:
|
||||
return ModelOptFp8KVCacheMethod(self)
|
||||
return None
|
||||
|
||||
# Excluded layers
|
||||
if self.is_layer_excluded(prefix):
|
||||
if isinstance(layer, LinearBase):
|
||||
return UnquantizedLinearMethod()
|
||||
return None
|
||||
|
||||
quant_algo = self._resolve_quant_algo(prefix)
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if quant_algo == "FP8":
|
||||
return ModelOptFp8LinearMethod(self.fp8_config)
|
||||
if quant_algo == "NVFP4":
|
||||
return ModelOptNvFp4LinearMethod(self.nvfp4_config)
|
||||
# Layer not in quantized_layers — leave unquantized
|
||||
return UnquantizedLinearMethod()
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
if quant_algo == "FP8":
|
||||
return ModelOptFp8MoEMethod(
|
||||
quant_config=self.fp8_config,
|
||||
moe_config=layer.moe_config,
|
||||
)
|
||||
if quant_algo == "NVFP4":
|
||||
return ModelOptNvFp4FusedMoE(
|
||||
quant_config=self.nvfp4_config,
|
||||
moe_config=layer.moe_config,
|
||||
)
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||
super().apply_vllm_mapper(hf_to_vllm_mapper)
|
||||
if self.quantized_layers:
|
||||
self.quantized_layers = hf_to_vllm_mapper.apply_dict(self.quantized_layers)
|
||||
|
||||
@@ -6,6 +6,7 @@ import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
@@ -77,6 +78,8 @@ class Mxfp4Backend(Enum):
|
||||
# Triton Backend
|
||||
TRITON = 6
|
||||
|
||||
CK = 7
|
||||
|
||||
|
||||
def get_mxfp4_backend_with_lora() -> Mxfp4Backend:
|
||||
"""
|
||||
@@ -167,9 +170,15 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
|
||||
elif current_platform.is_xpu():
|
||||
logger.info_once("Using xpu backend on XPU")
|
||||
return Mxfp4Backend.MARLIN
|
||||
elif current_platform.is_rocm() and has_triton_kernels():
|
||||
logger.info_once("Using Triton backend")
|
||||
return Mxfp4Backend.TRITON
|
||||
elif current_platform.is_rocm():
|
||||
from vllm.platforms.rocm import on_gfx950
|
||||
|
||||
if rocm_aiter_ops.is_enabled() and on_gfx950():
|
||||
logger.info_once("Using CK MXFP4 MoE backend (Aiter ROCm)")
|
||||
return Mxfp4Backend.CK
|
||||
elif has_triton_kernels():
|
||||
logger.info_once("Using Triton backend")
|
||||
return Mxfp4Backend.TRITON
|
||||
|
||||
return Mxfp4Backend.NONE
|
||||
|
||||
@@ -257,7 +266,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
|
||||
# Initialized in process_weights_after_loading for CUTLASS/SM90 backends
|
||||
self.moe_mk: mk.FusedMoEModularKernel | None = None
|
||||
self.moe_kernel: mk.FusedMoEKernel | None = None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -338,6 +347,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
self.intermediate_size = intermediate_size_per_partition_after_pad
|
||||
self.hidden_size = hidden_size
|
||||
self.hidden_pad = extra_weight_attrs.get("hidden_pad", 0)
|
||||
self.intermediate_pad = (
|
||||
intermediate_size_per_partition_after_pad - intermediate_size_per_partition
|
||||
)
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
@@ -427,7 +440,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
assert prepare_finalize is not None
|
||||
|
||||
self.moe_mk = mk.FusedMoEModularKernel(
|
||||
self.moe_kernel = mk.FusedMoEKernel(
|
||||
prepare_finalize,
|
||||
MarlinExperts(
|
||||
self.moe,
|
||||
@@ -776,7 +789,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
assert prepare_finalize is not None
|
||||
|
||||
self.moe_mk = mk.FusedMoEModularKernel(
|
||||
self.moe_kernel = mk.FusedMoEKernel(
|
||||
prepare_finalize,
|
||||
FlashInferExperts(
|
||||
moe_config=self.moe,
|
||||
@@ -784,6 +797,66 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
),
|
||||
shared_experts=None,
|
||||
)
|
||||
elif self.mxfp4_backend == Mxfp4Backend.CK:
|
||||
if layer.w13_bias is not None:
|
||||
layer.w13_bias.data = layer.w13_bias.data.to(torch.float32)
|
||||
if layer.w2_bias.data is not None:
|
||||
layer.w2_bias.data = layer.w2_bias.data.to(torch.float32)
|
||||
|
||||
e, n, k = layer.w13_weight.shape
|
||||
layer.w13_weight.view(torch.uint8).copy_(
|
||||
layer.w13_weight.data.view(torch.uint8)
|
||||
.view(e, n // 2, 2, k)
|
||||
.permute(0, 2, 1, 3)
|
||||
.contiguous()
|
||||
.view(e, n, k)
|
||||
)
|
||||
layer.w13_weight_scale.data = (
|
||||
layer.w13_weight_scale.data.view(e, n // 2, 2, -1)
|
||||
.permute(0, 2, 1, 3)
|
||||
.contiguous()
|
||||
.view(e, n, -1)
|
||||
)
|
||||
layer.w13_weight.data = layer.w13_weight.data.view(torch.float4_e2m1fn_x2)
|
||||
layer.w2_weight.data = layer.w2_weight.data.view(torch.float4_e2m1fn_x2)
|
||||
|
||||
layer.w13_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(
|
||||
layer.w13_weight, 16, True
|
||||
)
|
||||
shuffled_w13_scale = rocm_aiter_ops.shuffle_scale_a16w4(
|
||||
layer.w13_weight_scale.view(-1, layer.w13_weight_scale.shape[-1]),
|
||||
self.num_experts,
|
||||
True,
|
||||
)
|
||||
|
||||
layer.w2_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(
|
||||
layer.w2_weight, 16, False
|
||||
)
|
||||
shuffled_w2_scale = rocm_aiter_ops.shuffle_scale_a16w4(
|
||||
layer.w2_weight_scale.view(-1, layer.w2_weight_scale.shape[-1]),
|
||||
self.num_experts,
|
||||
False,
|
||||
)
|
||||
|
||||
layer.w13_bias.data = (
|
||||
layer.w13_bias.data.view(-1, n // 2, 2)
|
||||
.permute(0, 2, 1)
|
||||
.contiguous()
|
||||
.view(-1, n)
|
||||
)
|
||||
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
shuffled_w13_scale, requires_grad=False
|
||||
)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(
|
||||
shuffled_w2_scale, requires_grad=False
|
||||
)
|
||||
# replace_parameter(layer, "w13_bias", w13_bias)
|
||||
# replace_parameter(layer, "w13_weight_scale", w13_weight_scale)
|
||||
# replace_parameter(layer, "w2_weight_scale", w2_weight_scale)
|
||||
# replace_parameter(layer, "w13_weight", w13_weight)
|
||||
# replace_parameter(layer, "w2_weight", w2_weight)
|
||||
|
||||
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
|
||||
@@ -792,18 +865,16 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
layer.w13_bias = Parameter(w13_bias, requires_grad=False)
|
||||
layer.w2_bias = Parameter(w2_bias, requires_grad=False)
|
||||
|
||||
# Ideally we'd use FusedMoEModularKernel.prepare_finalize object
|
||||
# (stored in self.fused_experts) to determine if the MoE has a
|
||||
# batched activation format. As self.fused_experts is not
|
||||
# initialized at this point, we resort to checking the MoE config
|
||||
# directly.
|
||||
is_batched_moe = self.moe.use_pplx_kernels or self.moe.use_deepep_ll_kernels
|
||||
is_batched_moe = self.moe.use_deepep_ll_kernels
|
||||
if is_batched_moe:
|
||||
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
|
||||
else:
|
||||
num_warps = 8
|
||||
|
||||
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
|
||||
layer.w13_weight, layer.w13_weight_scale, num_warps
|
||||
)
|
||||
@@ -817,13 +888,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
self.w2_precision_config = PrecisionConfig(
|
||||
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
|
||||
)
|
||||
|
||||
self.w13_weight = w13_weight
|
||||
self.w2_weight = w2_weight
|
||||
del layer.w13_weight
|
||||
del layer.w2_weight
|
||||
layer.w13_weight = w13_weight
|
||||
layer.w2_weight = w2_weight
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported mxfp4_backend: {self.mxfp4_backend}: "
|
||||
@@ -862,6 +933,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
elif self.mxfp4_backend in [
|
||||
Mxfp4Backend.SM100_FI_MXFP4_BF16,
|
||||
Mxfp4Backend.SM90_FI_MXFP4_BF16,
|
||||
Mxfp4Backend.CK,
|
||||
]:
|
||||
return mxfp4_w4a16_moe_quant_config(
|
||||
w1_bias=layer.w13_bias,
|
||||
@@ -882,9 +954,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
) -> mk.FusedMoEExpertsModular:
|
||||
if (
|
||||
prepare_finalize.activation_format
|
||||
== mk.FusedMoEActivationFormat.BatchedExperts
|
||||
@@ -929,10 +1001,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
if self.moe.is_lora_enabled:
|
||||
return False
|
||||
return (
|
||||
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
|
||||
or self.mxfp4_backend == Mxfp4Backend.TRITON
|
||||
or self.mxfp4_backend == Mxfp4Backend.CK
|
||||
)
|
||||
|
||||
def apply(
|
||||
@@ -968,8 +1043,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
or self.mxfp4_backend == Mxfp4Backend.MARLIN
|
||||
)
|
||||
|
||||
assert self.moe_mk is not None
|
||||
return self.moe_mk(
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
@@ -1054,6 +1129,27 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
tune_max_num_tokens=max(self.max_capture_size, 1),
|
||||
)[0]
|
||||
return trtllm_gen_output
|
||||
elif self.mxfp4_backend == Mxfp4Backend.CK:
|
||||
topk_weights, topk_ids = rocm_aiter_ops.fused_topk(
|
||||
x, router_logits, layer.top_k, True
|
||||
)
|
||||
output = rocm_aiter_ops.fused_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation_method=rocm_aiter_ops.get_aiter_activation_type("swiglu"),
|
||||
quant_method=rocm_aiter_ops.get_aiter_quant_type("per_1x32"),
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
doweight_stage1=False,
|
||||
hidden_pad=self.hidden_pad // 128 * 128,
|
||||
intermediate_pad=self.intermediate_pad // 64 * 64 * 2,
|
||||
bias1=layer.w13_bias,
|
||||
bias2=layer.w2_bias,
|
||||
)
|
||||
return output
|
||||
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501
|
||||
triton_kernel_moe_forward,
|
||||
@@ -1162,7 +1258,7 @@ class XpuMxfp4MoEMethod(Mxfp4MoEMethod):
|
||||
topk_weights=routing_weights,
|
||||
topk_ids=selected_experts,
|
||||
n_experts_per_token=layer.top_k,
|
||||
activation=layer.activation,
|
||||
activation=layer.activation.value,
|
||||
num_experts=layer.local_num_experts,
|
||||
is_mxfp4=True,
|
||||
)
|
||||
|
||||
@@ -7,7 +7,6 @@ import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
@@ -26,10 +25,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class PTPCFp8Config(Fp8Config):
|
||||
"""Config class for Per-Token-Per-Channel Dynamic Quantization Fp8."""
|
||||
|
||||
@@ -35,6 +35,7 @@ from vllm.model_executor.layers.quantization.quark.utils import (
|
||||
)
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.config import get_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
@@ -59,6 +60,22 @@ class QuarkConfig(QuantizationConfig):
|
||||
self.kv_cache_group = kv_cache_group
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.pack_method = pack_method
|
||||
self.dynamic_mxfp4_quant = False
|
||||
|
||||
def maybe_update_config(self, model_name: str, revision: str | None = None):
|
||||
self.hf_config = get_config(
|
||||
model=model_name,
|
||||
trust_remote_code=False, # or get from model_config if available
|
||||
revision=revision,
|
||||
config_format="auto",
|
||||
)
|
||||
|
||||
quant_config = getattr(self.hf_config, "quantization_config", None)
|
||||
if quant_config is not None:
|
||||
quant_dtype = quant_config["global_quant_config"]["weight"]["dtype"]
|
||||
model_type = self.hf_config.model_type
|
||||
if quant_dtype == "fp4" and model_type == "deepseek_v3":
|
||||
self.dynamic_mxfp4_quant = True
|
||||
|
||||
def get_linear_method(self) -> "QuarkLinearMethod":
|
||||
return QuarkLinearMethod(self)
|
||||
@@ -108,7 +125,20 @@ class QuarkConfig(QuantizationConfig):
|
||||
if should_ignore_layer(
|
||||
prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
if (
|
||||
"self_attn" not in prefix # only quantize attention projections
|
||||
or not getattr(self, "dynamic_mxfp4_quant", False)
|
||||
or not isinstance(layer, LinearBase) # Ignore other methods
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
|
||||
scheme = self.get_scheme(
|
||||
layer=layer,
|
||||
layer_name=prefix,
|
||||
dynamic_mxfp4_quant=True,
|
||||
)
|
||||
layer.scheme = scheme
|
||||
return QuarkLinearMethod(self)
|
||||
if isinstance(layer, LinearBase):
|
||||
scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
||||
layer.scheme = scheme
|
||||
@@ -450,7 +480,9 @@ class QuarkConfig(QuantizationConfig):
|
||||
)
|
||||
return global_quant_config
|
||||
|
||||
def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme":
|
||||
def _get_scheme_from_config(
|
||||
self, config: dict[str, Any], dynamic_mxfp4_quant: bool = False
|
||||
) -> "QuarkScheme":
|
||||
if config.get("output_tensors") or config.get("bias"):
|
||||
raise NotImplementedError(
|
||||
"Currently, Quark models with output_tensors "
|
||||
@@ -473,7 +505,9 @@ class QuarkConfig(QuantizationConfig):
|
||||
input_symmetric=input_config.get("symmetric"),
|
||||
)
|
||||
elif self._is_w_ocp_mx_a_x(weight_config, input_config):
|
||||
return QuarkOCP_MX(weight_config, input_config)
|
||||
return QuarkOCP_MX(
|
||||
weight_config, input_config, dynamic_mxfp4_quant=dynamic_mxfp4_quant
|
||||
)
|
||||
|
||||
raise NotImplementedError(
|
||||
"No quark compatible scheme was found. "
|
||||
@@ -481,11 +515,15 @@ class QuarkConfig(QuantizationConfig):
|
||||
f"Input config: {input_config}"
|
||||
)
|
||||
|
||||
def get_scheme(self, layer: torch.nn.Module, layer_name: str) -> "QuarkScheme":
|
||||
def get_scheme(
|
||||
self, layer: torch.nn.Module, layer_name: str, dynamic_mxfp4_quant: bool = False
|
||||
) -> "QuarkScheme":
|
||||
layer_quant_config = self._find_matched_config(layer_name, layer)
|
||||
|
||||
# Find the quant_scheme
|
||||
scheme = self._get_scheme_from_config(layer_quant_config)
|
||||
scheme = self._get_scheme_from_config(
|
||||
layer_quant_config, dynamic_mxfp4_quant=dynamic_mxfp4_quant
|
||||
)
|
||||
# Raise error if device does not support the scheme
|
||||
# (e.g. fp8 needs ada lovelace)
|
||||
self._check_scheme_supported(scheme.get_min_capability())
|
||||
|
||||
@@ -5,8 +5,8 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
@@ -32,6 +32,7 @@ from vllm.model_executor.layers.quantization.mxfp4 import (
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
prepare_fp8_moe_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import _swizzle_mxfp4
|
||||
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
|
||||
OCP_MX_BLOCK_SIZE,
|
||||
OCP_MX_Scheme,
|
||||
@@ -49,7 +50,11 @@ from vllm.utils.math_utils import round_up
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod", "QuarkOCP_MX_MoEMethod"]
|
||||
__all__ = [
|
||||
"QuarkMoEMethod",
|
||||
"QuarkOCP_MX_MoEMethod",
|
||||
"QuarkOCP_MX_MoEMethod_OSS",
|
||||
]
|
||||
|
||||
|
||||
class QuarkMoEMethod(FusedMoEMethodBase):
|
||||
@@ -71,14 +76,30 @@ class QuarkMoEMethod(FusedMoEMethodBase):
|
||||
"output_tensors and bias "
|
||||
"quantized are not supported"
|
||||
)
|
||||
|
||||
weight_config = layer_quant_config.get("weight")
|
||||
input_config = layer_quant_config.get("input_tensors")
|
||||
|
||||
if quant_config._is_fp8_w4a8(weight_config, input_config):
|
||||
return QuarkW4A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
|
||||
elif quant_config._is_fp8_w8a8(weight_config, input_config):
|
||||
return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
|
||||
elif quant_config._is_w_ocp_mx_a_x(weight_config, input_config):
|
||||
return QuarkOCP_MX_MoEMethod(weight_config, input_config, module.moe_config)
|
||||
emulate = not current_platform.supports_mx() or not (
|
||||
rocm_aiter_ops.is_fused_moe_enabled()
|
||||
)
|
||||
if (
|
||||
input_config.get("dtype") == "fp8_e4m3"
|
||||
and not input_config.get("is_dynamic")
|
||||
and not emulate
|
||||
):
|
||||
return QuarkOCP_MX_MoEMethod_OSS(
|
||||
weight_config, input_config, module.moe_config
|
||||
)
|
||||
else:
|
||||
return QuarkOCP_MX_MoEMethod(
|
||||
weight_config, input_config, module.moe_config
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("Unsupported FusedMoe scheme")
|
||||
|
||||
@@ -706,13 +727,11 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
get_current_vllm_config().model_config.hf_config, "model_type", None
|
||||
)
|
||||
|
||||
self._emulate = (
|
||||
self.emulate = (
|
||||
not current_platform.supports_mx()
|
||||
or not self.ocp_mx_scheme.startswith("w_mxfp4")
|
||||
) and (self.mxfp4_backend is None or not self.use_rocm_aiter_moe)
|
||||
|
||||
self.emulate = True if self.model_type == "gpt_oss" else self._emulate
|
||||
|
||||
if self.emulate:
|
||||
logger.warning_once(
|
||||
f"The current mode (supports_mx={current_platform.supports_mx()}, "
|
||||
@@ -753,6 +772,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
)
|
||||
|
||||
params_dtype = torch.uint8
|
||||
self.intermediate_size_per_partition = intermediate_size_per_partition
|
||||
if self.model_type == "gpt_oss":
|
||||
if current_platform.is_rocm():
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
@@ -765,6 +785,10 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
else:
|
||||
intermediate_size_per_partition_after_pad = intermediate_size_per_partition
|
||||
|
||||
self.unpadded_hidden_size = extra_weight_attrs.get(
|
||||
"unpadded_hidden_size", hidden_size
|
||||
)
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
@@ -991,30 +1015,20 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if not self.emulate:
|
||||
if (
|
||||
self.model_type == "gpt_oss"
|
||||
and self.mxfp4_backend == Mxfp4Backend.TRITON
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"Triton kernel implemented fused MoE for GPT_OSS model "
|
||||
"in Quark(MoE) format is not integrated or provided yet."
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
rocm_aiter_fused_experts,
|
||||
)
|
||||
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
rocm_aiter_fused_experts,
|
||||
)
|
||||
|
||||
return rocm_aiter_fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=layer.activation,
|
||||
quant_config=self.moe_quant_config,
|
||||
expert_map=layer.expert_map,
|
||||
)
|
||||
return rocm_aiter_fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=layer.activation,
|
||||
quant_config=self.moe_quant_config,
|
||||
expert_map=layer.expert_map,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
@@ -1031,3 +1045,133 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
expert_map=layer.expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
|
||||
class QuarkOCP_MX_MoEMethod_OSS(QuarkOCP_MX_MoEMethod):
|
||||
def __init__(
|
||||
self,
|
||||
weight_config: dict[str, Any],
|
||||
input_config: dict[str, Any],
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(weight_config, input_config, moe)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
|
||||
w13_bias = layer.w13_bias.to(torch.float32)
|
||||
w2_bias = layer.w2_bias.to(torch.float32)
|
||||
|
||||
layer.w13_bias = torch.nn.Parameter(w13_bias, requires_grad=False)
|
||||
layer.w2_bias = torch.nn.Parameter(w2_bias, requires_grad=False)
|
||||
|
||||
# FIXME warp need to be adjusted based on batch size
|
||||
# only apply to batched mode
|
||||
if self.moe.use_ep:
|
||||
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
|
||||
else:
|
||||
num_warps = 8
|
||||
|
||||
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
|
||||
layer.w13_weight, layer.w13_weight_scale, num_warps
|
||||
)
|
||||
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
|
||||
layer.w2_weight, layer.w2_weight_scale, num_warps
|
||||
)
|
||||
|
||||
self.w13_weight_triton_tensor = w13_weight
|
||||
self.w2_weight_triton_tensor = w2_weight
|
||||
|
||||
# need to delete the original weights to save memory on single GPU
|
||||
del layer.w13_weight
|
||||
del layer.w2_weight
|
||||
layer.w13_weight = None
|
||||
layer.w2_weight = None
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if self.static_input_scales:
|
||||
if layer.w13_input_scale is None or layer.w2_input_scale is None:
|
||||
raise ValueError(
|
||||
"QuantConfig has static quantization, but found "
|
||||
"activation scales are None."
|
||||
)
|
||||
if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
|
||||
layer.w2_input_scale
|
||||
):
|
||||
logger.warning_once(
|
||||
"Found input_scales that are not equal for "
|
||||
"fp8 MoE layer. Using the maximum across experts "
|
||||
"for each layer."
|
||||
)
|
||||
|
||||
layer.w13_input_scale = torch.nn.Parameter(
|
||||
layer.w13_input_scale.max().to(torch.float32), requires_grad=False
|
||||
)
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
layer.w2_input_scale.max().to(torch.float32), requires_grad=False
|
||||
)
|
||||
|
||||
from triton_kernels.numerics import InFlexData
|
||||
|
||||
lhs_data13 = InFlexData(scale=layer.w13_input_scale)
|
||||
lhs_data2 = InFlexData(scale=layer.w2_input_scale)
|
||||
|
||||
self.w13_precision_config = PrecisionConfig(
|
||||
weight_scale=w13_scale,
|
||||
flex_ctx=FlexCtx(rhs_data=w13_flex, lhs_data=lhs_data13),
|
||||
)
|
||||
|
||||
self.w2_precision_config = PrecisionConfig(
|
||||
weight_scale=w2_scale,
|
||||
flex_ctx=FlexCtx(rhs_data=w2_flex, lhs_data=lhs_data2),
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
return mxfp4_w4a8_moe_quant_config(
|
||||
w1_scale=self.w13_precision_config,
|
||||
w2_scale=self.w2_precision_config,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
w1_bias=layer.w13_bias,
|
||||
w2_bias=layer.w2_bias,
|
||||
block_shape=None,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return True
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `QuarkW4MXFp4MoEMethod_OSS` yet."
|
||||
)
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501
|
||||
triton_kernel_moe_forward,
|
||||
)
|
||||
|
||||
return triton_kernel_moe_forward(
|
||||
hidden_states=x,
|
||||
w1=self.w13_weight_triton_tensor,
|
||||
w2=self.w2_weight_triton_tensor,
|
||||
gating_output=router_logits,
|
||||
topk=layer.top_k,
|
||||
renormalize=layer.renormalize,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
unpadded_N_w1=self.intermediate_size_per_partition * 2,
|
||||
unpadded_K_w1=self.unpadded_hidden_size,
|
||||
unpadded_N_w2=self.unpadded_hidden_size,
|
||||
unpadded_K_w2=self.intermediate_size_per_partition,
|
||||
)
|
||||
|
||||
@@ -24,7 +24,12 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
|
||||
OCP_MX_BLOCK_SIZE,
|
||||
OCP_MX_Scheme,
|
||||
)
|
||||
from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter
|
||||
from vllm.model_executor.parameter import (
|
||||
GroupQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PackedvLLMParameter,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .quark_scheme import QuarkScheme
|
||||
@@ -169,13 +174,16 @@ except (ImportError, AttributeError, RuntimeError):
|
||||
|
||||
class QuarkOCP_MX(QuarkScheme):
|
||||
def __init__(
|
||||
self, weight_quant_spec: dict[str, Any], input_quant_spec: dict[str, Any]
|
||||
self,
|
||||
weight_quant_spec: dict[str, Any],
|
||||
input_quant_spec: dict[str, Any],
|
||||
dynamic_mxfp4_quant: bool = False,
|
||||
):
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.qscheme = "per_group"
|
||||
self.weight_quant_spec = weight_quant_spec
|
||||
self.input_quant_spec = input_quant_spec
|
||||
|
||||
self.dynamic_mxfp4_quant = dynamic_mxfp4_quant
|
||||
self.weight_dtype = weight_quant_spec["dtype"].replace("fp", "mxfp")
|
||||
self.input_dtype = input_quant_spec["dtype"].replace("fp", "mxfp")
|
||||
|
||||
@@ -269,7 +277,13 @@ class QuarkOCP_MX(QuarkScheme):
|
||||
layer.weight_scale.data, requires_grad=False
|
||||
)
|
||||
else:
|
||||
if self.rocm_use_aiter_fp4_asm_gemm:
|
||||
if self.dynamic_mxfp4_quant:
|
||||
w_q, w_s = dynamic_mxfp4_quant(layer.weight)
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
w_s.T.contiguous(), requires_grad=False
|
||||
)
|
||||
layer.weight = torch.nn.Parameter(w_q, requires_grad=False)
|
||||
elif self.rocm_use_aiter_fp4_asm_gemm:
|
||||
# shuffle weight scale
|
||||
weight_scale_shuffle = layer.weight_scale.data
|
||||
sm, sn = weight_scale_shuffle.shape
|
||||
@@ -302,36 +316,51 @@ class QuarkOCP_MX(QuarkScheme):
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
if self.dynamic_mxfp4_quant:
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
# WEIGHT
|
||||
weight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
self.get_packed_dim(input_size_per_partition, self.weight_dtype),
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
packed_dim=1,
|
||||
packed_factor=self.packed_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, kwargs)
|
||||
else:
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
# WEIGHT SCALE
|
||||
weight_scale = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // OCP_MX_BLOCK_SIZE,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
# WEIGHT
|
||||
weight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
self.get_packed_dim(input_size_per_partition, self.weight_dtype),
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
packed_dim=1,
|
||||
packed_factor=self.packed_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
weight_scale = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // OCP_MX_BLOCK_SIZE,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
|
||||
@@ -6,28 +6,18 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
activation_to_flashinfer_int,
|
||||
align_fp4_moe_weights_for_fi,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
|
||||
swizzle_blockscale,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kNvfp4Dynamic,
|
||||
kNvfp4Static,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import (
|
||||
has_flashinfer_cutlass_fused_moe,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
@@ -42,92 +32,15 @@ __all__ = [
|
||||
"reorder_w1w3_to_w3w1",
|
||||
]
|
||||
|
||||
#
|
||||
# Methods used by the oracle for kernel selection.
|
||||
#
|
||||
|
||||
|
||||
def _supports_current_device() -> bool:
|
||||
"""Supports only Blackwell-family GPUs."""
|
||||
p = current_platform
|
||||
return p.is_cuda() and p.is_device_capability_family(100)
|
||||
|
||||
|
||||
def _supports_no_act_and_mul() -> bool:
|
||||
"""Supports non-gated MoE."""
|
||||
return True
|
||||
|
||||
|
||||
def _supports_quant_scheme(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""Supports Nvfp4 quantization."""
|
||||
SUPPORTED_W_A = [
|
||||
(kNvfp4Static, kNvfp4Dynamic),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
|
||||
|
||||
def _supports_routing_method(
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
"""Monolithic kernels need to express router support."""
|
||||
# NOTE(rob): potentially allow others here. This is a conservative list.
|
||||
return routing_method in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
RoutingMethodType.Llama4,
|
||||
]
|
||||
|
||||
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
"""
|
||||
TRTLLM is a monolithic kernel that requires dispatch_router_logits() for
|
||||
the naive dispatch/combine path. DeepEP HT only implements dispatch() for
|
||||
the modular kernel path, so TRTLLM is incompatible with DeepEP HT.
|
||||
"""
|
||||
return not moe_parallel_config.use_deepep_ht_kernels
|
||||
|
||||
|
||||
def is_supported_config_trtllm(
|
||||
moe_config: FusedMoEConfig,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""
|
||||
This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config
|
||||
"""
|
||||
|
||||
def _make_reason(reason: str) -> str:
|
||||
return f"kernel does not support {reason}"
|
||||
|
||||
if not _supports_current_device():
|
||||
return False, _make_reason(f"current device {current_platform.device_name}")
|
||||
elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()):
|
||||
return False, _make_reason("no act_and_mul MLP layer")
|
||||
elif not _supports_activation(moe_config.activation):
|
||||
return False, _make_reason(f"{moe_config.activation} activation")
|
||||
elif not _supports_quant_scheme(weight_key, activation_key):
|
||||
return False, _make_reason(f"quantization scheme {weight_key}x{activation_key}")
|
||||
elif not _supports_parallel_config(moe_config.moe_parallel_config):
|
||||
return False, _make_reason(f"parallel config {moe_config.moe_parallel_config}")
|
||||
elif not _supports_routing_method(moe_config.routing_method):
|
||||
return False, _make_reason(f"routing method {moe_config.routing_method}")
|
||||
elif activation_format != mk.FusedMoEActivationFormat.Standard:
|
||||
return False, _make_reason(f"activation format {activation_format}")
|
||||
elif moe_config.hidden_dim % 512 != 0:
|
||||
return False, _make_reason(
|
||||
f"hidden_dim must be divisible by 512, found {moe_config.hidden_dim}"
|
||||
)
|
||||
|
||||
return True, None
|
||||
def is_flashinfer_fp4_cutlass_moe_available() -> bool:
|
||||
"""Return `True` when FlashInfer CUTLASS NV-FP4 kernels can be used."""
|
||||
return (
|
||||
envs.VLLM_USE_FLASHINFER_MOE_FP4
|
||||
and has_flashinfer_cutlass_fused_moe()
|
||||
and current_platform.is_cuda()
|
||||
and current_platform.has_device_capability(100)
|
||||
)
|
||||
|
||||
|
||||
def reorder_w1w3_to_w3w1(
|
||||
@@ -276,190 +189,6 @@ def prepare_static_weights_for_trtllm_fp4_moe(
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_trtllm_fp4_moe(
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
num_expert_group: int | None,
|
||||
topk_group: int | None,
|
||||
custom_routing_function: object | None,
|
||||
e_score_correction_bias: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply FlashInfer TensorRT-LLM FP4 MoE kernel.
|
||||
|
||||
Args:
|
||||
layer: The MoE layer with weights and scales
|
||||
x: Input tensor
|
||||
router_logits: Router logits for expert selection
|
||||
top_k: Number of experts to select per token
|
||||
activation: Activation function to use
|
||||
global_num_experts: Total number of experts across all ranks
|
||||
num_expert_group: Number of expert groups (for grouped routing)
|
||||
topk_group: Top-k within each group
|
||||
custom_routing_function: Custom routing function (e.g., Llama4)
|
||||
e_score_correction_bias: Optional routing bias correction
|
||||
|
||||
Returns:
|
||||
Output tensor from the MoE layer
|
||||
"""
|
||||
import flashinfer
|
||||
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
|
||||
SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
assert activation in SUPPORTED_ACTIVATIONS, (
|
||||
f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer "
|
||||
f"TRTLLM FP4 MoE, {activation} found instead."
|
||||
)
|
||||
|
||||
# Quantize input to FP4
|
||||
if isinstance(x, tuple):
|
||||
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
|
||||
else:
|
||||
# hidden_states is the already quantized
|
||||
(hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant(
|
||||
x, layer.a1_gscale, is_sf_swizzled_layout=False
|
||||
)
|
||||
|
||||
# Determine routing method type
|
||||
use_llama4_routing = custom_routing_function is Llama4MoE.custom_routing_function
|
||||
routing_method_type = layer.routing_method_type
|
||||
if use_llama4_routing:
|
||||
routing_method_type = flashinfer.RoutingMethodType.Llama4
|
||||
|
||||
# Cast to Fp32 (required by kernel).
|
||||
router_logits = (
|
||||
router_logits.to(torch.float32)
|
||||
if routing_method_type == RoutingMethodType.DeepSeekV3
|
||||
else router_logits
|
||||
)
|
||||
|
||||
# Determine activation type
|
||||
activation_type = activation_to_flashinfer_int(layer.activation)
|
||||
|
||||
# Call TRT-LLM FP4 block-scale MoE kernel
|
||||
out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=e_score_correction_bias,
|
||||
hidden_states=hidden_states_fp4,
|
||||
hidden_states_scale=hidden_states_scale_linear_fp4.view(
|
||||
torch.float8_e4m3fn
|
||||
).reshape(*hidden_states_fp4.shape[:-1], -1),
|
||||
gemm1_weights=layer.w13_weight.data,
|
||||
gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
|
||||
gemm1_bias=None,
|
||||
gemm1_alpha=None,
|
||||
gemm1_beta=None,
|
||||
gemm1_clamp_limit=None,
|
||||
gemm2_weights=layer.w2_weight.data,
|
||||
gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn),
|
||||
gemm2_bias=None,
|
||||
output1_scale_scalar=layer.g1_scale_c.data,
|
||||
output1_scale_gate_scalar=layer.g1_alphas.data,
|
||||
output2_scale_scalar=layer.g2_alphas.data,
|
||||
num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
n_group=num_expert_group if num_expert_group is not None else 0,
|
||||
topk_group=topk_group if topk_group is not None else 0,
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
routed_scaling_factor=None,
|
||||
routing_method_type=routing_method_type,
|
||||
do_finalize=True,
|
||||
activation_type=activation_type,
|
||||
)[0]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def flashinfer_trtllm_fp4_routed_moe(
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
top_k: int,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply FlashInfer TensorRT-LLM FP4 MoE kernel. Uses packed
|
||||
input top k expert indices and scores rather than computing
|
||||
top k expert indices from scores.
|
||||
|
||||
Args:
|
||||
layer: The MoE layer with weights and scales
|
||||
x: Input tensor
|
||||
topk_ids: Ids of selected experts
|
||||
top_k: Number of experts to select per token
|
||||
activation: Activation function to use
|
||||
global_num_experts: Total number of experts across all ranks
|
||||
|
||||
Returns:
|
||||
Output tensor from the MoE layer
|
||||
"""
|
||||
import flashinfer
|
||||
|
||||
# https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2535
|
||||
assert activation == MoEActivation.SILU, (
|
||||
"Only SiLU activation is supported for FlashInfer TRTLLM FP4 Routed MoE. "
|
||||
f"{activation} found instead."
|
||||
)
|
||||
|
||||
# Pack top k ids and expert weights into a single int32 tensor, as
|
||||
# required by TRT-LLM
|
||||
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
|
||||
torch.bfloat16
|
||||
).view(torch.int16)
|
||||
|
||||
if isinstance(x, tuple):
|
||||
# Hidden_states is the already quantized
|
||||
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
|
||||
else:
|
||||
# Quantize input to FP4
|
||||
(hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant(
|
||||
x, layer.a1_gscale, is_sf_swizzled_layout=False
|
||||
)
|
||||
|
||||
# Call TRT-LLM FP4 block-scale MoE kernel
|
||||
out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(
|
||||
topk_ids=packed_tensor,
|
||||
routing_bias=None,
|
||||
hidden_states=hidden_states_fp4,
|
||||
hidden_states_scale=hidden_states_scale_linear_fp4.view(
|
||||
torch.float8_e4m3fn
|
||||
).reshape(*hidden_states_fp4.shape[:-1], -1),
|
||||
gemm1_weights=layer.w13_weight.data,
|
||||
gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
|
||||
gemm1_bias=None,
|
||||
gemm1_alpha=None,
|
||||
gemm1_beta=None,
|
||||
gemm1_clamp_limit=None,
|
||||
gemm2_weights=layer.w2_weight.data,
|
||||
gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn),
|
||||
gemm2_bias=None,
|
||||
output1_scale_scalar=layer.g1_scale_c.data,
|
||||
output1_scale_gate_scalar=layer.g1_alphas.data,
|
||||
output2_scale_scalar=layer.g2_alphas.data,
|
||||
num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
n_group=0,
|
||||
topk_group=0,
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
routed_scaling_factor=None,
|
||||
routing_method_type=1,
|
||||
do_finalize=True,
|
||||
)[0]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
|
||||
backend: "NvFp4MoeBackend",
|
||||
layer: "FusedMoE",
|
||||
@@ -526,6 +255,7 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
|
||||
)
|
||||
)
|
||||
layer.intermediate_size_per_partition = padded_intermediate
|
||||
layer.moe_config.intermediate_size_per_partition = padded_intermediate
|
||||
|
||||
w13, w13_scale, w2, w2_scale = prepare_static_weights_for_trtllm_fp4_moe(
|
||||
w13,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
@@ -10,6 +11,9 @@ from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from flashinfer.fused_moe.core import ActivationType
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -20,6 +24,10 @@ class FlashinferMoeBackend(Enum):
|
||||
|
||||
|
||||
def activation_to_flashinfer_int(activation: MoEActivation) -> int:
|
||||
return activation_to_flashinfer_type(activation).value
|
||||
|
||||
|
||||
def activation_to_flashinfer_type(activation: MoEActivation) -> "ActivationType":
|
||||
from flashinfer.fused_moe.core import ActivationType
|
||||
|
||||
# silu and gelu are mapped to their gated versions SwiGLU and GeGLU respectively
|
||||
@@ -30,7 +38,7 @@ def activation_to_flashinfer_int(activation: MoEActivation) -> int:
|
||||
MoEActivation.GELU: ActivationType.Geglu,
|
||||
MoEActivation.RELU2_NO_MUL: ActivationType.Relu2,
|
||||
}
|
||||
return ACTIVATION_TO_FI_ACTIVATION[activation].value
|
||||
return ACTIVATION_TO_FI_ACTIVATION[activation]
|
||||
|
||||
|
||||
def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -87,104 +95,6 @@ def rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
|
||||
)
|
||||
|
||||
|
||||
def register_scales_for_trtllm_fp8_per_tensor_moe(
|
||||
layer: torch.nn.Module,
|
||||
w13_scale: torch.Tensor,
|
||||
w13_input_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w2_input_scale: torch.Tensor,
|
||||
) -> None:
|
||||
"""Register necessary scales for FlashInfer TRTLLM FP8 MoE kernel"""
|
||||
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
|
||||
w13_scale=w13_scale,
|
||||
w13_input_scale=w13_input_scale,
|
||||
w2_scale=w2_scale,
|
||||
w2_input_scale=w2_input_scale,
|
||||
)
|
||||
layer.w2_input_scale_inv = 1.0 / w2_input_scale
|
||||
layer.output1_scales_gate_scalar = g1_alphas
|
||||
|
||||
if layer.activation.is_gated:
|
||||
layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv
|
||||
else:
|
||||
layer.output1_scales_scalar = (
|
||||
torch.ones_like(g1_alphas) * layer.w2_input_scale_inv
|
||||
)
|
||||
layer.output2_scales_scalar = g2_alphas
|
||||
|
||||
|
||||
def apply_fi_trtllm_fp8_per_tensor_moe(
|
||||
layer: torch.nn.Module,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor | None,
|
||||
top_k: int,
|
||||
num_expert_group: int | None,
|
||||
topk_group: int | None,
|
||||
global_num_experts: int,
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> torch.Tensor:
|
||||
from flashinfer.fused_moe import RoutingMethodType
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
|
||||
# Added to the layer by: register_scales_for_trtllm_fp8_per_tensor_moe
|
||||
assert (
|
||||
hasattr(layer, "output1_scales_scalar")
|
||||
and hasattr(layer, "output1_scales_gate_scalar")
|
||||
and hasattr(layer, "output2_scales_scalar")
|
||||
)
|
||||
|
||||
if layer.routing_method_type == RoutingMethodType.Llama4:
|
||||
assert (
|
||||
not layer.renormalize
|
||||
and layer.custom_routing_function == Llama4MoE.custom_routing_function
|
||||
), (
|
||||
"FusedMoE flashinfer kernels with Llama4 routing method are only "
|
||||
"supported for Llama4"
|
||||
)
|
||||
else:
|
||||
assert layer.custom_routing_function is None, (
|
||||
"Custom routing function is only supported for Llama4"
|
||||
)
|
||||
activation_type = activation_to_flashinfer_int(layer.activation)
|
||||
|
||||
return torch.ops.vllm.fi_trtllm_fp8_per_tensor_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=routing_bias,
|
||||
hidden_states=hidden_states,
|
||||
input_scale=layer.w13_input_scale,
|
||||
gemm1_weights=layer.w13_weight,
|
||||
gemm2_weights=layer.w2_weight,
|
||||
output1_scales_scalar=layer.output1_scales_scalar,
|
||||
output1_scales_gate_scalar=layer.output1_scales_gate_scalar,
|
||||
output2_scales_scalar=layer.output2_scales_scalar,
|
||||
num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
use_routing_scales_on_input=apply_router_weight_on_input,
|
||||
routing_method_type=layer.routing_method_type,
|
||||
activation_type=activation_type,
|
||||
)
|
||||
|
||||
|
||||
def make_fp8_moe_alpha_scales_for_fi(
|
||||
w13_scale: torch.Tensor,
|
||||
w13_input_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w2_input_scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
g1_alphas = (w13_scale * w13_input_scale).squeeze()
|
||||
g2_alphas = (w2_scale * w2_input_scale).squeeze()
|
||||
|
||||
return g1_alphas, g2_alphas
|
||||
|
||||
|
||||
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
|
||||
backend_map = {
|
||||
"throughput": FlashinferMoeBackend.CUTLASS,
|
||||
@@ -432,6 +342,7 @@ def prepare_fp8_moe_layer_for_fi(
|
||||
min_alignment,
|
||||
)
|
||||
layer.intermediate_size_per_partition = new_intermediate
|
||||
layer.moe_config.intermediate_size_per_partition = new_intermediate
|
||||
|
||||
# FI kernels require W31 layout rather than W13.
|
||||
if layer.moe_config.is_act_and_mul:
|
||||
@@ -440,20 +351,12 @@ def prepare_fp8_moe_layer_for_fi(
|
||||
w13_scale = swap_w13_to_w31(w13_scale)
|
||||
|
||||
# FI TRT-LLM FP8 per-tensor MoE kernel requires weight shuffle
|
||||
# and registration of alpha scales. Note that we do not register
|
||||
# as nn.Parameters since they are not needed for weight-reloading.
|
||||
# and registration of alpha scales.
|
||||
if is_trtllm and not block_quant:
|
||||
assert w13_input_scale is not None
|
||||
assert w2_input_scale is not None
|
||||
|
||||
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(w13, w2, is_gated)
|
||||
register_scales_for_trtllm_fp8_per_tensor_moe(
|
||||
layer,
|
||||
w13_scale=w13_scale,
|
||||
w13_input_scale=w13_input_scale,
|
||||
w2_scale=w2_scale,
|
||||
w2_input_scale=w2_input_scale,
|
||||
)
|
||||
|
||||
# Clamp block scales to avoid NaN from the FlashInfer CUTLASS kernel.
|
||||
# Some FP8 models have near-zero block scales (~1e-23) for dead/unused
|
||||
|
||||
@@ -53,7 +53,10 @@ logger = init_logger(__name__)
|
||||
def is_fp8(x: torch.dtype | torch.Tensor) -> bool:
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = x.dtype
|
||||
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
|
||||
try:
|
||||
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
# We need to pass in the is_hopper flag as argument because the function
|
||||
|
||||
373
vllm/model_executor/layers/quantization/utils/gguf_utils.py
Normal file
373
vllm/model_executor/layers/quantization/utils/gguf_utils.py
Normal file
@@ -0,0 +1,373 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from gguf.constants import GGMLQuantizationType
|
||||
|
||||
def get_awq_format(w, group_size=128, w_bit=4):
|
||||
org_w_shape = w.shape
|
||||
ori_w_dtype = torch.get_default_dtype()
|
||||
assert w_bit == 4
|
||||
assert w.shape[1] % group_size == 0
|
||||
|
||||
in_features = org_w_shape[1]
|
||||
w = w.reshape(-1, group_size)
|
||||
assert torch.isnan(w).sum() == 0
|
||||
|
||||
max_val = w.amax(dim=1, keepdim=True)
|
||||
min_val = w.amin(dim=1, keepdim=True)
|
||||
max_int = 2**w_bit - 1
|
||||
min_int = 0
|
||||
scales = (max_val - min_val).clamp(min=1e-5) / max_int
|
||||
zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
|
||||
w = (
|
||||
torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros
|
||||
) * scales
|
||||
zeros = zeros.view(org_w_shape[0], -1)
|
||||
scales = scales.view(org_w_shape[0], -1)
|
||||
w = w.reshape(org_w_shape)
|
||||
assert torch.isnan(scales).sum() == 0
|
||||
assert torch.isnan(w).sum() == 0
|
||||
|
||||
scales = scales.t().contiguous() # input // group, o
|
||||
zeros = zeros.t().contiguous() # input // group, o
|
||||
|
||||
# from auto awq
|
||||
scale_zeros = zeros * scales
|
||||
scales = scales.clone().to(ori_w_dtype)
|
||||
|
||||
pack_num = 32 // w_bit
|
||||
intweight = []
|
||||
for idx in range(in_features):
|
||||
intweight.append(
|
||||
torch.round(
|
||||
(w[:, idx] + scale_zeros[idx // group_size])
|
||||
/ scales[idx // group_size]
|
||||
).to(torch.int)[:, None]
|
||||
)
|
||||
intweight = torch.cat(intweight, dim=1)
|
||||
intweight = intweight.t().contiguous()
|
||||
intweight = intweight.to(dtype=torch.int32)
|
||||
|
||||
qweight = torch.zeros(
|
||||
(intweight.shape[0], intweight.shape[1] // 32 * w_bit),
|
||||
dtype=torch.int32,
|
||||
device=intweight.device,
|
||||
)
|
||||
|
||||
for col in range(intweight.shape[1] // pack_num):
|
||||
order_map = [0, 2, w_bit, 6, 1, 3, 5, 7]
|
||||
for i in range(pack_num):
|
||||
qweight_col = intweight[:, col * pack_num + order_map[i]]
|
||||
qweight[:, col] |= qweight_col << (i * w_bit)
|
||||
|
||||
zeros = zeros.to(dtype=torch.int32, device=qweight.device)
|
||||
|
||||
qzeros = torch.zeros(
|
||||
(zeros.shape[0], zeros.shape[1] // 32 * w_bit),
|
||||
dtype=torch.int32,
|
||||
device=zeros.device,
|
||||
)
|
||||
|
||||
for col in range(zeros.shape[1] // pack_num):
|
||||
order_map = [0, 2, w_bit, 6, 1, 3, 5, 7]
|
||||
for i in range(pack_num):
|
||||
qzero_col = zeros[:, col * pack_num + order_map[i]]
|
||||
qzeros[:, col] |= qzero_col << (i * w_bit)
|
||||
|
||||
return qweight, qzeros, scales
|
||||
|
||||
GGML_BLOCK_SIZES = {
|
||||
"F32": 4,
|
||||
"F16": 2,
|
||||
"Q4_0": 2 + 16,
|
||||
"Q5_0": 2 + 4 + 16,
|
||||
"Q8_0": 2 + 32,
|
||||
"Q2_K": 256 // 16 + 256 // 4 + 2 + 2,
|
||||
"Q3_K": 256 // 8 + 256 // 4 + 12 + 2,
|
||||
"Q4_K": 2 + 2 + 12 + 256 // 2,
|
||||
"Q5_K": 2 + 2 + 12 + 256 // 8 + 256 // 2,
|
||||
"Q6_K": 256 // 2 + 256 // 4 + 256 // 16 + 2,
|
||||
"IQ4_XS": 2 + 2 + 256 // 2 + 256 // 64,
|
||||
}
|
||||
|
||||
def dequantize_f32(data):
|
||||
return np.frombuffer(data, dtype=np.float32)
|
||||
|
||||
def dequantize_f16(data):
|
||||
return np.frombuffer(data, dtype=np.float16)
|
||||
|
||||
def dequantize_q4_0(data):
|
||||
num_blocks = len(data) // GGML_BLOCK_SIZES["Q4_0"]
|
||||
|
||||
scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 8)[:, :1].astype(np.float32)
|
||||
qs = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 16)[:, 2:]
|
||||
|
||||
return np.concatenate([
|
||||
scales * ((qs & 0xf).astype(np.int8) - 8),
|
||||
scales * ((qs >> 4).astype(np.int8) - 8),
|
||||
], axis=1)
|
||||
|
||||
def dequantize_q5_0(data):
|
||||
num_blocks = len(data) // GGML_BLOCK_SIZES["Q5_0"]
|
||||
|
||||
scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 2 + 8)[:, :1].astype(np.float32)
|
||||
qh = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 4 + 16)[:, 2:2 + 4]
|
||||
qs = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 4 + 16)[:, 2 + 4:]
|
||||
|
||||
bits = np.unpackbits(qh, axis=-1, bitorder="little")
|
||||
|
||||
x0 = ((qs & 0xf).astype(np.int8) | (bits[:, :16] << 4)) - 16
|
||||
x1 = ((qs >> 4).astype(np.int8) | (bits[:, 16:] << 4)) - 16
|
||||
|
||||
return np.concatenate([
|
||||
scales * x0,
|
||||
scales * x1,
|
||||
], axis=1)
|
||||
|
||||
def dequantize_q8_0(data):
|
||||
num_blocks = len(data) // GGML_BLOCK_SIZES["Q8_0"]
|
||||
|
||||
scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 16)[:, :1].astype(np.float32)
|
||||
qs = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, 2 + 32)[:, 2:]
|
||||
return scales * qs
|
||||
|
||||
def dequantize_q2_k(data):
|
||||
block_size = GGML_BLOCK_SIZES["Q2_K"]
|
||||
num_blocks = len(data) // block_size
|
||||
|
||||
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
|
||||
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
|
||||
|
||||
dmin = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32)
|
||||
d = data_f16[:, -2].reshape(num_blocks, 1, 1).astype(np.float32)
|
||||
scales = data_u8[:, :16].reshape(num_blocks, 16, 1)
|
||||
qs = data_u8[:, 16:80].reshape(num_blocks, 64)
|
||||
|
||||
tmp = np.stack([
|
||||
qs[:, 00:16] >> 0,
|
||||
qs[:, 16:32] >> 0,
|
||||
qs[:, 00:16] >> 2,
|
||||
qs[:, 16:32] >> 2,
|
||||
qs[:, 00:16] >> 4,
|
||||
qs[:, 16:32] >> 4,
|
||||
qs[:, 00:16] >> 6,
|
||||
qs[:, 16:32] >> 6,
|
||||
qs[:, 32:48] >> 0,
|
||||
qs[:, 48:64] >> 0,
|
||||
qs[:, 32:48] >> 2,
|
||||
qs[:, 48:64] >> 2,
|
||||
qs[:, 32:48] >> 4,
|
||||
qs[:, 48:64] >> 4,
|
||||
qs[:, 32:48] >> 6,
|
||||
qs[:, 48:64] >> 6,
|
||||
], axis=1)
|
||||
|
||||
return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4)
|
||||
|
||||
|
||||
def dequantize_q3_k(data):
|
||||
block_size = GGML_BLOCK_SIZES["Q3_K"]
|
||||
num_blocks = len(data) // block_size
|
||||
|
||||
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
|
||||
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
|
||||
|
||||
d = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32)
|
||||
bits = np.unpackbits(data_u8[:, :32].reshape(num_blocks, 32, 1), axis=-1, bitorder="little")
|
||||
bits = 4 ^ (bits << 2)
|
||||
qs = data_u8[:, 32:32 + 64].astype(np.int16)
|
||||
a, b, c = data_u8[:, 96: 96 + 12].reshape(num_blocks, 3, 4).transpose(1, 0, 2)
|
||||
scales = np.zeros((num_blocks, 4, 4), dtype=np.uint8)
|
||||
scales[:, 0] = (a & 15) | ((c & 3) << 4)
|
||||
scales[:, 1] = (b & 15) | (((c >> 2) & 3) << 4)
|
||||
scales[:, 2] = (a >> 4) | (((c >> 4) & 3) << 4)
|
||||
scales[:, 3] = (b >> 4) | ((c >> 6) << 4)
|
||||
scales = scales.reshape(num_blocks, 16, 1).astype(np.int16)
|
||||
|
||||
return d * (scales - 32) * np.stack([
|
||||
(((qs[:, 00:16] >> 0) & 3) - bits[:, :16, 0]),
|
||||
(((qs[:, 16:32] >> 0) & 3) - bits[:, 16:, 0]),
|
||||
(((qs[:, 00:16] >> 2) & 3) - bits[:, :16, 1]),
|
||||
(((qs[:, 16:32] >> 2) & 3) - bits[:, 16:, 1]),
|
||||
(((qs[:, 00:16] >> 4) & 3) - bits[:, :16, 2]),
|
||||
(((qs[:, 16:32] >> 4) & 3) - bits[:, 16:, 2]),
|
||||
(((qs[:, 00:16] >> 6) & 3) - bits[:, :16, 3]),
|
||||
(((qs[:, 16:32] >> 6) & 3) - bits[:, 16:, 3]),
|
||||
(((qs[:, 32:48] >> 0) & 3) - bits[:, :16, 4]),
|
||||
(((qs[:, 48:64] >> 0) & 3) - bits[:, 16:, 4]),
|
||||
(((qs[:, 32:48] >> 2) & 3) - bits[:, :16, 5]),
|
||||
(((qs[:, 48:64] >> 2) & 3) - bits[:, 16:, 5]),
|
||||
(((qs[:, 32:48] >> 4) & 3) - bits[:, :16, 6]),
|
||||
(((qs[:, 48:64] >> 4) & 3) - bits[:, 16:, 6]),
|
||||
(((qs[:, 32:48] >> 6) & 3) - bits[:, :16, 7]),
|
||||
(((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7])
|
||||
], axis=1)
|
||||
|
||||
def dequantize_q4_k(data, device=None):
|
||||
block_size = GGML_BLOCK_SIZES["Q4_K"]
|
||||
num_blocks = len(data) // block_size
|
||||
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
|
||||
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
|
||||
# Casting to float32 because float16 is very slow on CPU
|
||||
scale_factors = data_f16[:, 0].reshape(num_blocks, 1, 1).astype(np.float32)
|
||||
scale_offsets = data_f16[:, 1].reshape(num_blocks, 1, 1).astype(np.float32)
|
||||
qs1 = data_u8[:, 4:16].reshape(num_blocks, 12, 1)
|
||||
qs2 = data_u8[:, 16:].reshape(num_blocks, 4, 32)
|
||||
# Dequantize scales and offsets (6 bits and 4 + 2 bits)
|
||||
factors = scale_factors * np.concatenate([qs1[:, 0:4] & 0b111111, (qs1[:, 8:] & 15) | ((qs1[:, 0:4] >> 6) << 4)], axis=1)
|
||||
offsets = scale_offsets * np.concatenate([qs1[:, 4:8] & 0b111111, (qs1[:, 8:] >> 4) | ((qs1[:, 4:8] >> 6) << 4)], axis=1)
|
||||
# Interleave low and high quantized bits
|
||||
qs2 = np.stack([qs2 & 0xf, qs2 >> 4], axis=2).reshape(num_blocks, 8, 32)
|
||||
# Dequantize final weights using scales and offsets
|
||||
weight = factors * qs2 - offsets
|
||||
if device is None:
|
||||
return weight
|
||||
return torch.from_numpy(weight).to(device=device)
|
||||
|
||||
def dequantize_q5_k(data):
|
||||
block_size = GGML_BLOCK_SIZES["Q5_K"]
|
||||
num_blocks = len(data) // block_size
|
||||
|
||||
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
|
||||
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
|
||||
|
||||
d = data_f16[:, 0].reshape(num_blocks, 1).astype(np.float32)
|
||||
dmin = data_f16[:, 1].reshape(num_blocks, 1).astype(np.float32)
|
||||
scales = data_u8[:, 4:16].reshape(num_blocks, 12, 1)
|
||||
qh = data_u8[:, 16: 16 + 32].reshape(num_blocks, 32, 1)
|
||||
qs = data_u8[:, 48: 48 + 128].reshape(num_blocks, 4, 32)
|
||||
|
||||
bits = np.unpackbits(qh, axis=-1, bitorder="little")
|
||||
|
||||
qs_hi_4 = qs >> 4
|
||||
qs_lo_4 = qs & 15
|
||||
|
||||
scales_lo_6 = scales[:, :8] & 63
|
||||
scales_hi_6 = scales[:, :8] >> 6
|
||||
scales_lo_4 = scales[:, 8:] & 15
|
||||
scales_hi_4 = scales[:, 8:] >> 4
|
||||
|
||||
m1 = dmin * scales_lo_6[:, 4]
|
||||
m2 = dmin * scales_lo_6[:, 5]
|
||||
m3 = dmin * scales_lo_6[:, 6]
|
||||
m4 = dmin * scales_lo_6[:, 7]
|
||||
m5 = dmin * (scales_hi_4[:, 0] | (scales_hi_6[:, 4] << 4))
|
||||
m6 = dmin * (scales_hi_4[:, 1] | (scales_hi_6[:, 5] << 4))
|
||||
m7 = dmin * (scales_hi_4[:, 2] | (scales_hi_6[:, 6] << 4))
|
||||
m8 = dmin * (scales_hi_4[:, 3] | (scales_hi_6[:, 7] << 4))
|
||||
|
||||
d1 = d * scales_lo_6[:, 0]
|
||||
d2 = d * scales_lo_6[:, 1]
|
||||
d3 = d * scales_lo_6[:, 2]
|
||||
d4 = d * scales_lo_6[:, 3]
|
||||
d5 = d * (scales_lo_4[:, 0] | (scales_hi_6[:, 0] << 4))
|
||||
d6 = d * (scales_lo_4[:, 1] | (scales_hi_6[:, 1] << 4))
|
||||
d7 = d * (scales_lo_4[:, 2] | (scales_hi_6[:, 2] << 4))
|
||||
d8 = d * (scales_lo_4[:, 3] | (scales_hi_6[:, 3] << 4))
|
||||
|
||||
return np.concatenate([
|
||||
d1 * (qs_lo_4[:, 0] + (bits[:, :, 0] << 4)) - m1,
|
||||
d2 * (qs_hi_4[:, 0] + (bits[:, :, 1] << 4)) - m2,
|
||||
d3 * (qs_lo_4[:, 1] + (bits[:, :, 2] << 4)) - m3,
|
||||
d4 * (qs_hi_4[:, 1] + (bits[:, :, 3] << 4)) - m4,
|
||||
d5 * (qs_lo_4[:, 2] + (bits[:, :, 4] << 4)) - m5,
|
||||
d6 * (qs_hi_4[:, 2] + (bits[:, :, 5] << 4)) - m6,
|
||||
d7 * (qs_lo_4[:, 3] + (bits[:, :, 6] << 4)) - m7,
|
||||
d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8,
|
||||
], axis=1)
|
||||
|
||||
def dequantize_q6_k(data, device = None):
|
||||
block_size = GGML_BLOCK_SIZES["Q6_K"]
|
||||
num_blocks = len(data) // block_size
|
||||
|
||||
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
|
||||
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
|
||||
data_i8 = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, block_size)
|
||||
|
||||
scales = data_f16[:, -1].reshape(num_blocks, 1).astype(np.float32)
|
||||
# TODO use uint8 and cast later?
|
||||
ql = data_u8[:, :128].astype(np.int16)
|
||||
qh = data_u8[:, 128:192].astype(np.int16)
|
||||
sc = data_i8[:, 192:208, np.newaxis].astype(np.float32)
|
||||
|
||||
# Unpack bits, subtraction requires signed data type
|
||||
q1 = (ql[:, :32 ] & 0xF) | (((qh[:, :32] >> 0) & 3) << 4) - 32
|
||||
q2 = (ql[:, 32:64 ] & 0xF) | (((qh[:, :32] >> 2) & 3) << 4) - 32
|
||||
q3 = (ql[:, :32 ] >> 4) | (((qh[:, :32] >> 4) & 3) << 4) - 32
|
||||
q4 = (ql[:, 32:64 ] >> 4) | (((qh[:, :32] >> 6) & 3) << 4) - 32
|
||||
q5 = (ql[:, 64:96 ] & 0xF) | (((qh[:, 32:] >> 0) & 3) << 4) - 32
|
||||
q6 = (ql[:, 96:128] & 0xF) | (((qh[:, 32:] >> 2) & 3) << 4) - 32
|
||||
q7 = (ql[:, 64:96 ] >> 4) | (((qh[:, 32:] >> 4) & 3) << 4) - 32
|
||||
q8 = (ql[:, 96:128] >> 4) | (((qh[:, 32:] >> 6) & 3) << 4) - 32
|
||||
|
||||
# Dequantize
|
||||
weight = scales * np.concatenate([
|
||||
sc[:, 0] * q1[:, :16],
|
||||
sc[:, 1] * q1[:, 16:],
|
||||
sc[:, 2] * q2[:, :16],
|
||||
sc[:, 3] * q2[:, 16:],
|
||||
sc[:, 4] * q3[:, :16],
|
||||
sc[:, 5] * q3[:, 16:],
|
||||
sc[:, 6] * q4[:, :16],
|
||||
sc[:, 7] * q4[:, 16:],
|
||||
sc[:, 8] * q5[:, :16],
|
||||
sc[:, 9] * q5[:, 16:],
|
||||
sc[:, 10] * q6[:, :16],
|
||||
sc[:, 11] * q6[:, 16:],
|
||||
sc[:, 12] * q7[:, :16],
|
||||
sc[:, 13] * q7[:, 16:],
|
||||
sc[:, 14] * q8[:, :16],
|
||||
sc[:, 15] * q8[:, 16:],
|
||||
], axis=1)
|
||||
|
||||
if device is None:
|
||||
return weight
|
||||
return torch.from_numpy(weight).to(device=device)
|
||||
|
||||
QK_K = 256
|
||||
kvalues_iq4nl = np.array([-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], dtype=np.int8)
|
||||
|
||||
def dequantize_iq4_xs(data):
|
||||
block_size = GGML_BLOCK_SIZES["IQ4_XS"]
|
||||
num_blocks = len(data) // block_size
|
||||
|
||||
d = np.frombuffer(data, dtype=np.float16)[0::block_size//2].astype(np.float32).reshape(num_blocks, 1)
|
||||
scales_h = np.frombuffer(data, dtype=np.uint16)[1::block_size//2].reshape(num_blocks, 1)
|
||||
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)[:, 4:]
|
||||
scales_l = data_u8[:, :4].reshape(num_blocks, 4)
|
||||
qs = data_u8[:, 4:].reshape(num_blocks, block_size - 8)
|
||||
|
||||
ls = np.zeros((num_blocks, QK_K // 32), dtype=np.int8)
|
||||
for ib in range(QK_K // 32):
|
||||
ls[:, ib] = ((scales_l[:, ib // 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h[:, 0] >> 2 * ib) & 3) << 4)
|
||||
|
||||
dl = (d * (ls - 32)).reshape(num_blocks, -1, 1)
|
||||
|
||||
qs_lo_4 = qs[:, :QK_K // 2].reshape(num_blocks, -1, 16) & 0xf
|
||||
qs_hi_4 = qs[:, :QK_K // 2].reshape(num_blocks, -1, 16) >> 4
|
||||
|
||||
y = np.zeros((num_blocks, QK_K), dtype=np.float32)
|
||||
for ib in range(QK_K // 32):
|
||||
y[:, ib*32:(ib*32)+16] = dl[:, ib] * kvalues_iq4nl[qs_lo_4[:, ib]]
|
||||
y[:, (ib*32)+16:(ib*32)+32] = dl[:, ib] * kvalues_iq4nl[qs_hi_4[:, ib]]
|
||||
|
||||
return y.flatten()
|
||||
|
||||
GGML_DEQUANTIZE = {
|
||||
int(GGMLQuantizationType.F32): dequantize_f32,
|
||||
int(GGMLQuantizationType.F16): dequantize_f16,
|
||||
int(GGMLQuantizationType.Q4_0): dequantize_q4_0,
|
||||
int(GGMLQuantizationType.Q5_0): dequantize_q5_0,
|
||||
int(GGMLQuantizationType.Q8_0): dequantize_q8_0,
|
||||
int(GGMLQuantizationType.Q2_K): dequantize_q2_k,
|
||||
int(GGMLQuantizationType.Q3_K): dequantize_q3_k,
|
||||
int(GGMLQuantizationType.Q4_K): dequantize_q4_k,
|
||||
int(GGMLQuantizationType.Q5_K): dequantize_q5_k,
|
||||
int(GGMLQuantizationType.Q6_K): dequantize_q6_k,
|
||||
int(GGMLQuantizationType.IQ4_XS): dequantize_iq4_xs,
|
||||
}
|
||||
|
||||
|
||||
def dequant_gguf(data, type, shape):
|
||||
values = GGML_DEQUANTIZE[type](data)
|
||||
values = torch.from_numpy(values).view(shape)
|
||||
return values
|
||||
@@ -255,18 +255,6 @@ def marlin_moe_intermediate_size(w1_packed: torch.Tensor, w2_packed: torch.Tenso
|
||||
return w2_packed.size(1) * marlin_tile_size
|
||||
|
||||
|
||||
def marlin_make_workspace(
|
||||
output_size_per_partition: int, device: torch.device
|
||||
) -> torch.Tensor:
|
||||
max_workspace_size = (
|
||||
output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N
|
||||
) * GPTQ_MARLIN_MAX_PARALLEL
|
||||
|
||||
return torch.zeros(
|
||||
max_workspace_size, dtype=torch.int, device=device, requires_grad=False
|
||||
)
|
||||
|
||||
|
||||
def marlin_make_workspace_new(
|
||||
device: torch.device, max_blocks_per_sm: int = 1
|
||||
) -> torch.Tensor:
|
||||
@@ -297,12 +285,6 @@ def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
|
||||
)
|
||||
|
||||
|
||||
def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
|
||||
return torch.nn.Parameter(
|
||||
torch.empty(0, dtype=torch.int, device=device), requires_grad=False
|
||||
)
|
||||
|
||||
|
||||
def marlin_sort_g_idx(g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
|
||||
return g_idx[g_idx_sort_indices], g_idx_sort_indices
|
||||
|
||||
@@ -175,7 +175,7 @@ try:
|
||||
op_func=_dequant_mxfp4,
|
||||
fake_impl=_dequant_mxfp4_fake,
|
||||
)
|
||||
dequant_mxfp4 = torch.ops.vllm.dequant_mxfp4
|
||||
dequant_mxfp4 = None
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
@@ -185,6 +185,6 @@ try:
|
||||
op_func=_quant_dequant_mxfp4,
|
||||
fake_impl=_quant_dequant_mxfp4_fake,
|
||||
)
|
||||
quant_dequant_mxfp4 = torch.ops.vllm.quant_dequant_mxfp4
|
||||
quant_dequant_mxfp4 = None
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
@@ -271,12 +271,12 @@ def scaled_quantize(
|
||||
If None, uses input dtype. Use torch.float32 for higher precision.
|
||||
"""
|
||||
group_shape = _normalize_quant_group_shape(x, group_shape)
|
||||
assert quant_dtype.is_floating_point, (
|
||||
"currently `scaled_quantize` only supports floating point dtypes "
|
||||
"but could be extended to support other dtypes"
|
||||
)
|
||||
# assert quant_dtype.is_floating_point, (
|
||||
# "currently `scaled_quantize` only supports floating point dtypes "
|
||||
# "but could be extended to support other dtypes"
|
||||
# )
|
||||
|
||||
finfo = torch.finfo(quant_dtype)
|
||||
finfo = torch.finfo(quant_dtype) if quant_dtype.is_floating_point else torch.iinfo(quant_dtype)
|
||||
|
||||
# Convert to compute dtype if specified
|
||||
x_compute = x if compute_dtype is None else x.to(compute_dtype)
|
||||
|
||||
114
vllm/model_executor/layers/quantization/w8a16.py
Normal file
114
vllm/model_executor/layers/quantization/w8a16.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
|
||||
class W8a16Config(QuantizationConfig):
|
||||
"""Config class for W8a16.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return ("W8a16Config")
|
||||
|
||||
def get_name(self) -> str:
|
||||
return "w8a16"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
def get_min_capability(self) -> int:
|
||||
return 75
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames():
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "W8a16Config":
|
||||
return cls()
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["W8a16LinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return W8a16LinearMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class W8a16LinearMethod(LinearMethodBase):
|
||||
"""Linear method for w8a16.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: W8a16Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight = Parameter(
|
||||
torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=torch.int8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
weight, {
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
})
|
||||
|
||||
scales = Parameter(
|
||||
torch.empty(
|
||||
1,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(scales, {
|
||||
"input_dim": None,
|
||||
"output_dim": 1,
|
||||
})
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
layer.register_parameter("scales", scales)
|
||||
set_weight_attrs(scales, extra_weight_attrs)
|
||||
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
qweight = layer.weight
|
||||
scales = layer.scales
|
||||
out_shape = (x.shape[:-1] + (qweight.shape[-2],))
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out = ops.linear_w8a16(reshaped_x, qweight, scales, format="TN")
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out.reshape(out_shape)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user