################################################################################ # Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ################################################################################ """A GPU worker class.""" import gc import os from typing import Optional # SPDX-License-Identifier: Apache-2.0 from typing import Dict, List, Set, Tuple, Type, Union import torch import torch_br import vllm.envs as envs from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.distributed.parallel_state import get_world_group from vllm.forward_context import set_forward_context from vllm.logger import logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.multimodal import MultiModalKwargs from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, SequenceGroupMetadata, SequenceGroupMetadataDelta) from vllm.utils import (GiB_bytes, MemorySnapshot, bind_kv_cache, memory_profiling) from vllm.worker.cache_engine import CacheEngine from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, WorkerInput) from vllm_br.platform import SUPAPlatform from vllm_br.v0.attention.backends.attention_v0 import ( SUPAFlashAttentionMetadata) from vllm_br.v0.worker.pooling_model_runner import ( ModelInputForGPUWithPoolingMetadata, PoolingModelRunner) _NUM_WARMUP_ITERS = 2 def build_batch_input(batch_size, seq_len=256, device="supa"): input_tokens = torch.cat([ torch.randint(0, 200, (seq_len, ), device=device) for _ in range(batch_size) ]) input_positions = torch.arange(seq_len, device=device).repeat(batch_size) seq_lens = [seq_len] * batch_size query_lens = [seq_len] * batch_size query_start_loc = torch.tensor( [i * seq_len for i in range(batch_size + 1)], dtype=torch.int32, device=device) seq_start_loc = [i * seq_len for i in range(batch_size + 1)] context_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) slot_mapping = torch.full((batch_size * seq_len, ), -1, dtype=torch.int32, device=device) attn_metadata = SUPAFlashAttentionMetadata( num_actual_tokens=batch_size * seq_len, max_query_len=seq_len, query_start_loc=query_start_loc, max_seq_len=seq_len, seq_lens=seq_lens, seq_lens_tensor=torch.tensor(seq_lens, dtype=torch.int32, device=device), block_table=torch.empty((batch_size, 0), dtype=torch.int32), slot_mapping=slot_mapping, seq_start_loc=seq_start_loc, context_lens=context_lens, max_decode_seq_len=0, num_prefills=batch_size, num_decodes=0, num_prefills_tokens=batch_size * seq_len, do_cache=False, use_cascade=False, common_prefix_len=0, cu_prefix_query_lens=None, prefix_kv_lens=None, suffix_kv_lens=None, scheduler_metadata=0, prefix_scheduler_metadata=None, _cached_prefill_metadata=None, _cached_decode_metadata=None, local_attn_metadata=None) # build ModelInputForGPUWithPoolingMetadata model_input = ModelInputForGPUWithPoolingMetadata( input_tokens=input_tokens, inputs_embeds=None, input_positions=input_positions, token_types=None, seq_lens=seq_lens, query_lens=query_lens, lora_mapping=None, lora_requests=set(), attn_metadata=attn_metadata, prompt_adapter_mapping=None, prompt_adapter_requests=set(), multi_modal_kwargs={}, request_ids_to_seq_ids={f'embd-{i}': [i] for i in range(batch_size)}, finished_requests_ids=[], virtual_engine=0, async_callback=None, scheduler_outputs=None, previous_hidden_states=None, pooling_metadata=None) return model_input class SUPAWorker(LocalOrDistributedWorkerBase): """A worker class that executes (a partition of) the model on a GPU. Each worker is associated with a single GPU. The worker is responsible for maintaining the KV cache and executing the model on the GPU. In case of distributed inference, each worker is assigned a partition of the model. """ def __init__( self, vllm_config: VllmConfig, local_rank: int, rank: int, distributed_init_method: str, is_driver_worker: bool = False, model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None, ) -> None: WorkerBase.__init__(self, vllm_config) self.parallel_config.rank = rank self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method self.is_driver_worker = is_driver_worker if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules init_cached_hf_modules() # Return hidden states from target model if the draft model is an # mlp_speculator speculative_config = self.speculative_config model_config = self.model_config speculative_args = {} if speculative_config is None \ or (speculative_config.draft_model_config.hf_config.model_type == model_config.hf_config.model_type) \ or (speculative_config.draft_model_config.hf_config.model_type not in ("medusa", "mlp_speculator", "eagle", "deepseek_mtp", "mimo_mtp")) \ else {"return_hidden_states": True} ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner if model_config.runner_type == "pooling": ModelRunnerClass = PoolingModelRunner elif self.model_config.is_encoder_decoder: ModelRunnerClass = EncoderDecoderModelRunner self.model_runner: GPUModelRunnerBase = ModelRunnerClass( vllm_config=self.vllm_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=is_driver_worker, **speculative_args, ) if model_runner_cls is not None: self.model_runner = model_runner_cls(self.model_runner) # Uninitialized cache engine. Will be initialized by # initialize_cache. self.cache_engine: List[CacheEngine] # Initialize gpu_cache as pooling models don't initialize kv_caches self.gpu_cache: Optional[List[List[torch.Tensor]]] = None self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {} # Buffers saved before sleep self._sleep_saved_buffers: Dict[str, torch.Tensor] = {} # Torch profiler. Enabled and configured through env vars: # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace if envs.VLLM_TORCH_PROFILER_DIR: torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR logger.info( "Profiling enabled. Traces will be saved to: %s", torch_profiler_trace_dir, ) self.profiler = torch.profiler.profile( on_trace_ready=torch.profiler.tensorboard_trace_handler( torch_profiler_trace_dir, use_gzip=True), activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.SUPA, # type: ignore ], schedule=torch.profiler.schedule(wait=0, warmup=0, active=1, repeat=1), profile_memory=False, record_shapes=True, with_stack=False, use_supa_simple=True, # type: ignore ) else: self.profiler = None def start_profile(self): if self.profiler is None: raise RuntimeError("Profiler is not enabled.") self.profiler.start() def stop_profile(self): if self.profiler is None: raise RuntimeError("Profiler is not enabled.") self.profiler.stop() def sleep(self, level: int = 1) -> None: raise NotImplementedError def wake_up(self, tags: Optional[list[str]] = None) -> None: raise NotImplementedError def init_device(self): if self.device_config.device.type == "supa": self.device = torch.device(f"supa:{self.local_rank}") SUPAPlatform.set_device(self.device) _check_if_gpu_supports_dtype(self.model_config.dtype) gc.collect() SUPAPlatform.empty_cache() self.init_gpu_memory = SUPAPlatform.mem_get_info()[0] self.baseline_snapshot = MemorySnapshot() else: raise RuntimeError( f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. init_worker_distributed_environment(self.vllm_config, self.rank, self.distributed_init_method, self.local_rank) # Set random seed. set_random_seed(self.model_config.seed) def load_model(self): if self.vllm_config.model_config.enable_sleep_mode: raise NotImplementedError('SUPA do not support sleep mode') else: from contextlib import nullcontext context = nullcontext() with context: self.model_runner.load_model() ### capture graphs ### if os.getenv('ENABLE_VLLM_BR_GRAPH_MODE', 'False').lower() not in {'false', '0', ''}: logger.info("Start capturing graphs...") if not hasattr(self.model_runner, "graph_captured"): self.model_runner.graph_captured = False if not self.model_runner.graph_captured: # support capturing graphs under multiple batch sizes." batch_sizes = [1, 2, 3, 4, 5, 6, 7, 8] self.model_runner.graphs = {} self.model_runner.graph_inputs = {} self.model_runner.graph_outputs = {} for bs in batch_sizes: if self.model_runner.parallel_config.world_size != 1: # prevent SCCL capturing by using the same stream with SCCL self.model_runner.graph_stream = torch.distributed.get_group_stream( get_world_group().device_group) else: self.model_runner.graph_stream = torch_br.supa.Stream() self.model_runner.default_stream = torch.supa.default_stream( ) self.model_runner.copy_done_event = torch_br.supa.Event() self.model_runner.graph_done_event = torch_br.supa.Event() graph = torch.supa.SUPAGraph() self.model_runner.model_input_in = build_batch_input( bs, seq_len=256, device=self.device) self.model_runner.intermediate_tensors = None model_executable = self.model_runner.model multi_modal_kwargs = self.model_runner.model_input_in.multi_modal_kwargs or {} seqlen_agnostic_kwargs = { "finished_requests_ids": self.model_runner.model_input_in.finished_requests_ids, "request_ids_to_seq_ids": self.model_runner.model_input_in. request_ids_to_seq_ids, } if self.model_runner.has_inner_state else {} cross_enc_kwargs = {} if self.model_runner.model_input_in.token_types is not None: cross_enc_kwargs[ "token_type_ids"] = self.model_runner.model_input_in.token_types # Run the model a few times without capturing the graph. # This is to make sure that the captured graph does not include the # kernel launches for initial benchmarking (e.g., Triton autotune). # Note one iteration is not enough for torch.compile for _ in range(_NUM_WARMUP_ITERS): with set_forward_context( self.model_runner.model_input_in.attn_metadata, self.model_runner.vllm_config, self. model_runner.model_input_in.virtual_engine): model_executable( input_ids=self.model_runner.model_input_in. input_tokens, positions=self.model_runner.model_input_in. input_positions, intermediate_tensors=None, **MultiModalKwargs.as_kwargs( multi_modal_kwargs, dtype=self.model_runner.model_config.dtype, device=self.model_runner.device, ), **cross_enc_kwargs, **seqlen_agnostic_kwargs, ) # Wait for the warm up operations to finish before proceeding with # Graph Capture. torch.supa.synchronize() with torch.supa.graph( graph, stream=self.model_runner.graph_stream), \ set_forward_context( self.model_runner.model_input_in.attn_metadata, self.model_runner.vllm_config, self. model_runner.model_input_in.virtual_engine): hidden_or_intermediate_states = model_executable( input_ids=self.model_runner.model_input_in. input_tokens, positions=self.model_runner.model_input_in. input_positions, intermediate_tensors=self.model_runner. intermediate_tensors, **MultiModalKwargs.as_kwargs( multi_modal_kwargs, dtype=self.model_runner.model_config.dtype, device=self.model_runner.device, ), **cross_enc_kwargs, **seqlen_agnostic_kwargs, ) torch.supa.synchronize() self.model_runner.graphs[bs] = graph self.model_runner.graph_inputs[ bs] = self.model_runner.model_input_in self.model_runner.graph_outputs[ bs] = hidden_or_intermediate_states self.model_runner.graph_captured = True logger.info("capturing graphs Done.") def save_sharded_state( self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None, ) -> None: self.model_runner.save_sharded_state( path, pattern=pattern, max_size=max_size, ) def save_tensorized_model( self, tensorizer_config: TensorizerConfig, ) -> None: self.model_runner.save_tensorized_model( tensorizer_config=tensorizer_config, ) @torch.inference_mode() def determine_num_available_blocks(self) -> Tuple[int, int]: """Profiles the peak memory usage of the model to determine how many KV blocks may be allocated without OOMs. The engine will first conduct a profiling of the existing memory usage. Then, it calculate the maximum possible number of GPU and CPU blocks that can be allocated with the remaining free memory. Tip: You may limit the usage of GPU memory by adjusting the `gpu_memory_utilization` parameter. """ # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. SUPAPlatform.empty_cache() _, total_gpu_memory = SUPAPlatform.mem_get_info() # Execute a forward pass with dummy inputs to profile the memory usage # of the model. with memory_profiling( self.baseline_snapshot, weights_memory=self.model_runner.model_memory_usage) as result: self.model_runner.profile_run() self._assert_memory_footprint_increased_during_profiling() memory_for_current_instance = total_gpu_memory * \ self.cache_config.gpu_memory_utilization available_kv_cache_memory = (memory_for_current_instance - result.non_kv_cache_memory) # Calculate the number of blocks that can be allocated with the # profiled peak memory. cache_block_size = self.get_cache_block_size_bytes() if cache_block_size == 0: num_gpu_blocks = 0 num_cpu_blocks = 0 else: num_gpu_blocks = int(available_kv_cache_memory // cache_block_size) num_cpu_blocks = int(self.cache_config.swap_space_bytes // cache_block_size) num_gpu_blocks = max(num_gpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) msg = (f"Memory profiling takes {result.profile_time:.2f} seconds\n" "the current vLLM instance can use " "total_gpu_memory " f"({(total_gpu_memory / GiB_bytes):.2f}GiB)" " x gpu_memory_utilization " f"({self.cache_config.gpu_memory_utilization:.2f})" f" = {(memory_for_current_instance / GiB_bytes):.2f}GiB\n" "model weights take " f"{(result.weights_memory / GiB_bytes):.2f}GiB;" " non_torch_memory takes " f"{(result.non_torch_increase / GiB_bytes):.2f}GiB;" " PyTorch activation peak memory takes " f"{(result.torch_peak_increase / GiB_bytes):.2f}GiB;" " the rest of the memory reserved for KV Cache is " f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.") logger.info(msg) # Final cleanup gc.collect() return num_gpu_blocks, num_cpu_blocks def _assert_memory_footprint_increased_during_profiling(self): # NOTE(woosuk): Here we assume that the other processes using the same # GPU did not change their memory usage during the profiling. free_gpu_memory, total = SUPAPlatform.mem_get_info() supa_memory = total - free_gpu_memory assert self.baseline_snapshot.supa_memory < supa_memory, ( "Error in memory profiling. " f"Initial used memory {self.baseline_snapshot.supa_memory}, " f"currently used memory {supa_memory}. " f"This happens when the GPU memory was " "not properly cleaned up before initializing the vLLM instance.") def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: """Allocate GPU and CPU KV cache with the specified number of blocks. This also warms up the model, which may record CUDA graphs. """ raise_if_cache_size_invalid( num_gpu_blocks, self.cache_config.block_size, self.cache_config.is_attention_free, self.model_config.max_model_len, self.parallel_config.pipeline_parallel_size) self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks if self.vllm_config.model_config.enable_sleep_mode: raise NotImplementedError('SUPA do not support sleep mode') else: from contextlib import nullcontext context = nullcontext() with context: self._init_cache_engine() self._warm_up_model() def _init_cache_engine(self): assert self.cache_config.num_gpu_blocks is not None self.cache_engine = [ CacheEngine(self.cache_config, self.model_config, self.parallel_config, self.device_config) for _ in range(self.parallel_config.pipeline_parallel_size) ] self.gpu_cache = [ self.cache_engine[ve].gpu_cache for ve in range(self.parallel_config.pipeline_parallel_size) ] bind_kv_cache(self.compilation_config.static_forward_context, self.gpu_cache) def _warm_up_model(self) -> None: # warm up sizes that are not in cudagraph capture sizes, # but users still want to compile for better performance, # e.g. for the max-num-batched token size in chunked prefill. warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy() if not self.model_config.enforce_eager: warmup_sizes = [ x for x in warmup_sizes if x not in self.vllm_config.cuda_graph_sizes ] for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) self.model_runner._dummy_run(size) if not self.model_config.enforce_eager: self.model_runner.capture_model(self.gpu_cache) # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. set_random_seed(self.model_config.seed) @property def do_metadata_broadcast(self) -> bool: return self.parallel_config.tensor_parallel_size > 1 @property def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: return self.gpu_cache @torch.inference_mode() def prepare_worker_input( self, execute_model_req: ExecuteModelRequest) -> WorkerInput: virtual_engine = execute_model_req.virtual_engine num_steps = execute_model_req.num_steps num_seq_groups = len(execute_model_req.seq_group_metadata_list) # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. # they contain parameters to launch cudamemcpyasync. blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in, device="cpu", dtype=torch.int64).view(-1, 2) blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out, device="cpu", dtype=torch.int64).view(-1, 2) # `blocks_to_copy` is a gpu tensor. The src and tgt of # blocks to copy are in the same device, and `blocks_to_copy` # can be used directly within cuda kernels. blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, device=self.device, dtype=torch.int64).view(-1, 2) return WorkerInput( num_seq_groups=num_seq_groups, blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, virtual_engine=virtual_engine, num_steps=num_steps, ) @torch.inference_mode() def execute_worker(self, worker_input: WorkerInput) -> None: virtual_engine = worker_input.virtual_engine # Issue cache operations. if (worker_input.blocks_to_swap_in is not None and worker_input.blocks_to_swap_in.numel() > 0): self.cache_engine[virtual_engine].swap_in( worker_input.blocks_to_swap_in) if (worker_input.blocks_to_swap_out is not None and worker_input.blocks_to_swap_out.numel() > 0): self.cache_engine[virtual_engine].swap_out( worker_input.blocks_to_swap_out) if (worker_input.blocks_to_copy is not None and worker_input.blocks_to_copy.numel() > 0): self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy) def _get_cached_seq_group_metadata( self, seq_group_metadata_list: List[Union[SequenceGroupMetadata, SequenceGroupMetadataDelta]], finished_request_ids: List[str]) -> List[SequenceGroupMetadata]: """Return a list of cached Sequence Group Metadata after updating its state. It is used because scheduler only sends delta to workers to reduce the data payload size. The function also cleans up cache based on a given `finished_request_ids`. """ new_seq_group_metadata_list = [] for metadata_or_delta in seq_group_metadata_list: request_id = metadata_or_delta.request_id if request_id not in self._seq_group_metadata_cache: # The first prefill. assert isinstance(metadata_or_delta, SequenceGroupMetadata) self._seq_group_metadata_cache[request_id] = metadata_or_delta else: # The first prefill is already cached. if isinstance(metadata_or_delta, SequenceGroupMetadataDelta): self._seq_group_metadata_cache[request_id].apply_delta( metadata_or_delta) else: # If metadata snapshot is sent again, it is # preempted. Reset the cache because we need to start # from scratch. assert isinstance(metadata_or_delta, SequenceGroupMetadata) self._seq_group_metadata_cache[ request_id] = metadata_or_delta new_seq_group_metadata_list.append( self._seq_group_metadata_cache[request_id]) # Clean up finished ids for finished_id in finished_request_ids: del self._seq_group_metadata_cache[finished_id] return new_seq_group_metadata_list def _execute_model_spmd( self, execute_model_req: ExecuteModelRequest, intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Optional[List[SamplerOutput]]: if execute_model_req is not None: new_seq_group_metadata_list = self._get_cached_seq_group_metadata( execute_model_req.seq_group_metadata_list, execute_model_req.finished_requests_ids) execute_model_req.seq_group_metadata_list = ( new_seq_group_metadata_list) output = super()._execute_model_spmd(execute_model_req, intermediate_tensors) return output def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_runner.add_lora(lora_request) def remove_lora(self, lora_id: int) -> bool: return self.model_runner.remove_lora(lora_id) def pin_lora(self, lora_id: int) -> bool: return self.model_runner.pin_lora(lora_id) def list_loras(self) -> Set[int]: return self.model_runner.list_loras() def add_prompt_adapter( self, prompt_adapter_request: PromptAdapterRequest) -> bool: return self.model_runner.add_prompt_adapter(prompt_adapter_request) def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: return self.model_runner.remove_lora(prompt_adapter_id) def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: return self.model_runner.pin_prompt_adapter(prompt_adapter_id) def list_prompt_adapters(self) -> Set[int]: return self.model_runner.list_prompt_adapters() @property def max_model_len(self) -> int: return self.model_config.max_model_len @property def vocab_size(self) -> int: return self.model_runner.vocab_size def get_cache_block_size_bytes(self) -> int: """Get the size of the KV cache block size in bytes. """ return CacheEngine.get_cache_block_size(self.cache_config, self.model_config, self.parallel_config) def init_worker_distributed_environment( vllm_config: VllmConfig, rank: int, distributed_init_method: Optional[str] = None, local_rank: int = -1, ) -> None: """Initialize the distributed environment.""" parallel_config = vllm_config.parallel_config set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank, "sccl") ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) ensure_kv_transfer_initialized(vllm_config) def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): # Check if the GPU supports the dtype. # TODO: add checkers return def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free, max_model_len, pipeline_parallel_size) -> None: if is_attention_free and num_gpu_blocks != 0: raise ValueError("No memory should be allocated for the cache blocks " f"for an attention-free model, but {num_gpu_blocks} " "blocks are allocated.") if not is_attention_free and num_gpu_blocks <= 0: raise ValueError("No available memory for the cache blocks. " "Try increasing `gpu_memory_utilization` when " "initializing the engine.") max_seq_len = block_size * (num_gpu_blocks // pipeline_parallel_size) if not is_attention_free and max_model_len > max_seq_len: raise ValueError( f"The model's max seq len ({max_model_len}) " "is larger than the maximum number of tokens that can be " f"stored in KV cache ({max_seq_len}). Try increasing " "`gpu_memory_utilization` or decreasing `max_model_len` when " "initializing the engine.")