315 lines
11 KiB
Python
315 lines
11 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""
|
|
Warmup deep_gemm kernels.
|
|
DeepGEMM JIT's the kernels. The warmup aims to JIT all the kernels that would
|
|
be used during model execution beforehand.
|
|
"""
|
|
|
|
import torch
|
|
from tqdm import tqdm
|
|
|
|
import vllm.envs as envs
|
|
from vllm.distributed.parallel_state import get_dp_group
|
|
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
|
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import compute_aligned_M
|
|
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEModularMethod
|
|
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
|
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
|
TritonOrDeepGemmExperts,
|
|
)
|
|
from vllm.model_executor.layers.linear import LinearBase
|
|
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
|
from vllm.utils.deep_gemm import (
|
|
fp8_gemm_nt,
|
|
get_mk_alignment_for_contiguous_layout,
|
|
m_grouped_fp8_gemm_nt_contiguous,
|
|
)
|
|
|
|
|
|
def _generate_optimal_warmup_m_values(
|
|
max_tokens: int, n: int, device: torch.device
|
|
) -> list[int]:
|
|
"""
|
|
Generate M values that cover all possible DeepGEMM kernel configurations.
|
|
Reference: https://github.com/deepseek-ai/DeepGEMM/blob/79f48ee15a82dd5fad5cd9beaa393c1f755e6b55/csrc/jit_kernels/heuristics/common.hpp
|
|
|
|
Args:
|
|
max_tokens: Maximum number of tokens to warmup for
|
|
n: The actual N dimension from the weight tensor
|
|
device: The torch device to get properties from.
|
|
"""
|
|
|
|
def ceil_div(a: int, b: int) -> int:
|
|
return (a + b - 1) // b
|
|
|
|
# DeepGEMM's possible block sizes
|
|
block_ms = [64, 128, 256]
|
|
block_ns = list(range(16, min(257, n + 1), 16))
|
|
num_sms = torch.cuda.get_device_properties(device).multi_processor_count
|
|
|
|
m_values = set()
|
|
|
|
# Always include small cases
|
|
m_values.update([1, 2, 4] + [i for i in range(8, 65, 8)])
|
|
|
|
# Collect M values where different wave patterns occur
|
|
for block_m in block_ms:
|
|
for block_n in block_ns:
|
|
if block_n > n:
|
|
continue
|
|
|
|
# Add key M boundaries for this block combination
|
|
for wave in range(1, 11): # Up to 10 waves
|
|
# M where this block config transitions to next wave
|
|
target_blocks = wave * num_sms
|
|
m = target_blocks * block_m // ceil_div(n, block_n)
|
|
if 1 <= m <= max_tokens:
|
|
m_values.add(m)
|
|
|
|
# Add block_m boundaries
|
|
for multiple in range(1, max_tokens // block_m + 1):
|
|
m = multiple * block_m
|
|
if m <= max_tokens:
|
|
m_values.add(m)
|
|
|
|
return sorted(m_values)
|
|
|
|
|
|
def _extract_data_from_linear_base_module(
|
|
m: torch.nn.Module,
|
|
) -> tuple[torch.Tensor, torch.Tensor, list[int]]:
|
|
"""
|
|
Extract weights, weight scales and quantization block sizes from the given
|
|
LinearBase module.
|
|
"""
|
|
assert isinstance(m, LinearBase)
|
|
assert isinstance(m.quant_method, Fp8LinearMethod)
|
|
assert m.quant_method.block_quant
|
|
assert m.quant_method.quant_config is not None
|
|
|
|
w = m.weight
|
|
ws = m.weight_scale_inv if hasattr(m, "weight_scale_inv") else m.weight_scale
|
|
quant_block_size = m.quant_method.quant_config.weight_block_size
|
|
|
|
assert isinstance(w, torch.Tensor)
|
|
assert isinstance(ws, torch.Tensor)
|
|
assert quant_block_size is not None
|
|
return (w, ws, quant_block_size)
|
|
|
|
|
|
def _extract_data_from_fused_moe_module(
|
|
m: torch.nn.Module,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]:
|
|
"""
|
|
Extract weights, weight scales and num_topk from FusedMoE module.
|
|
"""
|
|
assert isinstance(m, FusedMoE)
|
|
w13 = m.w13_weight
|
|
w13_s = (
|
|
m.w13_weight_scale_inv
|
|
if hasattr(m, "w13_weight_scale_inv")
|
|
else m.w13_weight_scale
|
|
)
|
|
w2 = m.w2_weight
|
|
w2_s = (
|
|
m.w2_weight_scale_inv
|
|
if hasattr(m, "w2_weight_scale_inv")
|
|
else m.w2_weight_scale
|
|
)
|
|
num_topk = m.top_k
|
|
|
|
assert isinstance(w13, torch.Tensor)
|
|
assert isinstance(w13_s, torch.Tensor)
|
|
assert isinstance(w2, torch.Tensor)
|
|
assert isinstance(w2_s, torch.Tensor)
|
|
return w13, w13_s, w2, w2_s, num_topk
|
|
|
|
|
|
def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool:
|
|
"""
|
|
Return True if the input module/layer could be processed with DeepGEMM.
|
|
"""
|
|
block_size = get_mk_alignment_for_contiguous_layout()[0]
|
|
if not (
|
|
isinstance(module, LinearBase)
|
|
and isinstance(module.quant_method, Fp8LinearMethod)
|
|
and module.quant_method.block_quant
|
|
):
|
|
return False
|
|
|
|
w, _, block_sizes = _extract_data_from_linear_base_module(module)
|
|
return (
|
|
block_sizes == get_mk_alignment_for_contiguous_layout()
|
|
and w.ndim == 2
|
|
and w.shape[0] % block_size == 0
|
|
and w.shape[1] % block_size == 0
|
|
)
|
|
|
|
|
|
def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
|
|
if not (envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM):
|
|
return False
|
|
|
|
if not isinstance(module, FusedMoE):
|
|
return False
|
|
|
|
moe_quant_config = module.quant_method.get_fused_moe_quant_config(module)
|
|
|
|
if (
|
|
moe_quant_config is None
|
|
or moe_quant_config.quant_dtype != torch.float8_e4m3fn
|
|
or moe_quant_config.block_shape != get_mk_alignment_for_contiguous_layout()
|
|
):
|
|
return False
|
|
|
|
if not isinstance(module.quant_method, FusedMoEModularMethod):
|
|
# modular kernels could invoke deep_gemm_moe_fp8
|
|
return True
|
|
|
|
mk: FusedMoEModularKernel = module.quant_method.fused_experts
|
|
# Further check if the ModularKernel implementation uses the DeepGemmExperts
|
|
return isinstance(mk.fused_experts, (DeepGemmExperts, TritonOrDeepGemmExperts))
|
|
|
|
|
|
FP8_GEMM_NT_WARMUP_CACHE: set[torch.Size] = set()
|
|
|
|
|
|
def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens: int):
|
|
if w.size() in FP8_GEMM_NT_WARMUP_CACHE:
|
|
return
|
|
|
|
n, k = w.size()
|
|
block_m = get_mk_alignment_for_contiguous_layout()[0]
|
|
|
|
device = w.device
|
|
a1q = torch.empty((max_tokens, k), device=device, dtype=torch.float8_e4m3fn)
|
|
a1q_scales = torch.empty(
|
|
(max_tokens, k // block_m), device=device, dtype=torch.float32
|
|
)
|
|
out = torch.empty((max_tokens, n), device=device, dtype=torch.bfloat16)
|
|
|
|
# Use optimal M values only if VLLM_DEEP_GEMM_WARMUP is set to "relax".
|
|
# Otherwise warmup all token sizes to avoid JIT compilation in hotpath
|
|
if envs.VLLM_DEEP_GEMM_WARMUP == "relax":
|
|
m_values = _generate_optimal_warmup_m_values(max_tokens, n, device)
|
|
desc = f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()}) [relaxed]"
|
|
else:
|
|
assert envs.VLLM_DEEP_GEMM_WARMUP == "full", (
|
|
"Expected "
|
|
'VLLM_DEEP_GEMM_WARMUP env to be set to "full" but got '
|
|
f"{envs.VLLM_DEEP_GEMM_WARMUP}"
|
|
)
|
|
m_values = list(range(1, max_tokens + 1))
|
|
desc = f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()}) [all tokens]"
|
|
|
|
pbar = tqdm(total=len(m_values), desc=desc)
|
|
|
|
for num_tokens in m_values:
|
|
fp8_gemm_nt(
|
|
(a1q[:num_tokens], a1q_scales[:num_tokens]), (w, ws), out[:num_tokens]
|
|
)
|
|
pbar.update(1)
|
|
|
|
FP8_GEMM_NT_WARMUP_CACHE.add(w.size())
|
|
|
|
|
|
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: set[torch.Size] = set()
|
|
|
|
|
|
def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
w1_scale: torch.Tensor,
|
|
w2_scale: torch.Tensor,
|
|
num_topk: int,
|
|
max_tokens: int,
|
|
):
|
|
if (
|
|
w1.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
|
|
and w2.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
|
|
):
|
|
return
|
|
|
|
assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts"
|
|
|
|
block_m = get_mk_alignment_for_contiguous_layout()[0]
|
|
num_experts = w1.size(0)
|
|
device = w1.device
|
|
|
|
# Assumes all ranks have the same max_num_batched_tokens
|
|
max_tokens_across_dp = get_dp_group().world_size * max_tokens
|
|
max_tokens = min(max_tokens_across_dp, envs.VLLM_FUSED_MOE_CHUNK_SIZE)
|
|
|
|
# This is the maximum GroupedGemm M size that we expect to run
|
|
# the grouped_gemm with.
|
|
MAX_M = compute_aligned_M(
|
|
max_tokens, num_topk, num_experts, block_m, expert_tokens_meta=None
|
|
)
|
|
# Distribute expert-ids evenly.
|
|
MAX_BLOCKS = MAX_M // block_m
|
|
expert_ids_block = torch.randint(
|
|
low=0, high=num_experts, size=(MAX_BLOCKS,), device=device, dtype=torch.int32
|
|
)
|
|
expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0)
|
|
|
|
def _warmup(w: torch.Tensor, w_scale: torch.Tensor):
|
|
_, n, k = w.size()
|
|
a1q = torch.empty((MAX_M, k), device=device, dtype=torch.float8_e4m3fn)
|
|
a1q_scales = torch.empty(
|
|
(MAX_M, k // block_m), device=device, dtype=torch.float32
|
|
)
|
|
out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16)
|
|
|
|
# Generate M values in block_m increments (already optimized for MoE)
|
|
m_values = list(range(block_m, MAX_M + 1, block_m))
|
|
|
|
pbar = tqdm(
|
|
total=len(m_values),
|
|
desc=f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()}) "
|
|
f"[{len(m_values)} values, block_m={block_m}]",
|
|
)
|
|
|
|
for num_tokens in m_values:
|
|
m_grouped_fp8_gemm_nt_contiguous(
|
|
(a1q[:num_tokens], a1q_scales[:num_tokens]),
|
|
(w, w_scale),
|
|
out[:num_tokens],
|
|
expert_ids[:num_tokens],
|
|
)
|
|
pbar.update(1)
|
|
|
|
for w, ws in [(w1, w1_scale), (w2, w2_scale)]:
|
|
if w.size() not in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE:
|
|
_warmup(w, ws)
|
|
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE.add(w.size())
|
|
|
|
|
|
def deepgemm_fp8_gemm_nt_warmup(model: torch.nn.Module, max_tokens: int):
|
|
dg_modules = [m for m in model.modules() if _fp8_linear_may_use_deep_gemm(m)]
|
|
|
|
for dgm in dg_modules:
|
|
w, ws, _ = _extract_data_from_linear_base_module(dgm)
|
|
_deepgemm_fp8_gemm_nt_warmup(w=w, ws=ws, max_tokens=max_tokens)
|
|
|
|
|
|
def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
|
|
model: torch.nn.Module, max_tokens: int
|
|
):
|
|
dg_modules = [
|
|
m for m in model.modules() if _fused_moe_grouped_gemm_may_use_deep_gemm(m)
|
|
]
|
|
|
|
for dgm in dg_modules:
|
|
w13, w13_scale, w2, w2_scale, num_topk = _extract_data_from_fused_moe_module(
|
|
dgm
|
|
)
|
|
_deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
|
|
w13, w2, w13_scale, w2_scale, num_topk, max_tokens
|
|
)
|
|
|
|
|
|
def deep_gemm_warmup(model: torch.nn.Module, max_tokens: int):
|
|
deepgemm_fp8_gemm_nt_warmup(model, max_tokens)
|
|
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model, max_tokens)
|