Sync from v0.13
This commit is contained in:
400
vllm/utils/deep_gemm.py
Normal file
400
vllm/utils/deep_gemm.py
Normal file
@@ -0,0 +1,400 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Compatibility wrapper for DeepGEMM API changes.
|
||||
|
||||
Users of vLLM should always import **only** these wrappers.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from typing import Any, NoReturn
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.import_utils import has_deep_gemm
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
|
||||
class DeepGemmQuantScaleFMT(Enum):
|
||||
# Float32 scales in Float32 tensor
|
||||
FLOAT32 = 0
|
||||
# Compute float32 scales and ceil the scales to UE8M0.
|
||||
# Keep the scales in Float32 tensor.
|
||||
FLOAT32_CEIL_UE8M0 = 1
|
||||
# Compute float32 scales and ceil the scales to UE8M0.
|
||||
# Pack the scales into a int32 tensor where each int32
|
||||
# element contains 4 scale values.
|
||||
UE8M0 = 2
|
||||
|
||||
@staticmethod
|
||||
def from_oracle() -> "DeepGemmQuantScaleFMT":
|
||||
if not is_deep_gemm_e8m0_used():
|
||||
return DeepGemmQuantScaleFMT.FLOAT32
|
||||
return (
|
||||
DeepGemmQuantScaleFMT.UE8M0
|
||||
if current_platform.is_device_capability_family(100)
|
||||
else DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0
|
||||
)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def is_deep_gemm_supported() -> bool:
|
||||
"""Return `True` if DeepGEMM is supported on the current platform.
|
||||
Currently, only Hopper and Blackwell GPUs are supported.
|
||||
"""
|
||||
is_supported_arch = current_platform.is_cuda() and (
|
||||
current_platform.is_device_capability(90)
|
||||
or current_platform.is_device_capability_family(100)
|
||||
)
|
||||
return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch
|
||||
|
||||
|
||||
@functools.cache
|
||||
def is_deep_gemm_e8m0_used() -> bool:
|
||||
"""Return `True` if vLLM is configured to use DeepGEMM "
|
||||
"E8M0 scale on a Hopper or Blackwell-class GPU.
|
||||
"""
|
||||
if not is_deep_gemm_supported():
|
||||
logger.debug_once(
|
||||
"DeepGEMM E8M0 disabled: DeepGEMM not supported on this system."
|
||||
)
|
||||
return False
|
||||
|
||||
_lazy_init()
|
||||
|
||||
if _fp8_gemm_nt_impl is None:
|
||||
logger.info_once("DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found")
|
||||
return False
|
||||
|
||||
if envs.VLLM_USE_DEEP_GEMM_E8M0:
|
||||
logger.info_once("DeepGEMM E8M0 enabled on current platform.")
|
||||
return True
|
||||
|
||||
logger.info_once("DeepGEMM E8M0 disabled on current configuration.")
|
||||
return False
|
||||
|
||||
|
||||
def _missing(*_: Any, **__: Any) -> NoReturn:
|
||||
"""Placeholder for unavailable DeepGEMM backend."""
|
||||
raise RuntimeError(
|
||||
"DeepGEMM backend is not available or outdated. Please install or "
|
||||
"update the `deep_gemm` to a newer version to enable FP8 kernels."
|
||||
)
|
||||
|
||||
|
||||
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
|
||||
_grouped_impl: Callable[..., Any] | None = None
|
||||
_grouped_masked_impl: Callable[..., Any] | None = None
|
||||
_fp8_mqa_logits_impl: Callable[..., Any] | None = None
|
||||
_fp8_paged_mqa_logits_impl: Callable[..., Any] | None = None
|
||||
_get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None
|
||||
_get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None
|
||||
_get_mk_alignment_for_contiguous_layout_impl: Callable[..., Any] | None = None
|
||||
_transform_sf_into_required_layout_impl: Callable[..., Any] | None = None
|
||||
|
||||
|
||||
def _lazy_init() -> None:
|
||||
"""Import deep_gemm and resolve symbols on first use."""
|
||||
global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl
|
||||
global _fp8_mqa_logits_impl, _fp8_paged_mqa_logits_impl
|
||||
global _get_paged_mqa_logits_metadata_impl
|
||||
global _get_mn_major_tma_aligned_tensor_impl
|
||||
global _get_mk_alignment_for_contiguous_layout_impl
|
||||
global _transform_sf_into_required_layout_impl
|
||||
# fast path
|
||||
if (
|
||||
_fp8_gemm_nt_impl is not None
|
||||
or _grouped_impl is not None
|
||||
or _grouped_masked_impl is not None
|
||||
or _fp8_mqa_logits_impl is not None
|
||||
or _fp8_paged_mqa_logits_impl is not None
|
||||
or _get_paged_mqa_logits_metadata_impl is not None
|
||||
or _get_mk_alignment_for_contiguous_layout_impl is not None
|
||||
or _transform_sf_into_required_layout_impl is not None
|
||||
):
|
||||
return
|
||||
|
||||
if not has_deep_gemm():
|
||||
return
|
||||
|
||||
# Set up deep_gemm cache path
|
||||
DEEP_GEMM_JIT_CACHE_ENV_NAME = "DG_JIT_CACHE_DIR"
|
||||
if not os.environ.get(DEEP_GEMM_JIT_CACHE_ENV_NAME, None):
|
||||
os.environ[DEEP_GEMM_JIT_CACHE_ENV_NAME] = os.path.join(
|
||||
envs.VLLM_CACHE_ROOT, "deep_gemm"
|
||||
)
|
||||
|
||||
_dg = importlib.import_module("deep_gemm")
|
||||
|
||||
_fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None)
|
||||
_grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None)
|
||||
_grouped_masked_impl = getattr(_dg, "fp8_m_grouped_gemm_nt_masked", None)
|
||||
_fp8_mqa_logits_impl = getattr(_dg, "fp8_mqa_logits", None)
|
||||
_fp8_paged_mqa_logits_impl = getattr(_dg, "fp8_paged_mqa_logits", None)
|
||||
_get_paged_mqa_logits_metadata_impl = getattr(
|
||||
_dg, "get_paged_mqa_logits_metadata", None
|
||||
)
|
||||
_get_mn_major_tma_aligned_tensor_impl = getattr(
|
||||
_dg, "get_mn_major_tma_aligned_tensor", None
|
||||
)
|
||||
_get_mk_alignment_for_contiguous_layout_impl = getattr(
|
||||
_dg, "get_mk_alignment_for_contiguous_layout", None
|
||||
)
|
||||
_transform_sf_into_required_layout_impl = getattr(
|
||||
_dg, "transform_sf_into_required_layout", None
|
||||
)
|
||||
|
||||
|
||||
def get_num_sms() -> int:
|
||||
_lazy_init()
|
||||
_dg = importlib.import_module("deep_gemm")
|
||||
return int(_dg.get_num_sms())
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_mk_alignment_for_contiguous_layout() -> list[int]:
|
||||
_lazy_init()
|
||||
if _get_mk_alignment_for_contiguous_layout_impl is None:
|
||||
return _missing()
|
||||
mk_align_size = _get_mk_alignment_for_contiguous_layout_impl()
|
||||
return [mk_align_size, mk_align_size]
|
||||
|
||||
|
||||
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor"""
|
||||
_lazy_init()
|
||||
if _get_mn_major_tma_aligned_tensor_impl is None:
|
||||
return _missing()
|
||||
return _get_mn_major_tma_aligned_tensor_impl(x)
|
||||
|
||||
|
||||
def fp8_gemm_nt(*args, **kwargs):
|
||||
_lazy_init()
|
||||
if _fp8_gemm_nt_impl is None:
|
||||
return _missing(*args, **kwargs)
|
||||
if "is_deep_gemm_e8m0_used" in kwargs:
|
||||
use_ue8m0 = kwargs["is_deep_gemm_e8m0_used"]
|
||||
del kwargs["is_deep_gemm_e8m0_used"]
|
||||
else:
|
||||
use_ue8m0 = is_deep_gemm_e8m0_used()
|
||||
return _fp8_gemm_nt_impl(*args, disable_ue8m0_cast=not use_ue8m0, **kwargs)
|
||||
|
||||
|
||||
def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
|
||||
_lazy_init()
|
||||
if _grouped_impl is None:
|
||||
return _missing(*args, **kwargs)
|
||||
return _grouped_impl(
|
||||
*args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
|
||||
)
|
||||
|
||||
|
||||
def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
|
||||
_lazy_init()
|
||||
if _grouped_masked_impl is None:
|
||||
return _missing(*args, **kwargs)
|
||||
return _grouped_masked_impl(
|
||||
*args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
|
||||
)
|
||||
|
||||
|
||||
def transform_sf_into_required_layout(*args, **kwargs):
|
||||
_lazy_init()
|
||||
if _transform_sf_into_required_layout_impl is None:
|
||||
return _missing(*args, **kwargs)
|
||||
return _transform_sf_into_required_layout_impl(
|
||||
*args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
|
||||
)
|
||||
|
||||
|
||||
def fp8_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
kv: tuple[torch.Tensor, torch.Tensor],
|
||||
weights: torch.Tensor,
|
||||
cu_seqlen_ks: torch.Tensor,
|
||||
cu_seqlen_ke: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute FP8 MQA logits for a single sequence without KV paging.
|
||||
|
||||
Args:
|
||||
q: Query tensor of shape [M, H, D]. Casted to
|
||||
`torch.float8_e4m3fn` by caller.
|
||||
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
|
||||
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
|
||||
[N, 1]) with dtype `torch.float32`.
|
||||
weights: weights of shape [M, H], dtype `torch.float32`.
|
||||
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
|
||||
Returns:
|
||||
Logits tensor of shape [M, N], dtype `torch.float32`.
|
||||
"""
|
||||
_lazy_init()
|
||||
if _fp8_mqa_logits_impl is None:
|
||||
return _missing()
|
||||
return _fp8_mqa_logits_impl(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)
|
||||
|
||||
|
||||
def get_paged_mqa_logits_metadata(
|
||||
context_lens: torch.Tensor, block_size: int, num_sms: int
|
||||
) -> torch.Tensor:
|
||||
"""Build scheduling metadata for paged MQA logits.
|
||||
|
||||
Args:
|
||||
context_lens: Tensor of shape [B], dtype int32; effective context length
|
||||
per batch element.
|
||||
block_size: KV-cache block size in tokens (e.g., 64).
|
||||
num_sms: Number of SMs available. 132 for Hopper
|
||||
|
||||
Returns:
|
||||
Backend-specific tensor consumed by `fp8_paged_mqa_logits` to
|
||||
schedule work across SMs.
|
||||
"""
|
||||
_lazy_init()
|
||||
if _get_paged_mqa_logits_metadata_impl is None:
|
||||
return _missing()
|
||||
return _get_paged_mqa_logits_metadata_impl(context_lens, block_size, num_sms)
|
||||
|
||||
|
||||
def fp8_paged_mqa_logits(
|
||||
q_fp8: torch.Tensor,
|
||||
kv_cache_fp8: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
schedule_metadata: torch.Tensor,
|
||||
max_model_len: int,
|
||||
) -> torch.Tensor:
|
||||
"""Compute FP8 MQA logits using paged KV-cache.
|
||||
|
||||
Args:
|
||||
q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to
|
||||
`torch.float8_e4m3fn` by caller.
|
||||
kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape
|
||||
[num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
|
||||
4 bytes per (block,pos) store the `float` dequant scale.
|
||||
weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
|
||||
context_lens: Tensor of shape [B], dtype int32; effective context length
|
||||
for each batch element.
|
||||
block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
|
||||
block indices to physical blocks in the paged cache.
|
||||
schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
|
||||
used to distribute work across SMs.
|
||||
max_model_len: Maximum sequence length used to size the logits output.
|
||||
|
||||
Returns:
|
||||
Logits tensor of shape [B * next_n, max_model_len], dtype
|
||||
`torch.float32`.
|
||||
"""
|
||||
_lazy_init()
|
||||
if _fp8_paged_mqa_logits_impl is None:
|
||||
return _missing()
|
||||
return _fp8_paged_mqa_logits_impl(
|
||||
q_fp8,
|
||||
kv_cache_fp8,
|
||||
weights,
|
||||
context_lens,
|
||||
block_tables,
|
||||
schedule_metadata,
|
||||
max_model_len,
|
||||
clean_logits=True,
|
||||
)
|
||||
|
||||
|
||||
def _ceil_to_ue8m0(x: torch.Tensor):
|
||||
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
|
||||
|
||||
|
||||
def _align(x: int, y: int) -> int:
|
||||
return cdiv(x, y) * y
|
||||
|
||||
|
||||
DEFAULT_BLOCK_SIZE = [128, 128]
|
||||
|
||||
|
||||
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/dd6ed14acbc7445dcef224248a77ab4d22b5f240/deep_gemm/utils/math.py#L38
|
||||
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||
def per_block_cast_to_fp8(
|
||||
x: torch.Tensor, block_size: list[int] = DEFAULT_BLOCK_SIZE, use_ue8m0: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
block_m, block_n = block_size
|
||||
x_padded = torch.zeros(
|
||||
(_align(m, block_m), _align(n, block_n)), dtype=x.dtype, device=x.device
|
||||
)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
sf = x_amax / 224.0 if current_platform.is_fp8_fnuz() else x_amax / 448.0
|
||||
sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
||||
x_scaled = (x_view * (1.0 / sf)).to(fp8_dtype)
|
||||
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
|
||||
x_view.size(0), x_view.size(2)
|
||||
)
|
||||
|
||||
|
||||
def calc_diff(x: torch.Tensor, y: torch.Tensor):
|
||||
"""Return a global difference metric for unit tests.
|
||||
|
||||
DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element
|
||||
error, causing `torch.testing.assert_close` to fail. Instead of checking
|
||||
every element, we compute a cosine-style similarity over the whole tensor
|
||||
and report `1 - sim`. Once kernel accuracy improves this helper can be
|
||||
removed.
|
||||
"""
|
||||
|
||||
x, y = x.double(), y.double()
|
||||
denominator = (x * x + y * y).sum()
|
||||
sim = 2 * (x * y).sum() / denominator
|
||||
return 1 - sim
|
||||
|
||||
|
||||
def should_use_deepgemm_for_fp8_linear(
|
||||
output_dtype: torch.dtype,
|
||||
weight: torch.Tensor,
|
||||
supports_deep_gemm: bool | None = None,
|
||||
):
|
||||
if supports_deep_gemm is None:
|
||||
supports_deep_gemm = is_deep_gemm_supported()
|
||||
|
||||
# Verify DeepGEMM N/K dims requirements
|
||||
# NOTE: Also synchronized with test_w8a8_block_fp8_deep_gemm_matmul
|
||||
# test inside kernels/quatization/test_block_fp8.py
|
||||
N_MULTIPLE = 64
|
||||
K_MULTIPLE = 128
|
||||
|
||||
return (
|
||||
supports_deep_gemm
|
||||
and output_dtype == torch.bfloat16
|
||||
and weight.shape[0] % N_MULTIPLE == 0
|
||||
and weight.shape[1] % K_MULTIPLE == 0
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"calc_diff",
|
||||
"DeepGemmQuantScaleFMT",
|
||||
"fp8_gemm_nt",
|
||||
"m_grouped_fp8_gemm_nt_contiguous",
|
||||
"fp8_m_grouped_gemm_nt_masked",
|
||||
"fp8_mqa_logits",
|
||||
"fp8_paged_mqa_logits",
|
||||
"get_paged_mqa_logits_metadata",
|
||||
"per_block_cast_to_fp8",
|
||||
"is_deep_gemm_e8m0_used",
|
||||
"is_deep_gemm_supported",
|
||||
"get_num_sms",
|
||||
"should_use_deepgemm_for_fp8_linear",
|
||||
"get_col_major_tma_aligned_tensor",
|
||||
"get_mk_alignment_for_contiguous_layout",
|
||||
]
|
||||
Reference in New Issue
Block a user