Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -2,10 +2,14 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from functools import cache, lru_cache, wraps
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from torch.distributed import PrefixStore, ProcessGroup
|
||||
from torch.distributed.distributed_c10d import is_nccl_available
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
@@ -61,13 +65,29 @@ _ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
|
||||
"0x744c": "AMD_Radeon_RX7900XTX",
|
||||
}
|
||||
|
||||
# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES`
|
||||
if "HIP_VISIBLE_DEVICES" in os.environ:
|
||||
val = os.environ["HIP_VISIBLE_DEVICES"]
|
||||
if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None):
|
||||
assert val == cuda_val
|
||||
else:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = val
|
||||
|
||||
def _sync_hip_cuda_env_vars():
|
||||
"""Ensure HIP_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES are consistent.
|
||||
Treats empty string as unset. Raises on genuine conflicts."""
|
||||
hip_val = os.environ.get("HIP_VISIBLE_DEVICES") or None
|
||||
cuda_val = os.environ.get("CUDA_VISIBLE_DEVICES") or None
|
||||
|
||||
if hip_val is not None and cuda_val is not None:
|
||||
if hip_val != cuda_val:
|
||||
raise ValueError(
|
||||
f"Inconsistent GPU visibility env vars: "
|
||||
f"HIP_VISIBLE_DEVICES='{hip_val}' vs "
|
||||
f"CUDA_VISIBLE_DEVICES='{cuda_val}'. "
|
||||
f"Please set only one, or ensure they match."
|
||||
)
|
||||
elif hip_val is not None:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = hip_val
|
||||
elif cuda_val is not None:
|
||||
os.environ["HIP_VISIBLE_DEVICES"] = cuda_val
|
||||
|
||||
|
||||
# Sync at import time - catches misconfigurations from process start.
|
||||
_sync_hip_cuda_env_vars()
|
||||
|
||||
# AMDSMI utils
|
||||
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
|
||||
@@ -131,6 +151,77 @@ _ON_GFX942 = "gfx942" in _GCN_ARCH
|
||||
_ON_GFX950 = "gfx950" in _GCN_ARCH
|
||||
|
||||
|
||||
def _capability_from_gcn_arch(gcn_arch: str) -> tuple[int, int] | None:
|
||||
"""
|
||||
Parse (major, minor) from a GCN arch string, mirroring how
|
||||
HIP derives hipDeviceProp_t.major / .minor.
|
||||
|
||||
Format: gfx<MAJOR><MINOR><STEPPING>
|
||||
- 1-digit major (gfx9xx): "gfx" + M + m + stepping
|
||||
- 2-digit major (gfx1xxx): "gfx" + MM + m + stepping
|
||||
|
||||
Examples:
|
||||
gfx90a -> (9, 0) gfx942 -> (9, 4) gfx950 -> (9, 5)
|
||||
gfx1100 -> (11, 0) gfx1101 -> (11, 0) gfx1200 -> (12, 0)
|
||||
|
||||
Returns None only when the string is not gfx-prefixed at all
|
||||
(i.e. not a ROCm arch string). Raises on any string that looks
|
||||
like a GCN arch but does not match a known layout.
|
||||
"""
|
||||
m = re.match(r"gfx(\d+)", gcn_arch)
|
||||
if not m:
|
||||
# Not a gfx string at all — caller should fall back to torch.cuda
|
||||
return None
|
||||
|
||||
digits = m.group(1)
|
||||
n = len(digits)
|
||||
|
||||
if n < 2:
|
||||
raise ValueError(
|
||||
f"GCN arch '{gcn_arch}' has too few digits ({n}) after 'gfx' "
|
||||
f"to derive a (major, minor) capability. "
|
||||
f"Please file a vLLM issue with your GPU model."
|
||||
)
|
||||
|
||||
if n in (2, 3):
|
||||
# 1-digit major: gfx9 family
|
||||
# len 2: major + minor (e.g. gfx90 from gfx90a)
|
||||
# len 3: major + minor + step (e.g. gfx942)
|
||||
major = int(digits[0])
|
||||
minor = int(digits[1])
|
||||
elif n == 4:
|
||||
# 2-digit major: gfx10xx, gfx11xx, gfx12xx
|
||||
# major(2) + minor(1) + stepping(1)
|
||||
major = int(digits[:2])
|
||||
minor = int(digits[2])
|
||||
elif n >= 5:
|
||||
raise ValueError(
|
||||
f"GCN arch '{gcn_arch}' has {n} digits after 'gfx', which "
|
||||
f"exceeds the known 4-digit layout (MMms). Cannot determine "
|
||||
f"major/minor split unambiguously. "
|
||||
f"Please file a vLLM issue with your GPU model."
|
||||
)
|
||||
|
||||
if major < 9:
|
||||
raise ValueError(
|
||||
f"Parsed unknown ROCm architecture from GCN arch '{gcn_arch}': "
|
||||
f"major={major}, minor={minor}. "
|
||||
f"Major version < 9 is not expected for any supported AMD GPU. "
|
||||
f"Please file a vLLM issue with your GPU model."
|
||||
)
|
||||
|
||||
if major > 12:
|
||||
raise ValueError(
|
||||
f"Parsed unknown ROCm architecture from GCN arch '{gcn_arch}': "
|
||||
f"major={major}, minor={minor}. "
|
||||
f"Major version > 12 is beyond currently known AMD generations. "
|
||||
f"Please file a vLLM issue with your GPU model so support "
|
||||
f"can be added."
|
||||
)
|
||||
|
||||
return (major, minor)
|
||||
|
||||
|
||||
def on_gfx1x() -> bool:
|
||||
return _ON_GFX1X
|
||||
|
||||
@@ -441,6 +532,15 @@ class RocmPlatform(Platform):
|
||||
@classmethod
|
||||
@lru_cache(maxsize=8)
|
||||
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
|
||||
cap = _capability_from_gcn_arch(_GCN_ARCH)
|
||||
if cap is not None:
|
||||
return DeviceCapability(major=cap[0], minor=cap[1])
|
||||
|
||||
logger.warning_once(
|
||||
"Could not derive device capability from GCN arch '%s', "
|
||||
"falling back to torch.cuda (this will initialize CUDA).",
|
||||
_GCN_ARCH,
|
||||
)
|
||||
major, minor = torch.cuda.get_device_capability(device_id)
|
||||
return DeviceCapability(major=major, minor=minor)
|
||||
|
||||
@@ -492,7 +592,6 @@ class RocmPlatform(Platform):
|
||||
use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
|
||||
use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enabled()
|
||||
use_aiter_fused_se = rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
|
||||
use_aiter_triton_rope = rocm_aiter_ops.is_triton_rotary_embed_enabled()
|
||||
# Aiter rms norm perform best when CUDA Graph capture is enabled.
|
||||
if (
|
||||
use_aiter_rms_norm
|
||||
@@ -519,9 +618,9 @@ class RocmPlatform(Platform):
|
||||
and "-grouped_topk" not in compilation_config.custom_ops
|
||||
):
|
||||
compilation_config.custom_ops.append("+grouped_topk")
|
||||
# Enable rotary embedding when using AITER if its not disabled by user
|
||||
# Enable rotary embedding customop when using AITER if not disabled by user
|
||||
if (
|
||||
use_aiter_triton_rope
|
||||
rocm_aiter_ops.is_enabled()
|
||||
and "+rotary_embedding" not in compilation_config.custom_ops
|
||||
and "-rotary_embedding" not in compilation_config.custom_ops
|
||||
):
|
||||
@@ -656,6 +755,37 @@ class RocmPlatform(Platform):
|
||||
def get_static_graph_wrapper_cls(cls) -> str:
|
||||
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
|
||||
|
||||
@classmethod
|
||||
def stateless_init_device_torch_dist_pg(
|
||||
cls,
|
||||
backend: str,
|
||||
prefix_store: PrefixStore,
|
||||
group_rank: int,
|
||||
group_size: int,
|
||||
timeout: timedelta,
|
||||
) -> ProcessGroup:
|
||||
assert is_nccl_available()
|
||||
pg: ProcessGroup = ProcessGroup(
|
||||
prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
)
|
||||
from torch.distributed.distributed_c10d import ProcessGroupNCCL
|
||||
|
||||
backend_options = ProcessGroupNCCL.Options()
|
||||
backend_options._timeout = timeout
|
||||
|
||||
backend_class = ProcessGroupNCCL(
|
||||
prefix_store, group_rank, group_size, backend_options
|
||||
)
|
||||
backend_type = ProcessGroup.BackendType.NCCL
|
||||
device = torch.device("cuda")
|
||||
pg._set_default_backend(backend_type)
|
||||
backend_class._set_sequence_number_for_group()
|
||||
|
||||
pg._register_backend(device, backend_type, backend_class)
|
||||
return pg
|
||||
|
||||
@classmethod
|
||||
def device_count(cls) -> int:
|
||||
return cuda_device_count_stateless()
|
||||
|
||||
Reference in New Issue
Block a user