init
This commit is contained in:
3566
vllm/utils/__init__.py
Normal file
3566
vllm/utils/__init__.py
Normal file
File diff suppressed because it is too large
Load Diff
BIN
vllm/utils/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm/utils/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/utils/__pycache__/jsontree.cpython-312.pyc
Normal file
BIN
vllm/utils/__pycache__/jsontree.cpython-312.pyc
Normal file
Binary file not shown.
319
vllm/utils/deep_gemm.py
Normal file
319
vllm/utils/deep_gemm.py
Normal file
@@ -0,0 +1,319 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Compatibility wrapper for DeepGEMM API changes.
|
||||
|
||||
Users of vLLM should always import **only** these wrappers.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
import os
|
||||
from typing import Any, Callable, NoReturn
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv, has_deep_gemm
|
||||
|
||||
|
||||
@functools.cache
|
||||
def is_deep_gemm_supported() -> bool:
|
||||
"""Return ``True`` if DeepGEMM is supported on the current platform.
|
||||
Currently, only Hopper and Blackwell GPUs are supported.
|
||||
"""
|
||||
is_supported_arch = current_platform.is_cuda() and (
|
||||
current_platform.is_device_capability(90)
|
||||
or current_platform.is_device_capability(100))
|
||||
return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch
|
||||
|
||||
|
||||
@functools.cache
|
||||
def is_deep_gemm_e8m0_used() -> bool:
|
||||
"""Return ``True`` if vLLM is configured to use DeepGEMM "
|
||||
"E8M0 scale on a Hopper or Blackwell-class GPU.
|
||||
"""
|
||||
if not is_deep_gemm_supported():
|
||||
logger.debug_once(
|
||||
"DeepGEMM E8M0 disabled: DeepGEMM not supported on this system.")
|
||||
return False
|
||||
|
||||
_lazy_init()
|
||||
|
||||
if _fp8_gemm_nt_impl is None:
|
||||
logger.info_once("DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found")
|
||||
return False
|
||||
|
||||
if current_platform.is_device_capability(100) and \
|
||||
envs.VLLM_USE_DEEP_GEMM_E8M0:
|
||||
logger.info_once("DeepGEMM E8M0 enabled on Blackwell GPU.")
|
||||
return True
|
||||
|
||||
if current_platform.is_device_capability(90) and \
|
||||
envs.VLLM_USE_DEEP_GEMM_E8M0_HOPPER:
|
||||
logger.info_once("DeepGEMM E8M0 enabled on Hopper GPU.")
|
||||
return True
|
||||
|
||||
logger.info_once("DeepGEMM E8M0 disabled on current configuration.")
|
||||
return False
|
||||
|
||||
|
||||
def _missing(*_: Any, **__: Any) -> NoReturn:
|
||||
"""Placeholder for unavailable DeepGEMM backend."""
|
||||
raise RuntimeError(
|
||||
"DeepGEMM backend is not available. Please install the `deep_gemm` "
|
||||
"package to enable FP8 kernels.")
|
||||
|
||||
|
||||
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
|
||||
_grouped_impl: Callable[..., Any] | None = None
|
||||
_grouped_masked_impl: Callable[..., Any] | None = None
|
||||
_fp8_mqa_logits_impl: Callable[..., Any] | None = None
|
||||
_fp8_paged_mqa_logits_impl: Callable[..., Any] | None = None
|
||||
_get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None
|
||||
_get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None
|
||||
|
||||
|
||||
def _lazy_init() -> None:
|
||||
"""Import deep_gemm and resolve symbols on first use."""
|
||||
global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl
|
||||
global _fp8_mqa_logits_impl, _fp8_paged_mqa_logits_impl
|
||||
global _get_paged_mqa_logits_metadata_impl
|
||||
global _get_mn_major_tma_aligned_tensor_impl
|
||||
|
||||
# fast path
|
||||
if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None
|
||||
or _grouped_masked_impl is not None
|
||||
or _fp8_mqa_logits_impl is not None
|
||||
or _fp8_paged_mqa_logits_impl is not None
|
||||
or _get_paged_mqa_logits_metadata_impl is not None):
|
||||
return
|
||||
|
||||
if not has_deep_gemm():
|
||||
return
|
||||
|
||||
# Set up deep_gemm cache path
|
||||
DEEP_GEMM_JIT_CACHE_ENV_NAME = 'DG_JIT_CACHE_DIR'
|
||||
if not os.environ.get(DEEP_GEMM_JIT_CACHE_ENV_NAME, None):
|
||||
os.environ[DEEP_GEMM_JIT_CACHE_ENV_NAME] = os.path.join(
|
||||
envs.VLLM_CACHE_ROOT, "deep_gemm")
|
||||
|
||||
_dg = importlib.import_module("deep_gemm")
|
||||
|
||||
_fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None)
|
||||
_grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None)
|
||||
_grouped_masked_impl = getattr(_dg, "fp8_m_grouped_gemm_nt_masked", None)
|
||||
_fp8_mqa_logits_impl = getattr(_dg, "fp8_mqa_logits", None)
|
||||
_fp8_paged_mqa_logits_impl = getattr(_dg, "fp8_paged_mqa_logits", None)
|
||||
_get_paged_mqa_logits_metadata_impl = getattr(
|
||||
_dg, "get_paged_mqa_logits_metadata", None)
|
||||
_get_mn_major_tma_aligned_tensor_impl = getattr(
|
||||
_dg, "get_mn_major_tma_aligned_tensor", None)
|
||||
|
||||
|
||||
def get_num_sms() -> int:
|
||||
_lazy_init()
|
||||
_dg = importlib.import_module("deep_gemm")
|
||||
return int(_dg.get_num_sms())
|
||||
|
||||
|
||||
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor"""
|
||||
_lazy_init()
|
||||
if _get_mn_major_tma_aligned_tensor_impl is None:
|
||||
return _missing()
|
||||
return _get_mn_major_tma_aligned_tensor_impl(x)
|
||||
|
||||
|
||||
def fp8_gemm_nt(*args, **kwargs):
|
||||
_lazy_init()
|
||||
if _fp8_gemm_nt_impl is None:
|
||||
return _missing(*args, **kwargs)
|
||||
return _fp8_gemm_nt_impl(*args,
|
||||
disable_ue8m0_cast=not is_deep_gemm_e8m0_used(),
|
||||
**kwargs)
|
||||
|
||||
|
||||
def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
|
||||
_lazy_init()
|
||||
if _grouped_impl is None:
|
||||
return _missing(*args, **kwargs)
|
||||
return _grouped_impl(*args,
|
||||
disable_ue8m0_cast=not is_deep_gemm_e8m0_used(),
|
||||
**kwargs)
|
||||
|
||||
|
||||
def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
|
||||
_lazy_init()
|
||||
if _grouped_masked_impl is None:
|
||||
return _missing(*args, **kwargs)
|
||||
return _grouped_masked_impl(
|
||||
*args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs)
|
||||
|
||||
|
||||
def fp8_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
kv: tuple[torch.Tensor, torch.Tensor],
|
||||
weights: torch.Tensor,
|
||||
cu_seqlen_ks: torch.Tensor,
|
||||
cu_seqlen_ke: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute FP8 MQA logits for a single sequence without KV paging.
|
||||
|
||||
Args:
|
||||
q: Query tensor of shape [M, H, D]. Casted to
|
||||
`torch.float8_e4m3fn` by caller.
|
||||
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
|
||||
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
|
||||
[N, 1]) with dtype `torch.float32`.
|
||||
weights: weights of shape [M, H], dtype `torch.float32`.
|
||||
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
|
||||
Returns:
|
||||
Logits tensor of shape [M, N], dtype `torch.float32`.
|
||||
"""
|
||||
_lazy_init()
|
||||
if _fp8_mqa_logits_impl is None:
|
||||
return _missing()
|
||||
return _fp8_mqa_logits_impl(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)
|
||||
|
||||
|
||||
def get_paged_mqa_logits_metadata(context_lens: torch.Tensor, block_size: int,
|
||||
num_sms: int) -> torch.Tensor:
|
||||
"""Build scheduling metadata for paged MQA logits.
|
||||
|
||||
Args:
|
||||
context_lens: Tensor of shape [B], dtype int32; effective context length
|
||||
per batch element.
|
||||
block_size: KV-cache block size in tokens (e.g., 64).
|
||||
num_sms: Number of SMs available. 132 for Hopper
|
||||
|
||||
Returns:
|
||||
Backend-specific tensor consumed by `fp8_paged_mqa_logits` to
|
||||
schedule work across SMs.
|
||||
"""
|
||||
_lazy_init()
|
||||
if _get_paged_mqa_logits_metadata_impl is None:
|
||||
return _missing()
|
||||
return _get_paged_mqa_logits_metadata_impl(context_lens, block_size,
|
||||
num_sms)
|
||||
|
||||
|
||||
def fp8_paged_mqa_logits(
|
||||
q_fp8: torch.Tensor,
|
||||
kv_cache_fp8: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
schedule_metadata: torch.Tensor,
|
||||
max_model_len: int,
|
||||
) -> torch.Tensor:
|
||||
"""Compute FP8 MQA logits using paged KV-cache.
|
||||
|
||||
Args:
|
||||
q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to
|
||||
`torch.float8_e4m3fn` by caller.
|
||||
kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape
|
||||
[num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
|
||||
4 bytes per (block,pos) store the `float` dequant scale.
|
||||
weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
|
||||
context_lens: Tensor of shape [B], dtype int32; effective context length
|
||||
for each batch element.
|
||||
block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
|
||||
block indices to physical blocks in the paged cache.
|
||||
schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
|
||||
used to distribute work across SMs.
|
||||
max_model_len: Maximum sequence length used to size the logits output.
|
||||
|
||||
Returns:
|
||||
Logits tensor of shape [B * next_n, max_model_len], dtype
|
||||
`torch.float32`.
|
||||
"""
|
||||
_lazy_init()
|
||||
if _fp8_paged_mqa_logits_impl is None:
|
||||
return _missing()
|
||||
return _fp8_paged_mqa_logits_impl(q_fp8,
|
||||
kv_cache_fp8,
|
||||
weights,
|
||||
context_lens,
|
||||
block_tables,
|
||||
schedule_metadata,
|
||||
max_model_len,
|
||||
clean_logits=True)
|
||||
|
||||
|
||||
def _ceil_to_ue8m0(x: torch.Tensor):
|
||||
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
|
||||
|
||||
|
||||
def _align(x: int, y: int) -> int:
|
||||
return cdiv(x, y) * y
|
||||
|
||||
|
||||
DEFAULT_BLOCK_SIZE = [128, 128]
|
||||
|
||||
|
||||
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/dd6ed14acbc7445dcef224248a77ab4d22b5f240/deep_gemm/utils/math.py#L38
|
||||
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||
def per_block_cast_to_fp8(
|
||||
x: torch.Tensor,
|
||||
block_size: list[int] = DEFAULT_BLOCK_SIZE,
|
||||
use_ue8m0: bool = False) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
block_m, block_n = block_size
|
||||
x_padded = torch.zeros((_align(m, block_m), _align(n, block_n)),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
sf = x_amax / 448.0
|
||||
sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
||||
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
|
||||
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
|
||||
x_view.size(0), x_view.size(2))
|
||||
|
||||
|
||||
def calc_diff(x: torch.Tensor, y: torch.Tensor):
|
||||
"""Return a global difference metric for unit tests.
|
||||
|
||||
DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element
|
||||
error, causing ``torch.testing.assert_close`` to fail. Instead of checking
|
||||
every element, we compute a cosine-style similarity over the whole tensor
|
||||
and report ``1 - sim``. Once kernel accuracy improves this helper can be
|
||||
removed.
|
||||
"""
|
||||
|
||||
x, y = x.double(), y.double()
|
||||
denominator = (x * x + y * y).sum()
|
||||
sim = 2 * (x * y).sum() / denominator
|
||||
return 1 - sim
|
||||
|
||||
|
||||
def should_use_deepgemm_for_fp8_linear(output_dtype: torch.dtype,
|
||||
weight: torch.Tensor):
|
||||
return (is_deep_gemm_supported() and output_dtype == torch.bfloat16
|
||||
and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"calc_diff",
|
||||
"fp8_gemm_nt",
|
||||
"m_grouped_fp8_gemm_nt_contiguous",
|
||||
"fp8_m_grouped_gemm_nt_masked",
|
||||
"fp8_mqa_logits",
|
||||
"fp8_paged_mqa_logits",
|
||||
"get_paged_mqa_logits_metadata",
|
||||
"per_block_cast_to_fp8",
|
||||
"is_deep_gemm_e8m0_used",
|
||||
"is_deep_gemm_supported",
|
||||
"get_num_sms",
|
||||
"should_use_deepgemm_for_fp8_linear",
|
||||
"get_col_major_tma_aligned_tensor",
|
||||
]
|
||||
443
vllm/utils/flashinfer.py
Normal file
443
vllm/utils/flashinfer.py
Normal file
@@ -0,0 +1,443 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Compatibility wrapper for FlashInfer API changes.
|
||||
|
||||
Users of vLLM should always import **only** these wrappers.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import importlib
|
||||
import importlib.util
|
||||
import os
|
||||
from typing import Any, Callable, NoReturn, Optional
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# This is the storage path for the cubins, it can be replaced
|
||||
# with a local path for testing.
|
||||
# Referenced from https://github.com/flashinfer-ai/flashinfer/blob/0c9a92c3d9a7e043ab6f3f7b2273269caf6ab044/flashinfer/jit/cubin_loader.py#L35 # noqa: E501
|
||||
FLASHINFER_CUBINS_REPOSITORY = os.environ.get(
|
||||
"FLASHINFER_CUBINS_REPOSITORY",
|
||||
"https://edge.urm.nvidia.com/artifactory/sw-kernelinferencelibrary-public-generic-local/", # noqa: E501
|
||||
)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer() -> bool:
|
||||
"""Return ``True`` if FlashInfer is available."""
|
||||
# Use find_spec to check if the module exists without importing it
|
||||
# This avoids potential CUDA initialization side effects
|
||||
return importlib.util.find_spec("flashinfer") is not None
|
||||
|
||||
|
||||
def _missing(*_: Any, **__: Any) -> NoReturn:
|
||||
"""Placeholder for unavailable FlashInfer backend."""
|
||||
raise RuntimeError(
|
||||
"FlashInfer backend is not available. Please install the package "
|
||||
"to enable FlashInfer kernels: "
|
||||
"https://github.com/flashinfer-ai/flashinfer")
|
||||
|
||||
|
||||
def _get_submodule(module_name: str) -> Any | None:
|
||||
"""Safely import a submodule and return it, or None if not available."""
|
||||
try:
|
||||
return importlib.import_module(module_name)
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
return None
|
||||
|
||||
|
||||
# General lazy import wrapper
|
||||
def _lazy_import_wrapper(module_name: str,
|
||||
attr_name: str,
|
||||
fallback_fn: Callable[..., Any] = _missing):
|
||||
"""Create a lazy import wrapper for a specific function."""
|
||||
|
||||
@functools.cache
|
||||
def _get_impl():
|
||||
if not has_flashinfer():
|
||||
return None
|
||||
mod = _get_submodule(module_name)
|
||||
return getattr(mod, attr_name, None) if mod else None
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
impl = _get_impl()
|
||||
if impl is None:
|
||||
return fallback_fn(*args, **kwargs)
|
||||
return impl(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# Create lazy wrappers for each function
|
||||
flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper(
|
||||
"flashinfer.fused_moe", "trtllm_fp8_block_scale_moe")
|
||||
flashinfer_trtllm_fp8_per_tensor_scale_moe = _lazy_import_wrapper(
|
||||
"flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe")
|
||||
flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe",
|
||||
"cutlass_fused_moe")
|
||||
fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
|
||||
nvfp4_block_scale_interleave = _lazy_import_wrapper(
|
||||
"flashinfer", "nvfp4_block_scale_interleave")
|
||||
trtllm_fp4_block_scale_moe = _lazy_import_wrapper(
|
||||
"flashinfer", "trtllm_fp4_block_scale_moe")
|
||||
|
||||
# Special case for autotune since it returns a context manager
|
||||
autotune = _lazy_import_wrapper(
|
||||
"flashinfer.autotuner",
|
||||
"autotune",
|
||||
fallback_fn=lambda *args, **kwargs: contextlib.nullcontext())
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_comm() -> bool:
|
||||
"""Return ``True`` if FlashInfer comm module is available."""
|
||||
return has_flashinfer() and importlib.util.find_spec(
|
||||
"flashinfer.comm") is not None
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_all2all() -> bool:
|
||||
"""Return ``True`` if FlashInfer mnnvl all2all is available."""
|
||||
if not has_flashinfer_comm():
|
||||
return False
|
||||
|
||||
# Check if all required functions are available
|
||||
required_functions = [
|
||||
("flashinfer.comm", "Mapping"),
|
||||
("flashinfer.comm.mnnvl", "MnnvlMemory"),
|
||||
("flashinfer.comm.trtllm_alltoall", "MnnvlMoe"),
|
||||
("flashinfer.comm.trtllm_alltoall", "MoEAlltoallInfo"),
|
||||
]
|
||||
|
||||
for module_name, attr_name in required_functions:
|
||||
mod = _get_submodule(module_name)
|
||||
if not mod or not hasattr(mod, attr_name):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_moe() -> bool:
|
||||
"""Return ``True`` if FlashInfer MoE module is available."""
|
||||
return has_flashinfer() and importlib.util.find_spec(
|
||||
"flashinfer.fused_moe") is not None
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_cutlass_fused_moe() -> bool:
|
||||
"""Return ``True`` if FlashInfer CUTLASS fused MoE is available."""
|
||||
if not has_flashinfer_moe():
|
||||
return False
|
||||
|
||||
# Check if all required functions are available
|
||||
required_functions = [
|
||||
("flashinfer.fused_moe", "cutlass_fused_moe"),
|
||||
("flashinfer", "fp4_quantize"),
|
||||
("flashinfer", "nvfp4_block_scale_interleave"),
|
||||
("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
|
||||
]
|
||||
|
||||
for module_name, attr_name in required_functions:
|
||||
mod = _get_submodule(module_name)
|
||||
if not mod or not hasattr(mod, attr_name):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_nvidia_artifactory() -> bool:
|
||||
"""Return ``True`` if NVIDIA's artifactory is accessible.
|
||||
|
||||
This checks connectivity to the kernel inference library artifactory
|
||||
which is required for downloading certain cubin kernels like TRTLLM FHMA.
|
||||
"""
|
||||
# Since FLASHINFER_CUBIN_DIR defines the pre-downloaded cubins path, when
|
||||
# it's true, we could assume the cubins are available.
|
||||
if envs.VLLM_HAS_FLASHINFER_CUBIN:
|
||||
return True
|
||||
|
||||
try:
|
||||
# Use a short timeout to avoid blocking for too long
|
||||
response = requests.get(FLASHINFER_CUBINS_REPOSITORY, timeout=5)
|
||||
accessible = response.status_code == 200
|
||||
if accessible:
|
||||
logger.debug_once("NVIDIA artifactory is accessible")
|
||||
else:
|
||||
logger.warning_once(
|
||||
"NVIDIA artifactory returned failed status code: %d",
|
||||
response.status_code)
|
||||
return accessible
|
||||
except Exception as e:
|
||||
logger.warning_once("Failed to connect to NVIDIA artifactory: %s", e)
|
||||
return False
|
||||
|
||||
|
||||
@functools.cache
|
||||
def supports_trtllm_attention() -> bool:
|
||||
"""
|
||||
TRTLLM attention is supported if the platform is SM100 and
|
||||
NVIDIA artifactory is accessible
|
||||
"""
|
||||
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
|
||||
return current_platform.is_device_capability(
|
||||
100) and has_nvidia_artifactory()
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _force_use_trtllm_attention(env_value: Optional[bool]) -> Optional[bool]:
|
||||
"""Cache the env value for VLLM_USE_TRTLLM_ATTENTION"""
|
||||
if env_value is not None:
|
||||
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
|
||||
return env_value
|
||||
|
||||
|
||||
def force_use_trtllm_attention() -> Optional[bool]:
|
||||
"""
|
||||
Return ``None`` if VLLM_USE_TRTLLM_ATTENTION is not set,
|
||||
return ``True`` if TRTLLM attention is forced to be used,
|
||||
return ``False`` if TRTLLM attention is forced to be not used.
|
||||
"""
|
||||
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)
|
||||
|
||||
|
||||
def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
|
||||
"""Check if the current configuration supports TRTLLM attention."""
|
||||
has_trtllm = supports_trtllm_attention()
|
||||
return has_trtllm and (num_qo_heads % num_kv_heads == 0)
|
||||
|
||||
|
||||
def use_trtllm_attention(
|
||||
num_qo_heads: int,
|
||||
num_kv_heads: int,
|
||||
num_tokens: int,
|
||||
max_seq_len: int,
|
||||
kv_cache_dtype: str,
|
||||
q_dtype: torch.dtype,
|
||||
is_prefill: bool,
|
||||
has_sinks: bool = False,
|
||||
has_spec: bool = False,
|
||||
) -> bool:
|
||||
"""Return ``True`` if TRTLLM attention is used."""
|
||||
force_use_trtllm = force_use_trtllm_attention()
|
||||
|
||||
# Environment variable is set to 0 - respect it
|
||||
if force_use_trtllm is not None and not force_use_trtllm:
|
||||
return False
|
||||
|
||||
# The platform is not supported
|
||||
if not supports_trtllm_attention():
|
||||
if force_use_trtllm:
|
||||
logger.warning_once(
|
||||
"TRTLLM attention is not supported on this platform, "
|
||||
"but VLLM_USE_TRTLLM_ATTENTION is set to 1")
|
||||
return False
|
||||
|
||||
# The combination of query and key heads is not supported
|
||||
if num_qo_heads % num_kv_heads != 0:
|
||||
if force_use_trtllm:
|
||||
logger.warning_once(
|
||||
"TRTLLM attention is not supported for this combination of "
|
||||
"query and key heads, but VLLM_USE_TRTLLM_ATTENTION is set to 1"
|
||||
)
|
||||
return False
|
||||
|
||||
if has_spec and not is_prefill:
|
||||
# Speculative decoding requires TRTLLM attention for decodes
|
||||
logger.info_once(
|
||||
"Using TRTLLM attention (enabled for speculative decoding).")
|
||||
return True
|
||||
|
||||
# Must use TRTLLM attention if query is FP8 quantized
|
||||
if q_dtype == current_platform.fp8_dtype():
|
||||
if has_sinks:
|
||||
raise RuntimeError(
|
||||
"TRTLLM FP8-qkv kernel is not supported for attention sinks. "
|
||||
"Use kv_cache_dtype=auto for now.")
|
||||
logger.info_once("Using TRTLLM attention (query is quantized).")
|
||||
return True
|
||||
|
||||
# If sinks are being used, we must use TRTLLM attention as it's
|
||||
# the only backend that supports them
|
||||
if has_sinks:
|
||||
logger.info_once(
|
||||
"Using TRTLLM attention (required for attention sinks).")
|
||||
return True
|
||||
|
||||
if force_use_trtllm is None:
|
||||
# Environment variable not set - use auto-detection
|
||||
use_trtllm = (num_tokens <= 256 and max_seq_len <= 131072
|
||||
and kv_cache_dtype == "auto")
|
||||
if use_trtllm:
|
||||
logger.warning_once("Using TRTLLM attention (auto-detected).")
|
||||
return use_trtllm
|
||||
|
||||
# Environment variable is set to 1 - respect it
|
||||
logger.info_once(
|
||||
"Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)")
|
||||
return True
|
||||
|
||||
|
||||
if has_flashinfer():
|
||||
|
||||
@torch.library.custom_op(
|
||||
"vllm::flashinfer_mm_fp4",
|
||||
mutates_args=[],
|
||||
device_types="cuda",
|
||||
)
|
||||
def flashinfer_mm_fp4(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
g_scale: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
backend: str,
|
||||
) -> torch.Tensor:
|
||||
from flashinfer import mm_fp4 as flashinfer_mm_fp4_
|
||||
return flashinfer_mm_fp4_(A,
|
||||
B,
|
||||
A_scale,
|
||||
B_scale,
|
||||
g_scale,
|
||||
dtype,
|
||||
block_size=16,
|
||||
backend=backend)
|
||||
|
||||
@torch.library.register_fake("vllm::flashinfer_mm_fp4", )
|
||||
def flashinfer_mm_fp4_fake(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
g_scale: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
backend: str,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(A.shape[0],
|
||||
B.shape[1],
|
||||
dtype=dtype,
|
||||
device=A.device)
|
||||
|
||||
@torch.library.custom_op(
|
||||
"vllm::bmm_fp8",
|
||||
mutates_args=[],
|
||||
device_types="cuda",
|
||||
)
|
||||
def bmm_fp8(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
backend: str,
|
||||
) -> torch.Tensor:
|
||||
from flashinfer import bmm_fp8 as bmm_fp8_
|
||||
return bmm_fp8_(A, B, A_scale, B_scale, dtype, None, backend)
|
||||
|
||||
@torch.library.register_fake("vllm::bmm_fp8", )
|
||||
def bmm_fp8_fake(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
backend: str,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(A.shape[0],
|
||||
A.shape[1],
|
||||
B.shape[2],
|
||||
dtype=dtype,
|
||||
device=A.device)
|
||||
|
||||
|
||||
def flashinfer_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor,
|
||||
block_scale_a: torch.Tensor,
|
||||
block_scale_b: torch.Tensor, alpha: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
backend: str) -> torch.Tensor:
|
||||
assert a.ndim == 2 and b.ndim == 2
|
||||
assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2
|
||||
assert a.stride(-1) == 1 and b.stride(-1) == 1
|
||||
assert a.shape[1] == b.shape[1]
|
||||
assert block_scale_a.shape[1] == a.shape[1] // 8
|
||||
assert block_scale_b.shape[1] == b.shape[1] // 8
|
||||
|
||||
if backend == "cutlass":
|
||||
block_scale_a = block_scale_a.view(torch.uint8)
|
||||
block_scale_b = block_scale_b.view(torch.uint8)
|
||||
|
||||
return flashinfer_mm_fp4(
|
||||
a,
|
||||
b.t(),
|
||||
block_scale_a,
|
||||
block_scale_b.t(),
|
||||
alpha,
|
||||
out_dtype,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_scaled_fp8_mm(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
assert a.ndim == 2 and b.ndim == 2
|
||||
assert a.shape[1] == b.shape[0]
|
||||
assert scale_a.numel() == 1 and scale_b.numel() == 1
|
||||
assert a.dtype == torch.float8_e4m3fn and b.dtype == torch.float8_e4m3fn
|
||||
assert a.device.type == "cuda" and b.device.type == "cuda"
|
||||
assert scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32
|
||||
assert scale_a.device.type == "cuda" and scale_b.device.type == "cuda"
|
||||
|
||||
output = bmm_fp8(
|
||||
a.unsqueeze(0),
|
||||
b.unsqueeze(0),
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype,
|
||||
"auto",
|
||||
).view(a.shape[0], b.shape[1])
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output
|
||||
|
||||
|
||||
@functools.cache
|
||||
def flashinfer_disable_q_quantization() -> bool:
|
||||
"""Cache result which only depends on the environment"""
|
||||
return envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION
|
||||
|
||||
|
||||
__all__ = [
|
||||
"has_flashinfer",
|
||||
"flashinfer_trtllm_fp8_block_scale_moe",
|
||||
"flashinfer_cutlass_fused_moe",
|
||||
"fp4_quantize",
|
||||
"nvfp4_block_scale_interleave",
|
||||
"trtllm_fp4_block_scale_moe",
|
||||
"autotune",
|
||||
"has_flashinfer_moe",
|
||||
"has_flashinfer_comm",
|
||||
"has_flashinfer_all2all",
|
||||
"has_flashinfer_cutlass_fused_moe",
|
||||
"has_nvidia_artifactory",
|
||||
"supports_trtllm_attention",
|
||||
"can_use_trtllm_attention",
|
||||
"use_trtllm_attention",
|
||||
"flashinfer_disable_q_quantization",
|
||||
"flashinfer_scaled_fp4_mm",
|
||||
"flashinfer_scaled_fp8_mm",
|
||||
]
|
||||
178
vllm/utils/jsontree.py
Normal file
178
vllm/utils/jsontree.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Helper functions to work with nested JSON structures."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from functools import reduce
|
||||
from typing import TYPE_CHECKING, Callable, TypeVar, Union, cast, overload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
from vllm.multimodal.inputs import BatchedTensorInputs
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_U = TypeVar("_U")
|
||||
|
||||
JSONTree = Union[
|
||||
dict[str, "JSONTree[_T]"],
|
||||
list["JSONTree[_T]"],
|
||||
tuple["JSONTree[_T]", ...],
|
||||
_T,
|
||||
]
|
||||
"""A nested JSON structure where the leaves need not be JSON-serializable."""
|
||||
|
||||
_JSONTree = Union[
|
||||
dict[str, "JSONTree[_T]"],
|
||||
list["JSONTree[_T]"],
|
||||
tuple["JSONTree[_T]", ...],
|
||||
dict[str, _T],
|
||||
list[_T],
|
||||
tuple[_T, ...],
|
||||
_T,
|
||||
]
|
||||
"""
|
||||
Same as `JSONTree` but with additional `Union` members to satisfy overloads.
|
||||
"""
|
||||
|
||||
|
||||
def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]:
|
||||
"""Iterate through each leaf in a nested JSON structure."""
|
||||
if isinstance(value, dict):
|
||||
for v in value.values():
|
||||
yield from json_iter_leaves(v)
|
||||
elif isinstance(value, (list, tuple)):
|
||||
for v in value:
|
||||
yield from json_iter_leaves(v)
|
||||
else:
|
||||
yield value
|
||||
|
||||
|
||||
@overload
|
||||
def json_map_leaves(
|
||||
func: Callable[["torch.Tensor"], "torch.Tensor"],
|
||||
value: "BatchedTensorInputs",
|
||||
) -> "BatchedTensorInputs":
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def json_map_leaves(
|
||||
func: Callable[[_T], _U],
|
||||
value: Union[_T, dict[str, _T]],
|
||||
) -> Union[_U, dict[str, _U]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def json_map_leaves(
|
||||
func: Callable[[_T], _U],
|
||||
value: Union[_T, list[_T]],
|
||||
) -> Union[_U, list[_U]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def json_map_leaves(
|
||||
func: Callable[[_T], _U],
|
||||
value: Union[_T, tuple[_T, ...]],
|
||||
) -> Union[_U, tuple[_U, ...]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def json_map_leaves(
|
||||
func: Callable[[_T], _U],
|
||||
value: JSONTree[_T],
|
||||
) -> JSONTree[_U]:
|
||||
...
|
||||
|
||||
|
||||
def json_map_leaves(
|
||||
func: Callable[[_T], _U],
|
||||
value: Union["BatchedTensorInputs", _JSONTree[_T]],
|
||||
) -> Union["BatchedTensorInputs", _JSONTree[_U]]:
|
||||
"""Apply a function to each leaf in a nested JSON structure."""
|
||||
if isinstance(value, dict):
|
||||
return {
|
||||
k: json_map_leaves(func, v) # type: ignore[arg-type]
|
||||
for k, v in value.items()
|
||||
}
|
||||
elif isinstance(value, list):
|
||||
return [json_map_leaves(func, v) for v in value]
|
||||
elif isinstance(value, tuple):
|
||||
return tuple(json_map_leaves(func, v) for v in value)
|
||||
else:
|
||||
return func(value)
|
||||
|
||||
|
||||
@overload
|
||||
def json_reduce_leaves(
|
||||
func: Callable[[_T, _T], _T],
|
||||
value: Union[_T, dict[str, _T]],
|
||||
/,
|
||||
) -> _T:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def json_reduce_leaves(
|
||||
func: Callable[[_T, _T], _T],
|
||||
value: Union[_T, list[_T]],
|
||||
/,
|
||||
) -> _T:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def json_reduce_leaves(
|
||||
func: Callable[[_T, _T], _T],
|
||||
value: Union[_T, tuple[_T, ...]],
|
||||
/,
|
||||
) -> _T:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def json_reduce_leaves(
|
||||
func: Callable[[_T, _T], _T],
|
||||
value: JSONTree[_T],
|
||||
/,
|
||||
) -> _T:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def json_reduce_leaves(
|
||||
func: Callable[[_U, _T], _U],
|
||||
value: JSONTree[_T],
|
||||
initial: _U,
|
||||
/,
|
||||
) -> _U:
|
||||
...
|
||||
|
||||
|
||||
def json_reduce_leaves(
|
||||
func: Callable[..., Union[_T, _U]],
|
||||
value: _JSONTree[_T],
|
||||
initial: _U = cast(_U, ...), # noqa: B008
|
||||
/,
|
||||
) -> Union[_T, _U]:
|
||||
"""
|
||||
Apply a function of two arguments cumulatively to each leaf in a
|
||||
nested JSON structure, from left to right, so as to reduce the
|
||||
sequence to a single value.
|
||||
"""
|
||||
if initial is ...:
|
||||
return reduce(func, json_iter_leaves(value)) # type: ignore[arg-type]
|
||||
|
||||
return reduce(
|
||||
func, # type: ignore[arg-type]
|
||||
json_iter_leaves(value),
|
||||
initial,
|
||||
)
|
||||
|
||||
|
||||
def json_count_leaves(value: JSONTree[_T]) -> int:
|
||||
"""Count the number of leaves in a nested JSON structure."""
|
||||
return sum(1 for _ in json_iter_leaves(value))
|
||||
235
vllm/utils/tensor_schema.py
Normal file
235
vllm/utils/tensor_schema.py
Normal file
@@ -0,0 +1,235 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import (Annotated, Any, Optional, Union, get_args, get_origin,
|
||||
get_type_hints)
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TensorShape:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*dims: Union[int, str],
|
||||
dynamic_dims: Optional[set[str]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.dims = dims
|
||||
self.dynamic_dims = dynamic_dims if dynamic_dims else set()
|
||||
|
||||
def resolve(self, **bindings: int) -> tuple[Union[int, str], ...]:
|
||||
resolved = list[Union[int, str]]()
|
||||
for dim in self.dims:
|
||||
if isinstance(dim, str) and dim in bindings:
|
||||
resolved.append(bindings[dim])
|
||||
else:
|
||||
resolved.append(dim)
|
||||
return tuple(resolved)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a string representation of the tensor shape."""
|
||||
dim_strs = []
|
||||
for dim in self.dims:
|
||||
if isinstance(dim, str):
|
||||
if dim in self.dynamic_dims:
|
||||
dim_strs.append(
|
||||
f"{dim}*") # Mark dynamic dimensions with *
|
||||
else:
|
||||
dim_strs.append(dim)
|
||||
else:
|
||||
dim_strs.append(str(dim))
|
||||
return f"({', '.join(dim_strs)})"
|
||||
|
||||
|
||||
class TensorSchema:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
validate: bool = True,
|
||||
resolve_bindings: Optional[dict[str, int]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._resolve_bindings = resolve_bindings if resolve_bindings else {}
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
if validate:
|
||||
self.validate()
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return getattr(self, key)
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
return getattr(self, key, default)
|
||||
|
||||
def _match_shape_with_dynamic(
|
||||
self,
|
||||
actual: tuple[int, ...],
|
||||
reference: tuple[int, ...],
|
||||
expected_shape: tuple[Union[int, str], ...],
|
||||
dynamic_dims: set[str],
|
||||
) -> bool:
|
||||
if len(actual) != len(reference) or len(actual) > len(expected_shape):
|
||||
return False
|
||||
|
||||
for i, (a, r) in enumerate(zip(actual, reference)):
|
||||
# When validating list inputs, we match shape suffixes only
|
||||
# (e.g. "p", 3, "h", "w"), assuming the list length corresponds
|
||||
# to the leading symbolic dim (e.g. "bn"). This allows comparing
|
||||
# only the trailing dimensions of each element in the list.
|
||||
dim = expected_shape[-len(actual) + i]
|
||||
# Skip this dimension if it's marked dynamic
|
||||
if dim in dynamic_dims:
|
||||
continue
|
||||
if a != r:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _validate_nested_tensors(
|
||||
self,
|
||||
value: Union[list[torch.Tensor], tuple[torch.Tensor, ...]],
|
||||
field_name: str,
|
||||
expected_shape: tuple[Union[int, str], ...],
|
||||
dynamic_dims: set[str],
|
||||
) -> tuple[int, ...]:
|
||||
"""Validate a list/tuple of tensors and return the actual shape."""
|
||||
# Ensure all tensors in the list have the same
|
||||
# shape, besides dynamic dimensions
|
||||
first = value[0]
|
||||
for i, v in enumerate(value):
|
||||
if not isinstance(v, torch.Tensor):
|
||||
raise ValueError(f"{field_name}[{i}] is not a "
|
||||
f"torch.Tensor")
|
||||
if not self._match_shape_with_dynamic(
|
||||
v.shape,
|
||||
first.shape,
|
||||
expected_shape,
|
||||
dynamic_dims,
|
||||
):
|
||||
raise ValueError(f"{field_name} contains inconsistent "
|
||||
f"shapes: {first.shape} vs {v.shape} "
|
||||
f"at index {i}")
|
||||
|
||||
# Treat the list as a stacked tensor:
|
||||
# shape = (len(list), *tensor.shape)
|
||||
return (len(value), ) + first.shape
|
||||
|
||||
def _validate_tensor_shape_expected(
|
||||
self,
|
||||
actual_shape: tuple[int, ...],
|
||||
expected_shape: tuple[Union[int, str], ...],
|
||||
field_name: str,
|
||||
shape_env: dict[str, int],
|
||||
dynamic_dims: set[str],
|
||||
) -> None:
|
||||
"""Validate that the actual tensor shape matches the expected shape."""
|
||||
|
||||
if len(actual_shape) != len(expected_shape):
|
||||
raise ValueError(f"{field_name} has rank {len(actual_shape)} "
|
||||
f"but expected {len(expected_shape)}")
|
||||
|
||||
for i, dim in enumerate(expected_shape):
|
||||
if dim in dynamic_dims:
|
||||
continue
|
||||
elif isinstance(dim, int):
|
||||
if actual_shape[i] != dim:
|
||||
raise ValueError(f"{field_name} dim[{i}] expected "
|
||||
f"{dim}, got {actual_shape[i]}")
|
||||
elif isinstance(dim, str):
|
||||
if dim in shape_env:
|
||||
if actual_shape[i] != shape_env[dim]:
|
||||
raise ValueError(f"{field_name} dim[{i}] expected "
|
||||
f"'{dim}'={shape_env[dim]}, got "
|
||||
f"{actual_shape[i]}")
|
||||
else:
|
||||
shape_env[dim] = actual_shape[i]
|
||||
else:
|
||||
raise TypeError(f"{field_name} dim[{i}] has unsupported "
|
||||
f"type: {type(dim)}")
|
||||
|
||||
def validate(self) -> None:
|
||||
type_hints = get_type_hints(self.__class__, include_extras=True)
|
||||
shape_env = dict[str, int]()
|
||||
|
||||
for field_name, field_type in type_hints.items():
|
||||
# Check if field is missing
|
||||
if (not hasattr(self, field_name)
|
||||
or getattr(self, field_name) is None):
|
||||
# Check if field is marked as optional
|
||||
actual_type = field_type
|
||||
if get_origin(field_type) is Annotated:
|
||||
args = get_args(field_type)
|
||||
actual_type = args[0]
|
||||
|
||||
# Check arg was provided as Union
|
||||
if get_origin(actual_type) is Union:
|
||||
args = get_args(actual_type)
|
||||
# Skip validation when Union contains None
|
||||
if type(None) in args:
|
||||
continue
|
||||
# Otherwise field is required, raise error
|
||||
raise ValueError(f"Required field '{field_name}' is missing")
|
||||
|
||||
# Field exists, proceed with validation
|
||||
value = getattr(self, field_name)
|
||||
if get_origin(field_type) is not None:
|
||||
args = get_args(field_type)
|
||||
|
||||
for arg in args:
|
||||
if isinstance(arg, TensorShape):
|
||||
expected_shape = arg.resolve(**self._resolve_bindings)
|
||||
if isinstance(value, (list, tuple)):
|
||||
# list/tuple of Tensors → shape = (len(value), ...)
|
||||
if value and isinstance(value[0], torch.Tensor):
|
||||
actual_shape = self._validate_nested_tensors(
|
||||
value, field_name, expected_shape,
|
||||
arg.dynamic_dims)
|
||||
elif value:
|
||||
# list/tuple of scalars → shape = (len(value),)
|
||||
actual_shape = (len(value), )
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{field_name} is an empty list")
|
||||
|
||||
# Tensor → shape = tensor.shape
|
||||
elif isinstance(value, torch.Tensor):
|
||||
actual_shape = value.shape
|
||||
|
||||
# Otherwise, it's an unsupported type
|
||||
else:
|
||||
type_names = []
|
||||
for arg in args:
|
||||
if hasattr(arg, "__name__"):
|
||||
type_names.append(str(arg.__name__))
|
||||
else:
|
||||
type_names.append(str(arg))
|
||||
|
||||
expected_types = ", ".join(type_names)
|
||||
raise ValueError(
|
||||
f"{field_name} is not one of the expected "
|
||||
f"types: {expected_types}")
|
||||
|
||||
self._validate_tensor_shape_expected(
|
||||
actual_shape, expected_shape, field_name,
|
||||
shape_env, arg.dynamic_dims)
|
||||
|
||||
def print_shapes(self) -> None:
|
||||
"""Print TensorShape annotations for debugging."""
|
||||
logger.debug("Shapes in %s:", self.__class__.__name__)
|
||||
type_hints = get_type_hints(self.__class__, include_extras=True)
|
||||
|
||||
for field_name, field_type in type_hints.items():
|
||||
if get_origin(field_type) is not None:
|
||||
args = get_args(field_type)
|
||||
for arg in args:
|
||||
if isinstance(arg, TensorShape):
|
||||
logger.debug(" %s: %s", field_name, str(arg))
|
||||
Reference in New Issue
Block a user