This commit is contained in:
2026-04-02 04:53:13 +00:00
parent 80932c96e5
commit 24df76db9d
1987 changed files with 447445 additions and 0 deletions

View File

View File

@@ -0,0 +1,230 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Warmup deep_gemm kernels.
DeepGEMM JIT's the kernels. The warmup aims to JIT all the kernels that would
be used during model execution beforehand.
"""
import torch
from tqdm import tqdm
import vllm.envs as envs
from vllm.distributed.parallel_state import get_dp_group
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
compute_aligned_M, deep_gemm_block_shape)
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.utils.deep_gemm import fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous
def _extract_data_from_linear_base_module(
m: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor, list[int]]:
"""
Extract weights, weight scales and quantization block sizes from the given
LinearBase module.
"""
assert isinstance(m, LinearBase)
assert isinstance(m.quant_method, Fp8LinearMethod)
assert m.quant_method.block_quant
assert m.quant_method.quant_config is not None
w = m.weight
ws = m.weight_scale
quant_block_size = m.quant_method.quant_config.weight_block_size
assert isinstance(w, torch.Tensor)
assert isinstance(ws, torch.Tensor)
assert quant_block_size is not None
return (w, ws, quant_block_size)
def _extract_data_from_fused_moe_module(
m: torch.nn.Module
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]:
"""
Extract weights, weight scales and num_topk from FusedMoE module.
"""
assert isinstance(m, FusedMoE)
w13 = m.w13_weight
w13_s = m.w13_weight_scale_inv if hasattr(
m, "w13_weight_scale_inv") else m.w13_weight_scale
w2 = m.w2_weight
w2_s = m.w2_weight_scale_inv if hasattr(
m, "w2_weight_scale_inv") else m.w2_weight_scale
num_topk = m.top_k
assert isinstance(w13, torch.Tensor)
assert isinstance(w13_s, torch.Tensor)
assert isinstance(w2, torch.Tensor)
assert isinstance(w2_s, torch.Tensor)
return w13, w13_s, w2, w2_s, num_topk
def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool:
"""
Return True if the input module/layer could be processed with DeepGEMM.
"""
block_size = deep_gemm_block_shape()[0]
if not (isinstance(module, LinearBase)
and isinstance(module.quant_method, Fp8LinearMethod)
and module.quant_method.block_quant):
return False
w, _, block_sizes = _extract_data_from_linear_base_module(module)
return (block_sizes == deep_gemm_block_shape() and w.ndim == 2
and w.shape[0] % block_size == 0 and w.shape[1] % block_size == 0)
def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
if not isinstance(module, FusedMoE):
return False
moe_quant_config = module.quant_method.get_fused_moe_quant_config(module)
if (moe_quant_config is None
or moe_quant_config.quant_dtype != torch.float8_e4m3fn
or moe_quant_config.block_shape != deep_gemm_block_shape()):
return False
if not isinstance(module.quant_method.fused_experts,
FusedMoEModularKernel):
# fused_experts could invoke deep_gemm_moe_fp8
return True
mk: FusedMoEModularKernel = module.quant_method.fused_experts
# Further check if the ModularKernel implementation uses the DeepGemmExperts
return isinstance(mk.fused_experts,
(DeepGemmExperts, TritonOrDeepGemmExperts))
FP8_GEMM_NT_WARMUP_CACHE: set[torch.Size] = set()
def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor,
max_tokens: int):
if w.size() in FP8_GEMM_NT_WARMUP_CACHE:
return
n, k = w.size()
block_m = deep_gemm_block_shape()[0]
device = w.device
a1q = torch.empty((max_tokens, k),
device=device,
dtype=torch.float8_e4m3fn)
a1q_scales = torch.empty((max_tokens, k // block_m),
device=device,
dtype=torch.float32)
out = torch.empty((max_tokens, n), device=device, dtype=torch.bfloat16)
pbar = tqdm(total=max_tokens,
desc=f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()})")
num_tokens = max_tokens
while num_tokens > 0:
fp8_gemm_nt((a1q[:num_tokens], a1q_scales[:num_tokens]), (w, ws),
out[:num_tokens])
pbar.update(1)
num_tokens -= 1
FP8_GEMM_NT_WARMUP_CACHE.add(w.size())
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: set[torch.Size] = set()
def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
w1: torch.Tensor, w2: torch.Tensor, w1_scale: torch.Tensor,
w2_scale: torch.Tensor, num_topk: int, max_tokens: int):
if (w1.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
and w2.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE):
return
assert w1.size(0) == w2.size(0), (
"w1 and w2 must have the same number of experts")
block_m = deep_gemm_block_shape()[0]
num_experts = w1.size(0)
device = w1.device
# Assumes all ranks have the same max_num_batched_tokens
max_tokens_across_dp = get_dp_group().world_size * max_tokens
max_tokens = min(max_tokens_across_dp, envs.VLLM_FUSED_MOE_CHUNK_SIZE)
# This is the maximum GroupedGemm M size that we expect to run
# the grouped_gemm with.
MAX_M = compute_aligned_M(max_tokens,
num_topk,
num_experts,
block_m,
expert_tokens_meta=None)
# Distribute expert-ids evenly.
MAX_BLOCKS = MAX_M // block_m
expert_ids_block = torch.randint(low=0,
high=num_experts,
size=(MAX_BLOCKS, ),
device=device,
dtype=torch.int32)
expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0)
def _warmup(w: torch.Tensor, w_scale: torch.Tensor):
_, n, k = w.size()
a1q = torch.empty((MAX_M, k), device=device, dtype=torch.float8_e4m3fn)
a1q_scales = torch.empty((MAX_M, k // block_m),
device=device,
dtype=torch.float32)
out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16)
pbar = tqdm(
total=MAX_BLOCKS,
desc=
f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()})"
)
num_tokens = MAX_M
while num_tokens > 0:
m_grouped_fp8_gemm_nt_contiguous(
(a1q[:num_tokens], a1q_scales[:num_tokens]), (w, w_scale),
out[:num_tokens], expert_ids[:num_tokens])
pbar.update(1)
num_tokens = num_tokens - block_m
for w, ws in [(w1, w1_scale), (w2, w2_scale)]:
if w.size() not in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE:
_warmup(w, ws)
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE.add(w.size())
def deepgemm_fp8_gemm_nt_warmup(model: torch.nn.Module, max_tokens: int):
dg_modules = [
m for m in model.modules() if _fp8_linear_may_use_deep_gemm(m)
]
for dgm in dg_modules:
w, ws, _ = _extract_data_from_linear_base_module(dgm)
_deepgemm_fp8_gemm_nt_warmup(w=w, ws=ws, max_tokens=max_tokens)
def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model: torch.nn.Module,
max_tokens: int):
dg_modules = [
m for m in model.modules()
if _fused_moe_grouped_gemm_may_use_deep_gemm(m)
]
for dgm in dg_modules:
w13, w13_scale, w2, w2_scale, num_topk = (
_extract_data_from_fused_moe_module(dgm))
_deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
w13, w2, w13_scale, w2_scale, num_topk, max_tokens)
def deep_gemm_warmup(model: torch.nn.Module, max_tokens: int):
deepgemm_fp8_gemm_nt_warmup(model, max_tokens)
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model, max_tokens)

View File

@@ -0,0 +1,83 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Warmup kernels used during model execution.
This is useful specifically for JIT'ed kernels as we don't want JIT'ing to
happen during model execution.
"""
from typing import TYPE_CHECKING
import torch
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import is_deep_gemm_supported
from vllm.utils.flashinfer import has_flashinfer
if TYPE_CHECKING:
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.gpu_worker import Worker
logger = init_logger(__name__)
def kernel_warmup(worker: "Worker"):
# Deep GEMM warmup
do_deep_gemm_warmup = (envs.VLLM_USE_DEEP_GEMM
and is_deep_gemm_supported()
and not envs.VLLM_SKIP_DEEP_GEMM_WARMUP)
if do_deep_gemm_warmup:
model = worker.get_model()
max_tokens = worker.scheduler_config.max_num_batched_tokens
deep_gemm_warmup(model, max_tokens)
# FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs
if has_flashinfer() and current_platform.has_device_capability(90):
flashinfer_autotune(worker.model_runner)
# FlashInfer attention warmup
# Only warmup if the model has FlashInfer attention groups
# and is not a pooling model
def _is_flashinfer_backend(backend):
try:
return backend.get_name() == "FLASHINFER"
except NotImplementedError:
return False
if not worker.model_runner.is_pooling_model and all(
_is_flashinfer_backend(group.backend)
for groups in worker.model_runner.attn_groups for group in groups):
logger.info("Warming up FlashInfer attention.")
# Warmup with mixed batch containing both prefill and decode tokens
# This is to warm up both prefill and decode attention kernels
worker.model_runner._dummy_run(
num_tokens=16,
skip_eplb=True,
is_profile=True,
force_attention=True,
create_mixed_batch=True,
)
def flashinfer_autotune(runner: "GPUModelRunner") -> None:
"""
Autotune FlashInfer operations.
FlashInfer have many implementations for the same operation,
autotuning runs benchmarks for each implementation and stores
the results. The results are cached transparently and
future calls to FlashInfer will use the best implementation.
Without autotuning, FlashInfer will rely on heuristics, which may
be significantly slower.
"""
from vllm.utils.flashinfer import autotune
with torch.inference_mode(), autotune():
# We skip EPLB here since we don't want to record dummy metrics
# When autotuning with number of tokens m, flashinfer will autotune
# operations for all number of tokens up to m.
# So we only need to run with the max number of tokens.
runner._dummy_run(runner.scheduler_config.max_num_batched_tokens,
skip_eplb=True,
is_profile=True)