Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -418,6 +418,125 @@ def should_use_deepgemm_for_fp8_linear(
|
||||
)
|
||||
|
||||
|
||||
def fp8_mqa_logits_torch(
|
||||
q: torch.Tensor,
|
||||
kv: tuple[torch.Tensor, torch.Tensor],
|
||||
weights: torch.Tensor,
|
||||
cu_seqlen_ks: torch.Tensor,
|
||||
cu_seqlen_ke: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute FP8 MQA logits for a single sequence without KV paging (CUDA fallback).
|
||||
|
||||
This is a pure PyTorch fallback for CUDA when DeepGEMM is not available.
|
||||
|
||||
Args:
|
||||
q: Query tensor of shape [M, H, D]. Casted to
|
||||
`torch.float8_e4m3fn` by caller.
|
||||
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
|
||||
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
|
||||
[N, 1]) with dtype `torch.float32`.
|
||||
weights: weights of shape [M, H], dtype `torch.float32`.
|
||||
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
|
||||
Returns:
|
||||
Logits tensor of shape [M, N], dtype `torch.float32`.
|
||||
"""
|
||||
kv_fp8, scale = kv
|
||||
seq_len_kv = kv_fp8.shape[0]
|
||||
k = kv_fp8.to(torch.bfloat16)
|
||||
q = q.to(torch.bfloat16)
|
||||
|
||||
mask_lo = (
|
||||
torch.arange(0, seq_len_kv, device=q.device)[None, :] >= cu_seqlen_ks[:, None]
|
||||
)
|
||||
mask_hi = (
|
||||
torch.arange(0, seq_len_kv, device=q.device)[None, :] < cu_seqlen_ke[:, None]
|
||||
)
|
||||
mask = mask_lo & mask_hi
|
||||
|
||||
score = torch.einsum("mhd,nd->hmn", q, k).float() * scale
|
||||
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
|
||||
logits = logits.masked_fill(~mask, float("-inf"))
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def fp8_paged_mqa_logits_torch(
|
||||
q: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
max_model_len: int,
|
||||
) -> torch.Tensor:
|
||||
"""Compute FP8 MQA logits using paged KV-cache (CUDA fallback).
|
||||
|
||||
This is a pure PyTorch fallback for CUDA when DeepGEMM is not available.
|
||||
Handles head_dim = 132 (128 + 4 for RoPE).
|
||||
|
||||
Args:
|
||||
q: Query tensor of shape [B, next_n, H, D].
|
||||
kv_cache: Paged KV-cache in packed FP8+scale layout with shape
|
||||
[num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
|
||||
4 bytes per (block,pos) store the `float` dequant scale.
|
||||
weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
|
||||
context_lens: Tensor of shape [B], dtype int32; effective context length
|
||||
for each batch element.
|
||||
block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
|
||||
block indices to physical blocks in the paged cache.
|
||||
max_model_len: Maximum sequence length used to size the logits output.
|
||||
|
||||
Returns:
|
||||
Logits tensor of shape [B * next_n, max_model_len], dtype
|
||||
`torch.float32`.
|
||||
"""
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
batch_size, next_n, heads, dim = q.size()
|
||||
kv_cache, scale = kv_cache[..., :dim], kv_cache[..., dim:]
|
||||
scale = scale.contiguous().view(torch.float)
|
||||
q = q.float()
|
||||
kv_cache = kv_cache.view(fp8_dtype).float() * scale
|
||||
num_blocks, block_size, _, dim = kv_cache.size()
|
||||
logits = torch.full(
|
||||
[batch_size * next_n, max_model_len],
|
||||
float("-inf"),
|
||||
device=q.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
for i in range(batch_size):
|
||||
context_len = context_lens[i].item()
|
||||
q_offsets = torch.arange(context_len - next_n, context_len, device=q.device)
|
||||
weight_slice = (
|
||||
weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous()
|
||||
)
|
||||
for block_idx in range(cdiv(context_len, block_size)):
|
||||
block_id = block_tables[i][block_idx]
|
||||
qx, kx = q[i], kv_cache[block_id]
|
||||
k_offsets = torch.arange(
|
||||
block_idx * block_size, (block_idx + 1) * block_size, device=q.device
|
||||
)
|
||||
mask = (k_offsets[None, :] < context_len) & (
|
||||
k_offsets[None, :] <= q_offsets[:, None]
|
||||
)
|
||||
s = torch.where(
|
||||
mask[None, :, :],
|
||||
(qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
|
||||
logits.dtype
|
||||
),
|
||||
float("-inf"),
|
||||
)
|
||||
s = torch.relu(s) * weight_slice[..., None]
|
||||
s = s.sum(dim=0)
|
||||
logits[
|
||||
i * next_n : (i + 1) * next_n,
|
||||
block_idx * block_size : (block_idx + 1) * block_size,
|
||||
] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf"))
|
||||
return logits
|
||||
|
||||
|
||||
__all__ = [
|
||||
"calc_diff",
|
||||
"DeepGemmQuantScaleFMT",
|
||||
@@ -425,7 +544,9 @@ __all__ = [
|
||||
"m_grouped_fp8_gemm_nt_contiguous",
|
||||
"fp8_m_grouped_gemm_nt_masked",
|
||||
"fp8_mqa_logits",
|
||||
"fp8_mqa_logits_torch",
|
||||
"fp8_paged_mqa_logits",
|
||||
"fp8_paged_mqa_logits_torch",
|
||||
"get_paged_mqa_logits_metadata",
|
||||
"per_block_cast_to_fp8",
|
||||
"is_deep_gemm_e8m0_used",
|
||||
|
||||
@@ -19,9 +19,6 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
# from vllm.model_executor.layers.batch_invariant import (
|
||||
# vllm_is_batch_invariant,
|
||||
# )
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -140,6 +137,7 @@ autotune = _lazy_import_wrapper(
|
||||
"autotune",
|
||||
fallback_fn=lambda *args, **kwargs: contextlib.nullcontext(),
|
||||
)
|
||||
_is_fi_autotuning: bool = False
|
||||
|
||||
|
||||
@functools.cache
|
||||
@@ -279,6 +277,9 @@ def supports_trtllm_attention() -> bool:
|
||||
TRTLLM attention is supported if the platform is SM100,
|
||||
NVIDIA artifactory is accessible, and batch-invariant mode is not enabled.
|
||||
"""
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
# Batch-invariant mode disables TRTLLM attention
|
||||
if vllm_is_batch_invariant():
|
||||
return False
|
||||
@@ -734,7 +735,7 @@ def should_use_flashinfer_for_blockscale_fp8_gemm(
|
||||
|
||||
# Verify DeepGEMM N/K dims requirements
|
||||
# NOTE: Also synchronized with test_w8a8_block_fp8_deep_gemm_matmul
|
||||
# test inside kernels/quatization/test_block_fp8.py
|
||||
# test inside kernels/quantization/test_block_fp8.py
|
||||
N_MULTIPLE = 64
|
||||
K_MULTIPLE = 128
|
||||
|
||||
|
||||
@@ -402,11 +402,6 @@ def _has_module(module_name: str) -> bool:
|
||||
return importlib.util.find_spec(module_name) is not None
|
||||
|
||||
|
||||
def has_pplx() -> bool:
|
||||
"""Whether the optional `pplx_kernels` package is available."""
|
||||
return _has_module("pplx_kernels")
|
||||
|
||||
|
||||
def has_deep_ep() -> bool:
|
||||
"""Whether the optional `deep_ep` package is available."""
|
||||
return _has_module("deep_ep")
|
||||
|
||||
@@ -30,3 +30,8 @@ def round_up(x: int, y: int) -> int:
|
||||
def round_down(x: int, y: int) -> int:
|
||||
"""Round down x to the nearest multiple of y."""
|
||||
return (x // y) * y
|
||||
|
||||
|
||||
def largest_power_of_2_divisor(n: int) -> int:
|
||||
"""Return the largest power-of-2 that divides *n* (isolate lowest set bit)."""
|
||||
return n & (-n)
|
||||
|
||||
@@ -16,6 +16,7 @@ import psutil
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import in_wsl
|
||||
from vllm.ray.lazy_utils import is_in_ray_actor
|
||||
|
||||
@@ -111,6 +112,17 @@ def unique_filepath(fn: Callable[[int], Path]) -> Path:
|
||||
# Process management utilities
|
||||
|
||||
|
||||
def _sync_visible_devices_env_vars():
|
||||
"""Sync HIP/CUDA visibility env vars before spawning (ROCm only)."""
|
||||
|
||||
if not current_platform.is_rocm():
|
||||
return
|
||||
|
||||
from vllm.platforms.rocm import _sync_hip_cuda_env_vars
|
||||
|
||||
_sync_hip_cuda_env_vars()
|
||||
|
||||
|
||||
def _maybe_force_spawn():
|
||||
"""Check if we need to force the use of the `spawn` multiprocessing start
|
||||
method.
|
||||
@@ -156,6 +168,10 @@ def get_mp_context():
|
||||
VLLM_WORKER_MULTIPROC_METHOD.
|
||||
"""
|
||||
_maybe_force_spawn()
|
||||
# (ROCm): Sync GPU visibility env vars so spawned children inherit
|
||||
# consistent values. Must run after _maybe_force_spawn and regardless
|
||||
# of whether spawn was already set.
|
||||
_sync_visible_devices_env_vars()
|
||||
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
|
||||
return multiprocessing.get_context(mp_method)
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from torch.library import Library, infer_schema
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
import ixformer.inference.functions as ixfops
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@@ -641,7 +642,6 @@ def weak_ref_tensor(tensor: Any) -> Any:
|
||||
This ignores 0-size tensors as those don't allocate any memory.
|
||||
"""
|
||||
if isinstance(tensor, torch.Tensor) and tensor.numel() > 0:
|
||||
# return torch.ops._C.weak_ref_tensor(tensor)
|
||||
return ixfops.weak_ref_tensor(tensor)
|
||||
else:
|
||||
return tensor
|
||||
@@ -685,7 +685,7 @@ def get_accelerator_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tens
|
||||
assert cpu_tensor.is_pinned(), "CPU tensor must be pinned"
|
||||
return torch.ops._C.get_xpu_view_from_cpu_tensor(cpu_tensor)
|
||||
elif current_platform.is_cuda() or current_platform.is_rocm():
|
||||
return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor)
|
||||
return ixfops.get_cuda_view_from_cpu_tensor(cpu_tensor)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`get_accelerator_view_from_cpu_tensor` is currently "
|
||||
@@ -741,6 +741,41 @@ def is_torch_equal(target: str) -> bool:
|
||||
return Version(importlib.metadata.version("torch")) == Version(target)
|
||||
|
||||
|
||||
HAS_OPAQUE_TYPE = is_torch_equal_or_newer("2.11.0.dev")
|
||||
|
||||
if HAS_OPAQUE_TYPE:
|
||||
from torch._opaque_base import OpaqueBase
|
||||
else:
|
||||
OpaqueBase = object # type: ignore[misc, assignment]
|
||||
|
||||
|
||||
class ModuleName(OpaqueBase): # type: ignore[misc]
|
||||
"""Wraps a module name string for use as a torch opaque type.
|
||||
|
||||
When torch >= 2.11, this is registered as a hoisted value-type opaque
|
||||
object so that torch.compile lifts it as a graph input instead of baking
|
||||
it as a constant. This avoids per-layer recompilation for MOE ops.
|
||||
"""
|
||||
|
||||
def __init__(self, value: str):
|
||||
self.value = value
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, ModuleName) and self.value == other.value
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.value)
|
||||
|
||||
def __fx_repr__(self):
|
||||
return (f"ModuleName({self.value!r})", {ModuleName})
|
||||
|
||||
|
||||
if HAS_OPAQUE_TYPE:
|
||||
from torch._library.opaque_object import register_opaque_type
|
||||
|
||||
register_opaque_type(ModuleName, typ="value", hoist=True)
|
||||
|
||||
|
||||
# Supports xccl with PyTorch versions >= 2.8.0.dev for XPU platform
|
||||
def supports_xccl() -> bool:
|
||||
return torch.distributed.is_xccl_available()
|
||||
|
||||
Reference in New Issue
Block a user