Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -6,10 +6,13 @@ pynvml. However, it should not initialize cuda context.
|
||||
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from functools import cache, wraps
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
import torch
|
||||
from torch.distributed import PrefixStore, ProcessGroup
|
||||
from torch.distributed.distributed_c10d import is_nccl_available
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
# import custom ops, trigger op registration
|
||||
@@ -414,6 +417,7 @@ class CudaPlatformBase(Platform):
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.FLASHINFER,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
@@ -481,6 +485,37 @@ class CudaPlatformBase(Platform):
|
||||
def get_static_graph_wrapper_cls(cls) -> str:
|
||||
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
|
||||
|
||||
@classmethod
|
||||
def stateless_init_device_torch_dist_pg(
|
||||
cls,
|
||||
backend: str,
|
||||
prefix_store: PrefixStore,
|
||||
group_rank: int,
|
||||
group_size: int,
|
||||
timeout: timedelta,
|
||||
) -> ProcessGroup:
|
||||
assert is_nccl_available()
|
||||
pg: ProcessGroup = ProcessGroup(
|
||||
prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
)
|
||||
from torch.distributed.distributed_c10d import ProcessGroupNCCL
|
||||
|
||||
backend_options = ProcessGroupNCCL.Options()
|
||||
backend_options._timeout = timeout
|
||||
|
||||
backend_class = ProcessGroupNCCL(
|
||||
prefix_store, group_rank, group_size, backend_options
|
||||
)
|
||||
backend_type = ProcessGroup.BackendType.NCCL
|
||||
device = torch.device("cuda")
|
||||
pg._set_default_backend(backend_type)
|
||||
backend_class._set_sequence_number_for_group()
|
||||
|
||||
pg._register_backend(device, backend_type, backend_class)
|
||||
return pg
|
||||
|
||||
@classmethod
|
||||
def device_count(cls) -> int:
|
||||
return cuda_device_count_stateless()
|
||||
@@ -556,7 +591,7 @@ class NvmlCudaPlatform(CudaPlatformBase):
|
||||
physical_device_id = cls.device_id_to_physical_device_id(device_id)
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
|
||||
major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
|
||||
return DeviceCapability(major=major, minor=minor)
|
||||
return DeviceCapability(major=9, minor=0)
|
||||
except RuntimeError:
|
||||
return None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user