Files
enginex-biren-vllm/vllm_br/v1/worker/supagraph_dispatcher.py
2026-03-10 13:31:25 +08:00

156 lines
7.3 KiB
Python

################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Optional
from vllm.config import CompilationLevel, VllmConfig
from vllm.logger import logger
from vllm_br.config.compilation import SUPAGraphMode
from vllm_br.forward_context import BatchDescriptor
_BATCH_SIZE_ALIGNMENT = 8
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 34)
]
class SupagraphDispatcher:
"""
Runtime supagraph dispatcher to dispatch keys for multiple set of
supagraphs.
The dispatcher stores two sets of dispatch keys, one for PIECEWISE and one
for FULL supagraph runtime mode. The keys are initialized depending on
attention support and what supagraph mode is set in CompilationConfig. The
keys stored in dispatcher are the only source of truth for valid
supagraphs that can be dispatched at runtime.
At runtime, the dispatch method generates the runtime supagraph mode (FULL,
PIECEWISE, or NONE for no supagraph) and the valid key (batch descriptor)
based on the input key. After dispatching (communicate via forward context),
the supagraph wrappers will trust the dispatch key to do either capturing
or replaying (if mode matched), or pass through to the underlying runnable
without supagraph (if mode no match or mode is NONE).
"""
def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.use_default_list = vllm_config.compilation_config.max_capture_size > 256 or self.compilation_config.max_capture_size == 0
self.capture_list = _BATCH_SIZES_TO_CAPTURE if self.use_default_list else self.compilation_config.cudagraph_capture_sizes
# TODO(liming): Remove this hard code once we support piecewise
self.supagraph_mode = SUPAGraphMode.FULL
# Dict to store valid supagraph dispatching keys.
self.supagraph_keys: dict[SUPAGraphMode, set[BatchDescriptor]] = {
SUPAGraphMode.PIECEWISE: set(),
SUPAGraphMode.FULL: set(),
SUPAGraphMode.FULL_DECODE_ONLY: set(),
}
assert not self.supagraph_mode.requires_piecewise_compilation() or \
(self.compilation_config.level == CompilationLevel.PIECEWISE and
self.compilation_config.splitting_ops_contain_attention()), \
"Compilation level should be CompilationLevel.PIECEWISE when "\
"supagraph_mode piecewise supagraphs is used, "\
f"supagraph_mode={self.supagraph_mode}, "\
f"compilation_level={self.compilation_config.level}, "\
f"splitting_ops={self.compilation_config.splitting_ops}"
self.keys_initialized = False
def add_supagraph_key(self, runtime_mode: SUPAGraphMode,
batch_descriptor: BatchDescriptor):
assert runtime_mode in [SUPAGraphMode.PIECEWISE,SUPAGraphMode.FULL_DECODE_ONLY, SUPAGraphMode.FULL], \
f"Invalid supagraph runtime mode: {runtime_mode}"
self.supagraph_keys[runtime_mode].add(batch_descriptor)
def initialize_supagraph_keys(self, supagraph_mode: SUPAGraphMode,
uniform_decode_query_len: int):
# This should be called only after attention backend is initialized.
# Note: we create all valid keys possible for supagraph but do not
# guarantee all keys would be used. For example, we create keys for
# piecewise supagraphs when it is piecewise compilation, which is always
# valid, but for attention backend support unified routine, we may not
# trigger capturing/replaying the piecewise supagraphs depending on
# CompilationConfig.supagraph_mode. In addition, if we allow lazy
# capturing in future PR, some keys may never be triggered.
if supagraph_mode == SUPAGraphMode.FULL:
max_num_tokens = (uniform_decode_query_len *
self.vllm_config.scheduler_config.max_num_seqs)
supagraph_capture_sizes_for_decode = [
x for x in self.capture_list
if x <= max_num_tokens and x >= uniform_decode_query_len
]
for bs in supagraph_capture_sizes_for_decode:
self.add_supagraph_key(
supagraph_mode,
BatchDescriptor(num_tokens=bs, uniform_decode=True))
# if decode supagraph mode is FULL, and we don't already have mixed
# mode full supagraphs then add them here.
if supagraph_mode == SUPAGraphMode.FULL_DECODE_ONLY:
max_num_tokens = uniform_decode_query_len * \
self.vllm_config.scheduler_config.max_num_seqs
supagraph_capture_sizes_for_decode = [
x for x in self.capture_list
if x <= max_num_tokens and x >= uniform_decode_query_len
]
for bs in supagraph_capture_sizes_for_decode:
self.add_supagraph_key(
supagraph_mode,
BatchDescriptor(num_tokens=bs, uniform_decode=True))
self.keys_initialized = True
def dispatch(
self, batch_descriptor: BatchDescriptor
) -> tuple[SUPAGraphMode, Optional[BatchDescriptor]]:
"""
Given a batch descriptor, dispatch to a supagraph mode.
A new batch descriptor is returned as we might dispatch a uniform batch
to a graph that supports a more general batch (uniform to non-uniform).
"""
# if not initialized, just skip dispatching.
if not self.keys_initialized:
logger.warning_once("supagraph dispatching keys are not "
"initialized. No supagraph will be used.")
return SUPAGraphMode.NONE, None
if batch_descriptor in self.supagraph_keys[
SUPAGraphMode.FULL_DECODE_ONLY]:
return SUPAGraphMode.FULL_DECODE_ONLY, batch_descriptor
# check if key exists for full supagraph
if batch_descriptor in self.supagraph_keys[SUPAGraphMode.FULL]:
return SUPAGraphMode.FULL, batch_descriptor
# # otherwise, check if non-uniform key exists
non_uniform_key = batch_descriptor.non_uniform
if non_uniform_key in self.supagraph_keys[SUPAGraphMode.FULL]:
return SUPAGraphMode.FULL, non_uniform_key
#
# # also check if non-uniform key exists for more "general"
# # piecewise supagraph
# if non_uniform_key in self.supagraph_keys[SUPAGraphMode.PIECEWISE]:
# return SUPAGraphMode.PIECEWISE, non_uniform_key
# finally, just return no supagraphs
return SUPAGraphMode.NONE, None