# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from contextlib import contextmanager from typing import Any import torch import torch.nn as nn from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_model_runner import GPUModelRunner logger = init_logger(__name__) class CPUModelRunner(GPUModelRunner): def __init__(self, vllm_config: VllmConfig, device: torch.device): with _torch_cuda_wrapper(): super().__init__(vllm_config, device) assert device == torch.device("cpu") assert self.speculative_config is None, "spec decode is not supported." self.use_cuda_graph = False self.cascade_attn_enabled = False self._postprocess_tensors() def _postprocess_tensors(self) -> None: # Note: replace device tensors with cpu tensors def replace_tensor(obj: Any, cpu_attr_name: str, device_attr_name) -> None: cpu_tensor = getattr(obj, cpu_attr_name, None) device_tensor = getattr(obj, device_attr_name, None) if cpu_tensor is not None and device_tensor is not None: assert isinstance(cpu_tensor, torch.Tensor) assert isinstance(device_tensor, torch.Tensor) setattr(obj, device_attr_name, cpu_tensor) for v in vars(self).values(): if isinstance(v, CpuGpuBuffer): v.gpu = v.cpu for k, v in vars(self.input_batch).items(): if k.endswith("_cpu_tensor") and isinstance(v, torch.Tensor): replace_tensor(self.input_batch, k, k[:-11]) for block_table in self.input_batch.block_table.block_tables: for v in vars(block_table).values(): if isinstance(v, CpuGpuBuffer): v.gpu = v.cpu def load_model(self, eep_scale_up: bool = False) -> None: logger.info("Starting to load model %s...", self.model_config.model) self.model = get_model(vllm_config=self.vllm_config) if self.lora_config: self.model = self.load_lora_model(self.model, self.vllm_config, self.device) def get_model(self) -> nn.Module: return self.model def warming_up_model(self) -> None: logger.info("Warming up model for the compilation...") # Only generate graph for the generic shape with _set_global_compilation_settings(self.vllm_config): self._dummy_run( min( max(16, self.max_num_reqs), self.scheduler_config.max_num_batched_tokens, ) ) logger.info("Warming up done.") def _init_device_properties(self) -> None: pass def _sync_device(self) -> None: pass def get_dp_padding(self, num_tokens: int) -> tuple[int, torch.Tensor | None]: # Note: For CPU backend, dp padding is not required for now. return 0, None @contextmanager def _torch_cuda_wrapper(): class _EventPlaceholder: def __init__(self, *args, **kwargs) -> None: self.record = lambda: None self.synchronize = lambda: None class _StreamPlaceholder: def __init__(self, *args, **kwargs) -> None: pass cuda_event = torch.Event cuda_stream = torch.cuda.Stream try: torch.Event = _EventPlaceholder torch.cuda.Stream = _StreamPlaceholder yield finally: torch.Event = cuda_event torch.cuda.Stream = cuda_stream @contextmanager def _set_global_compilation_settings(config: VllmConfig): import torch._inductor.config as torch_inductor_config inductor_config = config.compilation_config.inductor_compile_config # Note: The MKLDNN and CPPGEMM backend requires freezing parameters. freezing_value = torch_inductor_config.freezing try: if inductor_config.get("max_autotune", False): torch_inductor_config.freezing = True yield finally: torch_inductor_config.freezing = freezing_value