################################################################################ # Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ################################################################################ from typing import Optional from vllm.config import CompilationLevel, VllmConfig from vllm.logger import logger from vllm_br.config.compilation import SUPAGraphMode from vllm_br.forward_context import BatchDescriptor _BATCH_SIZE_ALIGNMENT = 8 _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ _BATCH_SIZE_ALIGNMENT * i for i in range(1, 34) ] class SupagraphDispatcher: """ Runtime supagraph dispatcher to dispatch keys for multiple set of supagraphs. The dispatcher stores two sets of dispatch keys, one for PIECEWISE and one for FULL supagraph runtime mode. The keys are initialized depending on attention support and what supagraph mode is set in CompilationConfig. The keys stored in dispatcher are the only source of truth for valid supagraphs that can be dispatched at runtime. At runtime, the dispatch method generates the runtime supagraph mode (FULL, PIECEWISE, or NONE for no supagraph) and the valid key (batch descriptor) based on the input key. After dispatching (communicate via forward context), the supagraph wrappers will trust the dispatch key to do either capturing or replaying (if mode matched), or pass through to the underlying runnable without supagraph (if mode no match or mode is NONE). """ def __init__(self, vllm_config: VllmConfig): self.vllm_config = vllm_config self.compilation_config = vllm_config.compilation_config self.use_default_list = vllm_config.compilation_config.max_capture_size > 256 or self.compilation_config.max_capture_size == 0 self.capture_list = _BATCH_SIZES_TO_CAPTURE if self.use_default_list else self.compilation_config.cudagraph_capture_sizes # TODO(liming): Remove this hard code once we support piecewise self.supagraph_mode = SUPAGraphMode.FULL # Dict to store valid supagraph dispatching keys. self.supagraph_keys: dict[SUPAGraphMode, set[BatchDescriptor]] = { SUPAGraphMode.PIECEWISE: set(), SUPAGraphMode.FULL: set(), SUPAGraphMode.FULL_DECODE_ONLY: set(), } assert not self.supagraph_mode.requires_piecewise_compilation() or \ (self.compilation_config.level == CompilationLevel.PIECEWISE and self.compilation_config.splitting_ops_contain_attention()), \ "Compilation level should be CompilationLevel.PIECEWISE when "\ "supagraph_mode piecewise supagraphs is used, "\ f"supagraph_mode={self.supagraph_mode}, "\ f"compilation_level={self.compilation_config.level}, "\ f"splitting_ops={self.compilation_config.splitting_ops}" self.keys_initialized = False def add_supagraph_key(self, runtime_mode: SUPAGraphMode, batch_descriptor: BatchDescriptor): assert runtime_mode in [SUPAGraphMode.PIECEWISE,SUPAGraphMode.FULL_DECODE_ONLY, SUPAGraphMode.FULL], \ f"Invalid supagraph runtime mode: {runtime_mode}" self.supagraph_keys[runtime_mode].add(batch_descriptor) def initialize_supagraph_keys(self, supagraph_mode: SUPAGraphMode, uniform_decode_query_len: int): # This should be called only after attention backend is initialized. # Note: we create all valid keys possible for supagraph but do not # guarantee all keys would be used. For example, we create keys for # piecewise supagraphs when it is piecewise compilation, which is always # valid, but for attention backend support unified routine, we may not # trigger capturing/replaying the piecewise supagraphs depending on # CompilationConfig.supagraph_mode. In addition, if we allow lazy # capturing in future PR, some keys may never be triggered. if supagraph_mode == SUPAGraphMode.FULL: max_num_tokens = (uniform_decode_query_len * self.vllm_config.scheduler_config.max_num_seqs) supagraph_capture_sizes_for_decode = [ x for x in self.capture_list if x <= max_num_tokens and x >= uniform_decode_query_len ] for bs in supagraph_capture_sizes_for_decode: self.add_supagraph_key( supagraph_mode, BatchDescriptor(num_tokens=bs, uniform_decode=True)) # if decode supagraph mode is FULL, and we don't already have mixed # mode full supagraphs then add them here. if supagraph_mode == SUPAGraphMode.FULL_DECODE_ONLY: max_num_tokens = uniform_decode_query_len * \ self.vllm_config.scheduler_config.max_num_seqs supagraph_capture_sizes_for_decode = [ x for x in self.capture_list if x <= max_num_tokens and x >= uniform_decode_query_len ] for bs in supagraph_capture_sizes_for_decode: self.add_supagraph_key( supagraph_mode, BatchDescriptor(num_tokens=bs, uniform_decode=True)) self.keys_initialized = True def dispatch( self, batch_descriptor: BatchDescriptor ) -> tuple[SUPAGraphMode, Optional[BatchDescriptor]]: """ Given a batch descriptor, dispatch to a supagraph mode. A new batch descriptor is returned as we might dispatch a uniform batch to a graph that supports a more general batch (uniform to non-uniform). """ # if not initialized, just skip dispatching. if not self.keys_initialized: logger.warning_once("supagraph dispatching keys are not " "initialized. No supagraph will be used.") return SUPAGraphMode.NONE, None if batch_descriptor in self.supagraph_keys[ SUPAGraphMode.FULL_DECODE_ONLY]: return SUPAGraphMode.FULL_DECODE_ONLY, batch_descriptor # check if key exists for full supagraph if batch_descriptor in self.supagraph_keys[SUPAGraphMode.FULL]: return SUPAGraphMode.FULL, batch_descriptor # # otherwise, check if non-uniform key exists non_uniform_key = batch_descriptor.non_uniform if non_uniform_key in self.supagraph_keys[SUPAGraphMode.FULL]: return SUPAGraphMode.FULL, non_uniform_key # # # also check if non-uniform key exists for more "general" # # piecewise supagraph # if non_uniform_key in self.supagraph_keys[SUPAGraphMode.PIECEWISE]: # return SUPAGraphMode.PIECEWISE, non_uniform_key # finally, just return no supagraphs return SUPAGraphMode.NONE, None