update
This commit is contained in:
114
vllm/v1/worker/xpu_worker.py
Normal file
114
vllm/v1/worker/xpu_worker.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import gc
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.profiler.wrapper import TorchProfilerWrapper
|
||||
from vllm.utils.mem_utils import MemorySnapshot, format_gib
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
from vllm.v1.utils import report_usage_stats
|
||||
from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
from vllm.v1.worker.xpu_model_runner import XPUModelRunner
|
||||
|
||||
from .utils import request_memory
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class XPUWorker(Worker):
|
||||
"""A XPU worker class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config, local_rank, rank, distributed_init_method, is_driver_worker
|
||||
)
|
||||
device_config = self.device_config
|
||||
assert device_config.device_type == "xpu"
|
||||
assert current_platform.is_xpu()
|
||||
|
||||
# Torch profiler. Enabled and configured through profiler_config.
|
||||
self.profiler: Any | None = None
|
||||
profiler_config = vllm_config.profiler_config
|
||||
if profiler_config.profiler == "torch":
|
||||
worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
|
||||
self.profiler = TorchProfilerWrapper(
|
||||
profiler_config,
|
||||
worker_name=worker_name,
|
||||
local_rank=self.local_rank,
|
||||
activities=["CPU", "XPU"],
|
||||
)
|
||||
|
||||
def init_device(self):
|
||||
device = self.device_config.device
|
||||
if (
|
||||
isinstance(device, torch.device)
|
||||
and device.type == "xpu"
|
||||
and current_platform.is_xpu()
|
||||
):
|
||||
self.device = torch.device(f"xpu:{self.local_rank}")
|
||||
current_platform.set_device(self.device)
|
||||
current_platform.check_if_supports_dtype(self.model_config.dtype)
|
||||
torch.xpu.empty_cache()
|
||||
self.init_gpu_memory = torch.xpu.get_device_properties(
|
||||
self.local_rank
|
||||
).total_memory
|
||||
else:
|
||||
raise RuntimeError(f"Not support device type: {self.device_config.device}")
|
||||
|
||||
ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi")
|
||||
ENV_LOCAL_WORLD_SIZE = os.getenv(
|
||||
"LOCAL_WORLD_SIZE", str(self.parallel_config.world_size)
|
||||
)
|
||||
os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT
|
||||
os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
|
||||
os.environ["LOCAL_RANK"] = str(self.local_rank)
|
||||
|
||||
init_worker_distributed_environment(
|
||||
self.vllm_config,
|
||||
self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank,
|
||||
current_platform.dist_backend,
|
||||
)
|
||||
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
# Now take memory snapshot after NCCL is initialized
|
||||
gc.collect()
|
||||
torch.xpu.empty_cache()
|
||||
|
||||
# take current memory snapshot
|
||||
self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device)
|
||||
self.requested_memory = request_memory(init_snapshot, self.cache_config)
|
||||
logger.debug("worker init memory snapshot: %r", self.init_snapshot)
|
||||
logger.debug(
|
||||
"worker requested memory: %sGiB", format_gib(self.requested_memory)
|
||||
)
|
||||
|
||||
# Initialize workspace manager
|
||||
num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1
|
||||
init_workspace_manager(self.device, num_ubatches)
|
||||
|
||||
# Construct the model runner
|
||||
self.model_runner = XPUModelRunner( # type: ignore
|
||||
self.vllm_config, self.device
|
||||
)
|
||||
|
||||
if self.rank == 0:
|
||||
# If usage stat is enabled, collect relevant info.
|
||||
report_usage_stats(self.vllm_config)
|
||||
Reference in New Issue
Block a user