156 lines
7.3 KiB
Python
156 lines
7.3 KiB
Python
################################################################################
|
|
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
################################################################################
|
|
|
|
from typing import Optional
|
|
|
|
from vllm.config import CompilationLevel, VllmConfig
|
|
from vllm.logger import logger
|
|
from vllm_br.config.compilation import SUPAGraphMode
|
|
from vllm_br.forward_context import BatchDescriptor
|
|
|
|
_BATCH_SIZE_ALIGNMENT = 8
|
|
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
|
|
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 34)
|
|
]
|
|
|
|
|
|
class SupagraphDispatcher:
|
|
"""
|
|
Runtime supagraph dispatcher to dispatch keys for multiple set of
|
|
supagraphs.
|
|
|
|
The dispatcher stores two sets of dispatch keys, one for PIECEWISE and one
|
|
for FULL supagraph runtime mode. The keys are initialized depending on
|
|
attention support and what supagraph mode is set in CompilationConfig. The
|
|
keys stored in dispatcher are the only source of truth for valid
|
|
supagraphs that can be dispatched at runtime.
|
|
|
|
At runtime, the dispatch method generates the runtime supagraph mode (FULL,
|
|
PIECEWISE, or NONE for no supagraph) and the valid key (batch descriptor)
|
|
based on the input key. After dispatching (communicate via forward context),
|
|
the supagraph wrappers will trust the dispatch key to do either capturing
|
|
or replaying (if mode matched), or pass through to the underlying runnable
|
|
without supagraph (if mode no match or mode is NONE).
|
|
"""
|
|
|
|
def __init__(self, vllm_config: VllmConfig):
|
|
self.vllm_config = vllm_config
|
|
self.compilation_config = vllm_config.compilation_config
|
|
self.use_default_list = vllm_config.compilation_config.max_capture_size > 256 or self.compilation_config.max_capture_size == 0
|
|
self.capture_list = _BATCH_SIZES_TO_CAPTURE if self.use_default_list else self.compilation_config.cudagraph_capture_sizes
|
|
|
|
# TODO(liming): Remove this hard code once we support piecewise
|
|
self.supagraph_mode = SUPAGraphMode.FULL
|
|
|
|
# Dict to store valid supagraph dispatching keys.
|
|
self.supagraph_keys: dict[SUPAGraphMode, set[BatchDescriptor]] = {
|
|
SUPAGraphMode.PIECEWISE: set(),
|
|
SUPAGraphMode.FULL: set(),
|
|
SUPAGraphMode.FULL_DECODE_ONLY: set(),
|
|
}
|
|
|
|
assert not self.supagraph_mode.requires_piecewise_compilation() or \
|
|
(self.compilation_config.level == CompilationLevel.PIECEWISE and
|
|
self.compilation_config.splitting_ops_contain_attention()), \
|
|
"Compilation level should be CompilationLevel.PIECEWISE when "\
|
|
"supagraph_mode piecewise supagraphs is used, "\
|
|
f"supagraph_mode={self.supagraph_mode}, "\
|
|
f"compilation_level={self.compilation_config.level}, "\
|
|
f"splitting_ops={self.compilation_config.splitting_ops}"
|
|
|
|
self.keys_initialized = False
|
|
|
|
def add_supagraph_key(self, runtime_mode: SUPAGraphMode,
|
|
batch_descriptor: BatchDescriptor):
|
|
assert runtime_mode in [SUPAGraphMode.PIECEWISE,SUPAGraphMode.FULL_DECODE_ONLY, SUPAGraphMode.FULL], \
|
|
f"Invalid supagraph runtime mode: {runtime_mode}"
|
|
self.supagraph_keys[runtime_mode].add(batch_descriptor)
|
|
|
|
def initialize_supagraph_keys(self, supagraph_mode: SUPAGraphMode,
|
|
uniform_decode_query_len: int):
|
|
# This should be called only after attention backend is initialized.
|
|
|
|
# Note: we create all valid keys possible for supagraph but do not
|
|
# guarantee all keys would be used. For example, we create keys for
|
|
# piecewise supagraphs when it is piecewise compilation, which is always
|
|
# valid, but for attention backend support unified routine, we may not
|
|
# trigger capturing/replaying the piecewise supagraphs depending on
|
|
# CompilationConfig.supagraph_mode. In addition, if we allow lazy
|
|
# capturing in future PR, some keys may never be triggered.
|
|
if supagraph_mode == SUPAGraphMode.FULL:
|
|
max_num_tokens = (uniform_decode_query_len *
|
|
self.vllm_config.scheduler_config.max_num_seqs)
|
|
supagraph_capture_sizes_for_decode = [
|
|
x for x in self.capture_list
|
|
if x <= max_num_tokens and x >= uniform_decode_query_len
|
|
]
|
|
for bs in supagraph_capture_sizes_for_decode:
|
|
self.add_supagraph_key(
|
|
supagraph_mode,
|
|
BatchDescriptor(num_tokens=bs, uniform_decode=True))
|
|
|
|
# if decode supagraph mode is FULL, and we don't already have mixed
|
|
# mode full supagraphs then add them here.
|
|
if supagraph_mode == SUPAGraphMode.FULL_DECODE_ONLY:
|
|
max_num_tokens = uniform_decode_query_len * \
|
|
self.vllm_config.scheduler_config.max_num_seqs
|
|
supagraph_capture_sizes_for_decode = [
|
|
x for x in self.capture_list
|
|
if x <= max_num_tokens and x >= uniform_decode_query_len
|
|
]
|
|
for bs in supagraph_capture_sizes_for_decode:
|
|
self.add_supagraph_key(
|
|
supagraph_mode,
|
|
BatchDescriptor(num_tokens=bs, uniform_decode=True))
|
|
self.keys_initialized = True
|
|
|
|
def dispatch(
|
|
self, batch_descriptor: BatchDescriptor
|
|
) -> tuple[SUPAGraphMode, Optional[BatchDescriptor]]:
|
|
"""
|
|
Given a batch descriptor, dispatch to a supagraph mode.
|
|
A new batch descriptor is returned as we might dispatch a uniform batch
|
|
to a graph that supports a more general batch (uniform to non-uniform).
|
|
"""
|
|
# if not initialized, just skip dispatching.
|
|
if not self.keys_initialized:
|
|
logger.warning_once("supagraph dispatching keys are not "
|
|
"initialized. No supagraph will be used.")
|
|
return SUPAGraphMode.NONE, None
|
|
|
|
if batch_descriptor in self.supagraph_keys[
|
|
SUPAGraphMode.FULL_DECODE_ONLY]:
|
|
return SUPAGraphMode.FULL_DECODE_ONLY, batch_descriptor
|
|
|
|
# check if key exists for full supagraph
|
|
if batch_descriptor in self.supagraph_keys[SUPAGraphMode.FULL]:
|
|
return SUPAGraphMode.FULL, batch_descriptor
|
|
|
|
# # otherwise, check if non-uniform key exists
|
|
non_uniform_key = batch_descriptor.non_uniform
|
|
if non_uniform_key in self.supagraph_keys[SUPAGraphMode.FULL]:
|
|
return SUPAGraphMode.FULL, non_uniform_key
|
|
|
|
|
|
#
|
|
# # also check if non-uniform key exists for more "general"
|
|
# # piecewise supagraph
|
|
# if non_uniform_key in self.supagraph_keys[SUPAGraphMode.PIECEWISE]:
|
|
# return SUPAGraphMode.PIECEWISE, non_uniform_key
|
|
|
|
# finally, just return no supagraphs
|
|
return SUPAGraphMode.NONE, None
|