Files
2026-03-10 13:31:25 +08:00

430 lines
18 KiB
Python

################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# SPDX-License-Identifier: Apache-2.0
"""A GPU worker class."""
import copy
import datetime
import gc
from typing import TYPE_CHECKING, Optional, Union
import torch
import torch.nn as nn
import vllm.envs as envs
import vllm_br.envs as br_envs
from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
DraftTokenIds, ModelRunnerOutput)
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.worker_base import WorkerBase
from vllm_br.platform import SUPAPlatform
from vllm_br.utils import GiB_bytes, SUPAMemorySnapshot
from vllm_br.v1.worker.model_runner import SUPAModelRunner
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
class SUPAWorker(WorkerBase):
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
):
super().__init__(
vllm_config=vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
)
self.kv_transfer_config = vllm_config.kv_transfer_config
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
# Buffers saved before sleep
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info(
"Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir,
)
self.profiler = torch.profiler.profile(
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir, use_gzip=True),
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.SUPA, # type: ignore
],
schedule=torch.profiler.schedule(wait=0,
warmup=0,
active=1,
repeat=1),
profile_memory=False,
record_shapes=True,
with_stack=False,
use_supa_simple=True, # type: ignore
)
else:
self.profiler = None
def sleep(self, level: int = 1) -> None:
raise NotImplementedError
def wake_up(self, tags: Optional[list[str]] = None) -> None:
raise NotImplementedError
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
def init_device(self):
if self.device_config.device.type == "supa":
self.device = torch.device(f"supa:{self.local_rank}")
if self.kv_transfer_config is not None:
device_cursor = self.kv_transfer_config.get_from_extra_config(
"device_cursor", 0)
self.device = torch.device(
f"supa:{self.local_rank + int(device_cursor)}")
SUPAPlatform.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype)
# Initialize the distributed environment BEFORE taking
# memory snapshot
# This ensures SUCCL buffers are allocated before we measure
# available memory
self._init_worker_distributed_environment()
# Set random seed.
set_random_seed(self.model_config.seed)
gc.collect()
torch.supa.empty_cache()
self.init_gpu_memory = SUPAPlatform.get_device_total_memory()
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Construct the model runner
self.model_runner: SUPAModelRunner = SUPAModelRunner( # type: ignore
self.vllm_config, self.device)
if self.rank == 0:
# If usage stat is enabled, collect relevant info.
report_usage_stats(self.vllm_config)
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
# to hijack tensor allocation.
def load_model(self) -> None:
if self.vllm_config.model_config.enable_sleep_mode:
raise NotImplementedError('SUPA do not support sleep mode')
else:
from contextlib import nullcontext
context = nullcontext()
with context:
self.model_runner.load_model()
@torch.inference_mode()
def determine_available_memory(self) -> int:
"""Profiles the peak memory usage of the model to determine how much
memory can be used for KV cache without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the free memory that can be used for KV cache in
bytes.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
torch.supa.empty_cache()
_, total_gpu_memory = torch.supa.mem_get_info()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
before_profile = SUPAMemorySnapshot()
after_profile = SUPAMemorySnapshot()
before_profile.measure()
self.model_runner.profile_run()
after_profile.measure()
free_gpu_memory, _ = torch.supa.mem_get_info()
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
assert self.init_gpu_memory > free_gpu_memory, (
"Error in memory profiling. "
f"Initial free memory {self.init_gpu_memory}, current free memory"
f" {free_gpu_memory}. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.")
# GPU did not change their memory usage during the profiling.
peak_memory = torch.supa.memory_allocated()
# Check for any memory left around that may have been allocated on the
# gpu outside of `torch`. NCCL operations, for example, can use a few
# GB during a forward pass
torch.supa.empty_cache()
torch_allocated_bytes = SUPAPlatform.get_memory_stats(
self.device, "allocated_bytes.all.current")
total_allocated_bytes = (torch.supa.mem_get_info()[1] -
torch.supa.mem_get_info()[0])
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
#if non_torch_allocations > 0:
# peak_memory += non_torch_allocations
available_kv_cache_memory = (
total_gpu_memory * self.cache_config.gpu_memory_utilization -
peak_memory)
memory_for_current_instance = total_gpu_memory * \
self.cache_config.gpu_memory_utilization
diff_profile = after_profile - before_profile
msg = (f"Memory profiling takes {diff_profile.timestamp:.2f} seconds\n"
"the current vLLM instance can use "
"total_gpu_memory "
f"({(total_gpu_memory / GiB_bytes):.2f}GiB)"
" x gpu_memory_utilization "
f"({self.cache_config.gpu_memory_utilization:.2f})"
f" = {(memory_for_current_instance / GiB_bytes):.2f}GiB\n"
"model weights take "
f"{(self.model_runner.model_memory_usage / GiB_bytes):.2f}GiB;"
" non_torch_memory takes "
f"{(non_torch_allocations / GiB_bytes):.2f}GiB;"
" PyTorch activation peak memory takes "
f"{(diff_profile.torch_peak / GiB_bytes):.2f}GiB;"
" the rest of the memory reserved for KV Cache is "
f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.")
logger.info(msg)
return int(available_kv_cache_memory)
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec()
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
if self.vllm_config.model_config.enable_sleep_mode:
raise NotImplementedError('SUPA do not support sleep mode')
else:
from contextlib import nullcontext
context = nullcontext()
with context:
self.model_runner.initialize_kv_cache(kv_cache_config)
def compile_or_warm_up_model(self) -> None:
# warm up sizes that are not in cudagraph capture sizes,
# but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill.
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
if not self.model_config.enforce_eager:
warmup_sizes = [
x for x in warmup_sizes
if x not in self.scheduler_config.cuda_graph_sizes
]
for size in sorted(warmup_sizes, reverse=True):
logger.info("Compile and warming up model for size %d", size)
self.model_runner._dummy_run(size,
skip_eplb=True,
remove_lora=False)
self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
if not self.model_config.enforce_eager:
self.model_runner.capture_model()
# Warm up sampler and preallocate memory buffer for logits and other
# sampling related tensors of max possible shape to avoid memory
# fragmentation issue.
# NOTE: This is called after `capture_model` on purpose to prevent
# memory buffers from being cleared by `SUPAPlatform.empty_cache`.
if get_pp_group().is_last_rank:
max_num_reqs = min(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens,
)
hidden_states, last_hidden_states = \
self.model_runner._dummy_run(
num_tokens=max_num_reqs,
skip_eplb=True,
)
if self.model_runner.is_pooling_model:
self.model_runner._dummy_pooler_run(hidden_states)
else:
self.model_runner._dummy_sampler_run(
hidden_states=last_hidden_states)
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.model_runner.get_supported_tasks()
@torch.inference_mode()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]:
intermediate_tensors = None
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
if forward_pass and not get_pp_group().is_first_rank:
# intermediate_tensors = IntermediateTensors(
# get_pp_group().recv_tensor_dict(
# all_gather_group=get_tp_group()))
# use cpu send/recv
if br_envs.VLLM_PP_CPU_SEND_RECV:
cpu_dict = get_pp_group().recv_tensor_dict()
gpu_dict = {
k: v.to(torch.supa.current_device())
for k, v in cpu_dict.items()
}
intermediate_tensors = IntermediateTensors(gpu_dict)
else:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict())
output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)
if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
return output
assert isinstance(output, IntermediateTensors)
parallel_config = self.vllm_config.parallel_config
assert parallel_config.distributed_executor_backend != (
"external_launcher") and not get_pp_group().is_last_rank
# use cpu send/recv
if br_envs.VLLM_PP_CPU_SEND_RECV:
cpu_dict = {k: v.cpu() for k, v in output.tensors.items()}
get_pp_group().send_tensor_dict(cpu_dict)
else:
get_pp_group().send_tensor_dict(output.tensors)
kv_connector_output = output.kv_connector_output
if not kv_connector_output:
return None
# In case of PP with kv transfer, we need to pass through the
# kv_connector_output
if (not kv_connector_output.finished_sending
and not kv_connector_output.finished_recving):
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.kv_connector_output = kv_connector_output
return output
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
return self.model_runner.take_draft_token_ids()
def profile(self, is_start: bool = True):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
if is_start:
self.profiler.start()
else:
self.profiler.stop()
def execute_dummy_batch(self) -> None:
self.model_runner._dummy_run(1)
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.model_runner.remove_lora(lora_id)
def list_loras(self) -> set[int]:
return self.model_runner.list_loras()
def pin_lora(self, lora_id: int) -> bool:
return self.model_runner.pin_lora(lora_id)
def check_health(self) -> None:
# worker will always be healthy as long as it's running.
return
def save_sharded_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
from vllm.model_executor.model_loader.loader import ShardedStateLoader
ShardedStateLoader.save_model(
self.model_runner.model,
path,
pattern=pattern,
max_size=max_size,
)
def _init_worker_distributed_environment(self) -> None:
"""Initialize the distributed environment."""
set_custom_all_reduce(
not self.parallel_config.disable_custom_all_reduce)
init_distributed_environment(self.parallel_config.world_size,
self.rank,
self.distributed_init_method,
self.local_rank,
"sccl",
timeout=datetime.timedelta(seconds=100))
ensure_model_parallel_initialized(
self.parallel_config.tensor_parallel_size,
self.parallel_config.pipeline_parallel_size)
ensure_kv_transfer_initialized(self.vllm_config)
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.
# TODO: add checkers
return
if torch_dtype == torch.bfloat16: # noqa: SIM102
capability = SUPAPlatform.get_device_capability()
gpu_name = SUPAPlatform.get_device_name()
if capability is None:
compute_str = "does not have a compute capability"
else:
version_str = capability.as_version_str()
compute_str = f"has compute capability {version_str}"
raise ValueError(
"Bfloat16 is only supported on GPUs with compute capability "
f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
"You can use float16 instead by explicitly setting the "
"`dtype` flag in CLI, for example: --dtype=half.")