[gpt-oss] Add gpt-oss bf16 support
This commit is contained in:
86
vllm/v1/worker/cpu_model_runner.py
Normal file
86
vllm/v1/worker/cpu_model_runner.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CPUModelRunner(GPUModelRunner):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||||
super().__init__(vllm_config, device)
|
||||
|
||||
assert device == torch.device("cpu")
|
||||
assert self.speculative_config is None, "spec decode is not supported."
|
||||
|
||||
self.use_cuda_graph = False
|
||||
self.cascade_attn_enabled = False
|
||||
|
||||
self._postprocess_tenosrs()
|
||||
|
||||
def _postprocess_tenosrs(self) -> None:
|
||||
# Note: replace device tensors with cpu tensors
|
||||
def replace_tensor(obj: Any, cpu_attr_name: str,
|
||||
device_attr_name) -> None:
|
||||
cpu_tensor = getattr(obj, cpu_attr_name, None)
|
||||
device_tensor = getattr(obj, device_attr_name, None)
|
||||
if cpu_tensor is not None and device_tensor is not None:
|
||||
assert isinstance(cpu_tensor, torch.Tensor)
|
||||
assert isinstance(device_tensor, torch.Tensor)
|
||||
setattr(obj, device_attr_name, cpu_tensor)
|
||||
|
||||
for k, v in vars(self).items():
|
||||
if k.endswith("_cpu") and isinstance(v, torch.Tensor):
|
||||
replace_tensor(self, k, k[:-4])
|
||||
|
||||
for k, v in vars(self.input_batch).items():
|
||||
if k.endswith("_cpu_tensor") and isinstance(v, torch.Tensor):
|
||||
replace_tensor(self.input_batch, k, k[:-11])
|
||||
|
||||
for k, v in vars(self.input_batch.block_table).items():
|
||||
if k.endswith("_cpu") and isinstance(v, torch.Tensor):
|
||||
replace_tensor(self.input_batch.block_table, k, k[:-4])
|
||||
|
||||
def load_model(self) -> None:
|
||||
logger.info("Starting to load model %s...", self.model_config.model)
|
||||
self.model = get_model(vllm_config=self.vllm_config)
|
||||
|
||||
if self.lora_config:
|
||||
self.model = self.load_lora_model(self.model, self.model_config,
|
||||
self.scheduler_config,
|
||||
self.lora_config, self.device)
|
||||
|
||||
def warming_up_model(self) -> None:
|
||||
logger.info("Warming up model for the compilation...")
|
||||
# Only generate graph for the generic shape
|
||||
self._dummy_run(max(16, self.max_num_reqs))
|
||||
logger.info("Warming up done.")
|
||||
|
||||
def _init_device_properties(self) -> None:
|
||||
pass
|
||||
|
||||
def _sync_device(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _set_global_compilation_settings():
|
||||
import torch._inductor.config
|
||||
|
||||
# Note: The CPPGEMM backend requires freezing parameters.
|
||||
freezing_value = torch._inductor.config.freezing
|
||||
torch._inductor.config.freezing = True
|
||||
# Note: workaround for "ValueError: fast mode: can't pickle cyclic objects
|
||||
# including object type dict"
|
||||
force_disable_caches = torch._inductor.config.force_disable_caches
|
||||
torch._inductor.config.force_disable_caches = True
|
||||
yield
|
||||
torch._inductor.config.freezing = freezing_value
|
||||
torch._inductor.config.force_disable_caches = force_disable_caches
|
||||
Reference in New Issue
Block a user