Sync from v0.13
This commit is contained in:
528
vllm/utils/flashinfer.py
Normal file
528
vllm/utils/flashinfer.py
Normal file
@@ -0,0 +1,528 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Compatibility wrapper for FlashInfer API changes.
|
||||
|
||||
Users of vLLM should always import **only** these wrappers.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import importlib
|
||||
import importlib.util
|
||||
import os
|
||||
import shutil
|
||||
from collections.abc import Callable
|
||||
from typing import Any, NoReturn
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# This is the storage path for the cubins, it can be replaced
|
||||
# with a local path for testing.
|
||||
# Referenced from https://github.com/flashinfer-ai/flashinfer/blob/0c9a92c3d9a7e043ab6f3f7b2273269caf6ab044/flashinfer/jit/cubin_loader.py#L35 # noqa: E501
|
||||
FLASHINFER_CUBINS_REPOSITORY = os.environ.get(
|
||||
"FLASHINFER_CUBINS_REPOSITORY",
|
||||
"https://edge.urm.nvidia.com/artifactory/sw-kernelinferencelibrary-public-generic-local/", # noqa: E501
|
||||
)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_cubin() -> bool:
|
||||
"""Return `True` if flashinfer-cubin package is available."""
|
||||
if envs.VLLM_HAS_FLASHINFER_CUBIN:
|
||||
return True
|
||||
if importlib.util.find_spec("flashinfer_cubin") is not None:
|
||||
return True
|
||||
logger.debug_once("flashinfer-cubin package was not found")
|
||||
return False
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer() -> bool:
|
||||
"""Return `True` if flashinfer-python package is available."""
|
||||
# Use find_spec to check if the module exists without importing it
|
||||
# This avoids potential CUDA initialization side effects
|
||||
if importlib.util.find_spec("flashinfer") is None:
|
||||
logger.debug_once("FlashInfer unavailable since package was not found")
|
||||
return False
|
||||
# When not using flashinfer cubin,
|
||||
# Also check if nvcc is available since it's required to JIT compile flashinfer
|
||||
if not has_flashinfer_cubin() and shutil.which("nvcc") is None:
|
||||
logger.debug_once(
|
||||
"FlashInfer unavailable since nvcc was not found "
|
||||
"and not using pre-downloaded cubins"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _missing(*_: Any, **__: Any) -> NoReturn:
|
||||
"""Placeholder for unavailable FlashInfer backend."""
|
||||
raise RuntimeError(
|
||||
"FlashInfer backend is not available. Please install the package "
|
||||
"to enable FlashInfer kernels: "
|
||||
"https://github.com/flashinfer-ai/flashinfer"
|
||||
)
|
||||
|
||||
|
||||
def _get_submodule(module_name: str) -> Any | None:
|
||||
"""Safely import a submodule and return it, or None if not available."""
|
||||
try:
|
||||
return importlib.import_module(module_name)
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
return None
|
||||
|
||||
|
||||
# General lazy import wrapper
|
||||
def _lazy_import_wrapper(
|
||||
module_name: str, attr_name: str, fallback_fn: Callable[..., Any] = _missing
|
||||
):
|
||||
"""Create a lazy import wrapper for a specific function."""
|
||||
|
||||
@functools.cache
|
||||
def _get_impl():
|
||||
if not has_flashinfer():
|
||||
return None
|
||||
mod = _get_submodule(module_name)
|
||||
return getattr(mod, attr_name, None) if mod else None
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
impl = _get_impl()
|
||||
if impl is None:
|
||||
return fallback_fn(*args, **kwargs)
|
||||
return impl(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# Create lazy wrappers for each function
|
||||
flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper(
|
||||
"flashinfer.fused_moe", "trtllm_fp8_block_scale_moe"
|
||||
)
|
||||
flashinfer_trtllm_fp8_per_tensor_scale_moe = _lazy_import_wrapper(
|
||||
"flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"
|
||||
)
|
||||
flashinfer_cutlass_fused_moe = _lazy_import_wrapper(
|
||||
"flashinfer.fused_moe", "cutlass_fused_moe"
|
||||
)
|
||||
flashinfer_cutedsl_grouped_gemm_nt_masked = _lazy_import_wrapper(
|
||||
"flashinfer.cute_dsl.blockscaled_gemm", "grouped_gemm_nt_masked"
|
||||
)
|
||||
flashinfer_fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
|
||||
nvfp4_batched_quantize = _lazy_import_wrapper("flashinfer", "nvfp4_batched_quantize")
|
||||
silu_and_mul_scaled_nvfp4_experts_quantize = _lazy_import_wrapper(
|
||||
"flashinfer", "silu_and_mul_scaled_nvfp4_experts_quantize"
|
||||
)
|
||||
scaled_fp4_grouped_quantize = _lazy_import_wrapper(
|
||||
"flashinfer", "scaled_fp4_grouped_quantize"
|
||||
)
|
||||
nvfp4_block_scale_interleave = _lazy_import_wrapper(
|
||||
"flashinfer", "nvfp4_block_scale_interleave"
|
||||
)
|
||||
trtllm_fp4_block_scale_moe = _lazy_import_wrapper(
|
||||
"flashinfer", "trtllm_fp4_block_scale_moe"
|
||||
)
|
||||
|
||||
# Special case for autotune since it returns a context manager
|
||||
autotune = _lazy_import_wrapper(
|
||||
"flashinfer.autotuner",
|
||||
"autotune",
|
||||
fallback_fn=lambda *args, **kwargs: contextlib.nullcontext(),
|
||||
)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_comm() -> bool:
|
||||
"""Return `True` if FlashInfer comm module is available."""
|
||||
return has_flashinfer() and importlib.util.find_spec("flashinfer.comm") is not None
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_all2all() -> bool:
|
||||
"""Return `True` if FlashInfer mnnvl all2all is available."""
|
||||
if not has_flashinfer_comm():
|
||||
return False
|
||||
|
||||
# Check if all required functions are available
|
||||
required_functions = [
|
||||
("flashinfer.comm", "Mapping"),
|
||||
("flashinfer.comm.mnnvl", "MnnvlMemory"),
|
||||
("flashinfer.comm.trtllm_alltoall", "MnnvlMoe"),
|
||||
("flashinfer.comm.trtllm_alltoall", "MoEAlltoallInfo"),
|
||||
]
|
||||
|
||||
for module_name, attr_name in required_functions:
|
||||
mod = _get_submodule(module_name)
|
||||
if not mod or not hasattr(mod, attr_name):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_moe() -> bool:
|
||||
"""Return `True` if FlashInfer MoE module is available."""
|
||||
return (
|
||||
has_flashinfer()
|
||||
and importlib.util.find_spec("flashinfer.fused_moe") is not None
|
||||
)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_cutedsl() -> bool:
|
||||
"""Return ``True`` if FlashInfer cutedsl module is available."""
|
||||
return (
|
||||
has_flashinfer() and importlib.util.find_spec("flashinfer.cute_dsl") is not None
|
||||
)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_cutlass_fused_moe() -> bool:
|
||||
"""Return `True` if FlashInfer CUTLASS fused MoE is available."""
|
||||
if not has_flashinfer_moe():
|
||||
return False
|
||||
|
||||
# Check if all required functions are available
|
||||
required_functions = [
|
||||
("flashinfer.fused_moe", "cutlass_fused_moe"),
|
||||
("flashinfer", "fp4_quantize"),
|
||||
("flashinfer", "nvfp4_block_scale_interleave"),
|
||||
("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
|
||||
]
|
||||
|
||||
for module_name, attr_name in required_functions:
|
||||
mod = _get_submodule(module_name)
|
||||
if not mod or not hasattr(mod, attr_name):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_cutedsl_grouped_gemm_nt_masked() -> bool:
|
||||
"""Return ``True`` if FlashInfer CUTLASS fused MoE is available."""
|
||||
if not has_flashinfer_cutedsl():
|
||||
return False
|
||||
|
||||
# Check if all required functions are available
|
||||
required_functions = [
|
||||
("flashinfer.cute_dsl.blockscaled_gemm", "grouped_gemm_nt_masked"),
|
||||
("flashinfer", "scaled_fp4_grouped_quantize"),
|
||||
("flashinfer", "silu_and_scaled_nvfp4_experts_quantize"),
|
||||
]
|
||||
|
||||
for module_name, attr_name in required_functions:
|
||||
mod = _get_submodule(module_name)
|
||||
if not mod or not hasattr(mod, attr_name):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_nvidia_artifactory() -> bool:
|
||||
"""Return `True` if NVIDIA's artifactory is accessible.
|
||||
|
||||
This checks connectivity to the kernel inference library artifactory
|
||||
which is required for downloading certain cubin kernels like TRTLLM FHMA.
|
||||
"""
|
||||
# If we have pre-downloaded cubins, we can assume the cubins are available.
|
||||
if has_flashinfer_cubin():
|
||||
return True
|
||||
|
||||
try:
|
||||
# Use a short timeout to avoid blocking for too long
|
||||
response = requests.get(FLASHINFER_CUBINS_REPOSITORY, timeout=5)
|
||||
accessible = response.status_code == 200
|
||||
if accessible:
|
||||
logger.debug_once("NVIDIA artifactory is accessible")
|
||||
else:
|
||||
logger.warning_once(
|
||||
"NVIDIA artifactory returned failed status code: %d",
|
||||
response.status_code,
|
||||
)
|
||||
return accessible
|
||||
except Exception as e:
|
||||
logger.warning_once("Failed to connect to NVIDIA artifactory: %s", e)
|
||||
return False
|
||||
|
||||
|
||||
@functools.cache
|
||||
def supports_trtllm_attention() -> bool:
|
||||
"""
|
||||
TRTLLM attention is supported if the platform is SM100,
|
||||
NVIDIA artifactory is accessible, and batch-invariant mode is not enabled.
|
||||
"""
|
||||
# Batch-invariant mode disables TRTLLM attention
|
||||
if vllm_is_batch_invariant():
|
||||
return False
|
||||
|
||||
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
|
||||
return (
|
||||
current_platform.is_device_capability_family(100) and has_nvidia_artifactory()
|
||||
)
|
||||
|
||||
|
||||
def force_use_trtllm_attention() -> bool | None:
|
||||
"""
|
||||
This function should only be called during initialization stage when vllm config
|
||||
is set.
|
||||
Return `None` if --attention-config.use_trtllm_attention is not set,
|
||||
return `True` if TRTLLM attention is forced to be used,
|
||||
return `False` if TRTLLM attention is forced to be not used.
|
||||
"""
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
return vllm_config.attention_config.use_trtllm_attention
|
||||
|
||||
|
||||
def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
|
||||
"""Check if the current configuration supports TRTLLM attention."""
|
||||
if force_use_trtllm_attention() is False:
|
||||
return False
|
||||
has_trtllm = supports_trtllm_attention()
|
||||
return has_trtllm and (num_qo_heads % num_kv_heads == 0)
|
||||
|
||||
|
||||
def use_trtllm_attention(
|
||||
num_qo_heads: int,
|
||||
num_kv_heads: int,
|
||||
num_tokens: int,
|
||||
max_seq_len: int,
|
||||
dcp_world_size: int,
|
||||
kv_cache_dtype: str,
|
||||
q_dtype: torch.dtype,
|
||||
is_prefill: bool,
|
||||
# None means auto-detection, True means force on, False means force off
|
||||
force_use_trtllm: bool | None = None,
|
||||
has_sinks: bool = False,
|
||||
has_spec: bool = False,
|
||||
) -> bool:
|
||||
"""Return `True` if TRTLLM attention is used."""
|
||||
|
||||
# CLI argument is set to 0 - respect it
|
||||
if force_use_trtllm is not None and not force_use_trtllm:
|
||||
return False
|
||||
|
||||
# Decode context parallel is not supported
|
||||
if dcp_world_size > 1:
|
||||
logger.warning_once(
|
||||
"Trtllm does not support returning LSE and as a result "
|
||||
"does not support DCP, reverting to FlashInfer"
|
||||
)
|
||||
return False
|
||||
|
||||
# The platform is not supported
|
||||
if not supports_trtllm_attention():
|
||||
if force_use_trtllm:
|
||||
logger.warning_once(
|
||||
"TRTLLM attention is not supported on this platform, "
|
||||
"but --attention-config.use_trtllm_attention is set to 1"
|
||||
)
|
||||
return False
|
||||
|
||||
# The combination of query and key heads is not supported
|
||||
if num_qo_heads % num_kv_heads != 0:
|
||||
if force_use_trtllm:
|
||||
logger.warning_once(
|
||||
"TRTLLM attention is not supported for this combination of "
|
||||
"query and key heads, but --attention-config.use_trtllm_attention is "
|
||||
"set to 1"
|
||||
)
|
||||
return False
|
||||
|
||||
if has_spec and not is_prefill:
|
||||
# Speculative decoding requires TRTLLM attention for decodes
|
||||
logger.info_once("Using TRTLLM attention (enabled for speculative decoding).")
|
||||
return True
|
||||
|
||||
# Must use TRTLLM attention if query is FP8 quantized
|
||||
if q_dtype == current_platform.fp8_dtype():
|
||||
logger.info_once("Using TRTLLM attention (query is quantized).")
|
||||
return True
|
||||
|
||||
# If sinks are being used, we must use TRTLLM attention as it's
|
||||
# the only backend that supports them
|
||||
if has_sinks:
|
||||
logger.info_once("Using TRTLLM attention (required for attention sinks).")
|
||||
return True
|
||||
|
||||
if force_use_trtllm is None:
|
||||
# CLI argument not set - use auto-detection
|
||||
if is_prefill:
|
||||
# Prefill auto-detection
|
||||
use_trtllm = kv_cache_dtype == "auto"
|
||||
if use_trtllm:
|
||||
logger.warning_once("Using TRTLLM prefill attention (auto-detected).")
|
||||
else:
|
||||
# Decode auto-detection
|
||||
use_trtllm = num_tokens <= 256 and kv_cache_dtype == "auto"
|
||||
if use_trtllm:
|
||||
logger.warning_once("Using TRTLLM decode attention (auto-detected).")
|
||||
return use_trtllm
|
||||
|
||||
# CLI argument is set to 1 - respect it
|
||||
logger.info_once(
|
||||
"Using TRTLLM attention (--attention-config.use_trtllm_attention is set to 1)"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
if has_flashinfer():
|
||||
|
||||
@torch.library.custom_op(
|
||||
"vllm::flashinfer_mm_fp4",
|
||||
mutates_args=[],
|
||||
device_types="cuda",
|
||||
)
|
||||
def flashinfer_mm_fp4(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
g_scale: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
backend: str,
|
||||
) -> torch.Tensor:
|
||||
from flashinfer import mm_fp4 as flashinfer_mm_fp4_
|
||||
|
||||
return flashinfer_mm_fp4_(
|
||||
A, B, A_scale, B_scale, g_scale, dtype, block_size=16, backend=backend
|
||||
)
|
||||
|
||||
@torch.library.register_fake(
|
||||
"vllm::flashinfer_mm_fp4",
|
||||
)
|
||||
def flashinfer_mm_fp4_fake(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
g_scale: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
backend: str,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(A.shape[0], B.shape[1], dtype=dtype, device=A.device)
|
||||
|
||||
@torch.library.custom_op(
|
||||
"vllm::bmm_fp8",
|
||||
mutates_args=[],
|
||||
device_types="cuda",
|
||||
)
|
||||
def bmm_fp8(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
backend: str,
|
||||
) -> torch.Tensor:
|
||||
from flashinfer import bmm_fp8 as bmm_fp8_
|
||||
|
||||
return bmm_fp8_(A, B, A_scale, B_scale, dtype, None, backend)
|
||||
|
||||
@torch.library.register_fake(
|
||||
"vllm::bmm_fp8",
|
||||
)
|
||||
def bmm_fp8_fake(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
backend: str,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(
|
||||
A.shape[0], A.shape[1], B.shape[2], dtype=dtype, device=A.device
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_scaled_fp4_mm(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
block_scale_a: torch.Tensor,
|
||||
block_scale_b: torch.Tensor,
|
||||
alpha: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
backend: str,
|
||||
) -> torch.Tensor:
|
||||
assert a.ndim == 2 and b.ndim == 2
|
||||
assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2
|
||||
assert a.stride(-1) == 1 and b.stride(-1) == 1
|
||||
assert a.shape[1] == b.shape[1]
|
||||
|
||||
if backend == "cutlass":
|
||||
block_scale_a = block_scale_a.view(torch.uint8)
|
||||
block_scale_b = block_scale_b.view(torch.uint8)
|
||||
|
||||
return flashinfer_mm_fp4(
|
||||
a,
|
||||
b.t(),
|
||||
block_scale_a,
|
||||
block_scale_b.t(),
|
||||
alpha,
|
||||
out_dtype,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_scaled_fp8_mm(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert a.ndim == 2 and b.ndim == 2
|
||||
assert a.shape[1] == b.shape[0]
|
||||
assert scale_a.numel() == 1 and scale_b.numel() == 1
|
||||
assert a.dtype == torch.float8_e4m3fn and b.dtype == torch.float8_e4m3fn
|
||||
assert a.device.type == "cuda" and b.device.type == "cuda"
|
||||
assert scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32
|
||||
assert scale_a.device.type == "cuda" and scale_b.device.type == "cuda"
|
||||
|
||||
output = bmm_fp8(
|
||||
a.unsqueeze(0),
|
||||
b.unsqueeze(0),
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype,
|
||||
"auto",
|
||||
).view(a.shape[0], b.shape[1])
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output
|
||||
|
||||
|
||||
__all__ = [
|
||||
"has_flashinfer",
|
||||
"flashinfer_trtllm_fp8_block_scale_moe",
|
||||
"flashinfer_cutlass_fused_moe",
|
||||
"flashinfer_cutedsl_grouped_gemm_nt_masked",
|
||||
"flashinfer_fp4_quantize",
|
||||
"silu_and_mul_scaled_nvfp4_experts_quantize",
|
||||
"scaled_fp4_grouped_quantize",
|
||||
"nvfp4_block_scale_interleave",
|
||||
"trtllm_fp4_block_scale_moe",
|
||||
"autotune",
|
||||
"has_flashinfer_moe",
|
||||
"has_flashinfer_comm",
|
||||
"has_flashinfer_all2all",
|
||||
"has_flashinfer_cutlass_fused_moe",
|
||||
"has_flashinfer_cutedsl_grouped_gemm_nt_masked",
|
||||
"has_nvidia_artifactory",
|
||||
"supports_trtllm_attention",
|
||||
"can_use_trtllm_attention",
|
||||
"use_trtllm_attention",
|
||||
"flashinfer_scaled_fp4_mm",
|
||||
"flashinfer_scaled_fp8_mm",
|
||||
]
|
||||
Reference in New Issue
Block a user