first commit

This commit is contained in:
2026-03-10 13:31:25 +08:00
parent ba974cecfa
commit b62b889355
2604 changed files with 438977 additions and 0 deletions

View File

@@ -0,0 +1,15 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,49 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import copy
from typing import TYPE_CHECKING
from vllm.config import VllmConfig
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin)
from vllm_br.forward_context import set_forward_context
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
# @staticmethod
def kv_connector_no_forward(scheduler_output: "SchedulerOutput",
vllm_config: VllmConfig) -> ModelRunnerOutput:
# KV send/recv even if no work to do.
with set_forward_context(
None,
vllm_config), KVConnectorModelRunnerMixin._get_kv_connector_output(
scheduler_output, wait_for_save=False) as kv_connector_output:
pass
if (not kv_connector_output.finished_sending
and not kv_connector_output.finished_recving):
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.kv_connector_output = kv_connector_output
return output
KVConnectorModelRunnerMixin.kv_connector_no_forward = kv_connector_no_forward

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,413 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import threading
from dataclasses import dataclass
from typing import Any, Callable, Optional
import torch
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed import get_ep_group
from vllm.distributed.device_communicators.pynccl_allocator import (
set_graph_pool_id)
from vllm.forward_context import (create_forward_context, get_forward_context,
override_forward_context)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import has_deep_gemm
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
from vllm_br.compilation.supa_graph import SUPAGraphWrapper
from vllm_br.config.compilation import SUPAGraphMode
logger = init_logger(__name__)
@dataclass
class UbatchMetadata:
context: UBatchContext
input_ids: torch.Tensor
positions: torch.Tensor
inputs_embeds: Optional[torch.Tensor]
intermediate_tensors: Optional[IntermediateTensors]
num_tokens: int
@dataclass
class SUPAGraphMetaData:
supagraph: torch.supa.SUPAGraph
ubatch_metadata: UbatchMetadata
outputs: Optional[Any] = None
class SMControlContextManager:
def __init__(self, comm_sms: int, set_comm_sms: Callable[[int], None],
set_compute_sms: Callable[[int], None]):
"""
Context manager for controlling SM (Streaming Multiprocessor)
allocation. Upon entering the context, it sets the number of SMs
allocated for communication and computation to comm_sms and
total_sms - comm_sms respectively. Upon exiting, it restores the
allocation to use all available SMs (i.e. total_sms).
Args:
comm_sms (int): The number of SMs to allocate for communication.
(The remainder will be used for computation.)
set_comm_sms (Callable[[int], None]):
A function that sets the number of SMs for communication.
set_compute_sms (Callable[[int], None]):
A function that sets the number of SMs for computation.
"""
assert current_platform.is_supa(), \
"SM control is currently only supported on SUPA"
props = torch.supa.get_device_properties(torch.supa.current_device())
total_sms = props.multi_processor_count
assert comm_sms < total_sms
self.total_sms = total_sms
self.compute_sms = total_sms - comm_sms
self.comm_sms = comm_sms
self.set_comm_sms = set_comm_sms
self.set_compute_sms = set_compute_sms
def __enter__(self):
self.set_comm_sms(self.comm_sms)
self.set_compute_sms(self.compute_sms)
def __exit__(self, exc_type, exc_value, traceback):
self.set_comm_sms(self.total_sms)
self.set_compute_sms(self.total_sms)
class UBatchWrapper:
def __init__(self, runnable: Callable, vllm_config: VllmConfig,
runtime_mode: SUPAGraphMode, device: torch.supa.device):
self.runnable = runnable
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.comm_stream = torch.supa.Stream(device=device)
# Two ubatch threads plus the main thread
self.ready_barrier = threading.Barrier(3)
self.supagraphs: dict[int, SUPAGraphMetaData] = {}
self.supagraph_wrapper = None
self.graph_pool = None
if runtime_mode is not SUPAGraphMode.NONE:
self.supagraph_wrapper = SUPAGraphWrapper(
runnable, vllm_config, runtime_mode=runtime_mode)
self.graph_pool = current_platform.get_global_graph_pool()
self.sm_control = self._create_sm_control_context(vllm_config)
self.device = device
@staticmethod
def _create_sm_control_context(vllm_config: VllmConfig):
comm_sms = envs.VLLM_DBO_COMM_SMS
set_comm_sms = lambda sms: None
if vllm_config.parallel_config.enable_expert_parallel:
# Currently only DeepEP highthroughput supports SM control so this
# only affects that case.
all2all_manager = get_ep_group(
).device_communicator.all2all_manager
if all2all_manager.max_sms_used() is not None:
comm_sms = min(comm_sms, all2all_manager.max_sms_used())
if comm_sms > 0:
set_comm_sms = lambda sms: all2all_manager.set_num_sms(sms)
# TODO(lucas): support other kernels besides DeepGEMM
set_compute_sms = lambda sms: None
if has_deep_gemm() and comm_sms > 0:
import deep_gemm as dg
set_compute_sms = lambda sms: dg.set_num_sms(sms)
return SMControlContextManager(comm_sms=comm_sms,
set_comm_sms=set_comm_sms,
set_compute_sms=set_compute_sms)
def __getattr__(self, key: str):
# allow accessing the attributes of the runnable.
if hasattr(self.runnable, key):
return getattr(self.runnable, key)
raise AttributeError(f"Attribute {key} not exists in the runnable of "
f"supagraph wrapper: {self.runnable}")
def unwrap(self) -> Callable:
# in case we need to access the original runnable.
return self.runnable
def _capture_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
"""
Capture a supagraph for a microbatched run.
The logic here is somewhat complicated because we need to make sure that
each of the ubatch threads initialize the supa context before we start
the graph capture.
The flow is as follows:
1. The main thread starts up each ubatch thread. Each thread will
initialize its supa context (torch.supa.current_blas_handle())
before going to sleep upon entering the ubatch_context.
2. The main thread starts the graph capture and wakes up the first
ubatch thread.
3. Each ubatch thread runs the model to completion and returns the
completed output tensors back to the main thread.
4. The main thread stores the captured supagraph along with its metadata
and returns
"""
@torch.inference_mode()
def _capture_ubatch_thread(results, ubatch_metadata):
torch.supa.set_device(self.device)
ubatch_context = ubatch_metadata.context
with torch.supa.stream(ubatch_context.compute_stream):
_ = torch.supa.current_blas_handle()
with torch.supa.stream(ubatch_context.comm_stream):
_ = torch.supa.current_blas_handle()
with ubatch_context:
model_output = model(
input_ids=ubatch_metadata.input_ids,
positions=ubatch_metadata.positions,
intermediate_tensors=ubatch_metadata.intermediate_tensors,
inputs_embeds=ubatch_metadata.inputs_embeds,
)
results.append((ubatch_metadata.context.id, model_output))
results: list[tuple[int, torch.Tensor]] = []
compute_stream = ubatch_metadata[0].context.compute_stream
num_tokens = ubatch_metadata[0].num_tokens + \
ubatch_metadata[1].num_tokens
# Ubatches will manually manage the forward context, so we override
# it to None here so we can have it restored correctly later
with override_forward_context(None):
ubatch_threads = []
for metadata in ubatch_metadata:
thread = threading.Thread(target=_capture_ubatch_thread,
args=(
results,
metadata,
))
ubatch_threads.append(thread)
thread.start()
self.ready_barrier.wait() # Wait for both threads to be ready
# Capture the supagraph
supagraph_metadata = \
SUPAGraphMetaData(
supagraph=torch.supa.SUPAGraph(),
ubatch_metadata=ubatch_metadata,
)
if self.graph_pool is not None:
set_graph_pool_id(self.graph_pool)
else:
set_graph_pool_id(current_platform.graph_pool_handle())
with torch.supa.graph(supagraph_metadata.supagraph,
stream=compute_stream,
pool=self.graph_pool):
ubatch_metadata[0].context.cpu_wait_event.set()
for thread in ubatch_threads:
thread.join()
sorted_results = [value for position, value in sorted(results)]
result = torch.cat(sorted_results, dim=0)
supagraph_metadata.outputs = result
self.supagraphs[num_tokens] = supagraph_metadata
return supagraph_metadata.outputs
def _run_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
@torch.inference_mode()
def _ubatch_thread(results, model, ubatch_metadata):
with ubatch_metadata.context:
model_output = model(
input_ids=ubatch_metadata.input_ids,
positions=ubatch_metadata.positions,
intermediate_tensors=ubatch_metadata.intermediate_tensors,
inputs_embeds=ubatch_metadata.inputs_embeds,
)
results.append((ubatch_metadata.context.id, model_output))
results: list[tuple[int, torch.Tensor]] = []
# Ubatch threads will manually manage the forward context, so we
# override it to None here so we can have it restored correctly
# after both threads have finished
with override_forward_context(None):
ubatch_threads = []
for metadata in ubatch_metadata:
thread = threading.Thread(target=_ubatch_thread,
args=(
results,
model,
metadata,
))
ubatch_threads.append(thread)
thread.start()
self.ready_barrier.wait() # Wait for both threads to be ready
ubatch_metadata[0].context.cpu_wait_event.set()
for thread in ubatch_threads:
thread.join()
sorted_results = [value for position, value in sorted(results)]
result = torch.cat(sorted_results, dim=0)
return result
def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, input_ids,
positions, inputs_embeds, intermediate_tensors,
compute_stream, dp_metadata, batch_descriptor,
supagraph_runtime_mode) -> list[UbatchMetadata]:
# Create one forward context per ubatch
forward_contexts = []
for i, ubatch_slice in enumerate(ubatch_slices):
forward_contexts.append(
create_forward_context(
attn_metadata[i] if attn_metadata is not None else None,
self.vllm_config,
dp_metadata=dp_metadata,
batch_descriptor=batch_descriptor,
supagraph_runtime_mode=supagraph_runtime_mode))
ubatch_ctxs = make_ubatch_contexts(
num_micro_batches=len(ubatch_slices),
comm_stream=self.comm_stream,
compute_stream=compute_stream,
forward_contexts=forward_contexts,
ready_barrier=self.ready_barrier)
ubatch_metadata: list[UbatchMetadata] = []
for i, ubatch_slice in enumerate(ubatch_slices):
sliced_input_ids, sliced_positions, sliced_inputs_embeds, \
sliced_intermediate_tensors = \
self._slice_model_inputs(
ubatch_slice.token_slice, input_ids, positions,
inputs_embeds, intermediate_tensors)
ubatch_metadata.append(
UbatchMetadata(
context=ubatch_ctxs[i],
input_ids=sliced_input_ids,
positions=sliced_positions,
inputs_embeds=sliced_inputs_embeds,
intermediate_tensors=sliced_intermediate_tensors,
num_tokens=ubatch_slice.token_slice.stop -
ubatch_slice.token_slice.start))
return ubatch_metadata
def _slice_model_inputs(self, tokens_slice: slice, input_ids, positions,
inputs_embeds, intermediate_tensors):
sliced_input_ids = input_ids[tokens_slice]
# if we are using mrope. Mrope adds an additional dimension to the
# positions tensor
if positions.ndim == 2:
sliced_positions = positions[:, tokens_slice]
else:
sliced_positions = positions[tokens_slice]
sliced_inputs_embeds = inputs_embeds[
tokens_slice] if inputs_embeds else None
sliced_intermediate_tensors = intermediate_tensors[
tokens_slice] if intermediate_tensors else None
return (sliced_input_ids, sliced_positions, sliced_inputs_embeds,
sliced_intermediate_tensors)
def __call__(self, *args, **kwargs):
forward_context = get_forward_context()
batch_descriptor = forward_context.batch_descriptor
ubatch_slices = forward_context.ubatch_slices
supagraph_runtime_mode = forward_context.cudagraph_runtime_mode
# If there's no ubatching, just run the runnable object
if ubatch_slices is None:
# This is to account for the case where ubatching was aborted.
# When we capture full graphs we only capture one graph per shape,
# meaning that if we have a ubatched supagraph for the current
# num_tokens, we don't have a non-ubatched one. Without this
# check, the supagraph wrapper will try to capture a supagraph
# for this shape during a normal run.
if supagraph_runtime_mode is SUPAGraphMode.FULL:
assert batch_descriptor is not None
if batch_descriptor.num_tokens in self.supagraphs:
supagraph_runtime_mode = SUPAGraphMode.NONE
if supagraph_runtime_mode in (SUPAGraphMode.NONE,
SUPAGraphMode.PIECEWISE):
return self.runnable(*args, **kwargs)
else:
assert self.supagraph_wrapper is not None
return self.supagraph_wrapper(*args, **kwargs)
attn_metadata = forward_context.attn_metadata
num_tokens = (ubatch_slices[0].token_slice.stop -
ubatch_slices[0].token_slice.start) * 2
input_ids = kwargs['input_ids']
positions = kwargs['positions']
intermediate_tensors = kwargs['intermediate_tensors']
inputs_embeds = kwargs['inputs_embeds']
compute_stream = torch.supa.current_stream()
dp_metadata = forward_context.dp_metadata
# We shouldn't be here unless we are running with multiple DP ranks
assert dp_metadata is not None
if num_tokens not in self.supagraphs \
and supagraph_runtime_mode is SUPAGraphMode.FULL:
ubatch_metadata = self._make_ubatch_metadata(
ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata,
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
compute_stream=compute_stream,
dp_metadata=dp_metadata,
batch_descriptor=batch_descriptor,
supagraph_runtime_mode=SUPAGraphMode.NONE)
with self.sm_control:
return self._capture_ubatches(ubatch_metadata, self.model)
elif num_tokens in self.supagraphs \
and supagraph_runtime_mode is SUPAGraphMode.FULL:
supagraph_metadata = self.supagraphs[num_tokens]
supagraph_metadata.supagraph.replay()
return supagraph_metadata.outputs
else:
ubatch_metadata = self._make_ubatch_metadata(
ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata,
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
compute_stream=compute_stream,
dp_metadata=dp_metadata,
batch_descriptor=batch_descriptor,
supagraph_runtime_mode=SUPAGraphMode.NONE)
with self.sm_control:
return self._run_ubatches(ubatch_metadata, self.model)

View File

@@ -0,0 +1,155 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Optional
from vllm.config import CompilationLevel, VllmConfig
from vllm.logger import logger
from vllm_br.config.compilation import SUPAGraphMode
from vllm_br.forward_context import BatchDescriptor
_BATCH_SIZE_ALIGNMENT = 8
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 34)
]
class SupagraphDispatcher:
"""
Runtime supagraph dispatcher to dispatch keys for multiple set of
supagraphs.
The dispatcher stores two sets of dispatch keys, one for PIECEWISE and one
for FULL supagraph runtime mode. The keys are initialized depending on
attention support and what supagraph mode is set in CompilationConfig. The
keys stored in dispatcher are the only source of truth for valid
supagraphs that can be dispatched at runtime.
At runtime, the dispatch method generates the runtime supagraph mode (FULL,
PIECEWISE, or NONE for no supagraph) and the valid key (batch descriptor)
based on the input key. After dispatching (communicate via forward context),
the supagraph wrappers will trust the dispatch key to do either capturing
or replaying (if mode matched), or pass through to the underlying runnable
without supagraph (if mode no match or mode is NONE).
"""
def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.use_default_list = vllm_config.compilation_config.max_capture_size > 256 or self.compilation_config.max_capture_size == 0
self.capture_list = _BATCH_SIZES_TO_CAPTURE if self.use_default_list else self.compilation_config.cudagraph_capture_sizes
# TODO(liming): Remove this hard code once we support piecewise
self.supagraph_mode = SUPAGraphMode.FULL
# Dict to store valid supagraph dispatching keys.
self.supagraph_keys: dict[SUPAGraphMode, set[BatchDescriptor]] = {
SUPAGraphMode.PIECEWISE: set(),
SUPAGraphMode.FULL: set(),
SUPAGraphMode.FULL_DECODE_ONLY: set(),
}
assert not self.supagraph_mode.requires_piecewise_compilation() or \
(self.compilation_config.level == CompilationLevel.PIECEWISE and
self.compilation_config.splitting_ops_contain_attention()), \
"Compilation level should be CompilationLevel.PIECEWISE when "\
"supagraph_mode piecewise supagraphs is used, "\
f"supagraph_mode={self.supagraph_mode}, "\
f"compilation_level={self.compilation_config.level}, "\
f"splitting_ops={self.compilation_config.splitting_ops}"
self.keys_initialized = False
def add_supagraph_key(self, runtime_mode: SUPAGraphMode,
batch_descriptor: BatchDescriptor):
assert runtime_mode in [SUPAGraphMode.PIECEWISE,SUPAGraphMode.FULL_DECODE_ONLY, SUPAGraphMode.FULL], \
f"Invalid supagraph runtime mode: {runtime_mode}"
self.supagraph_keys[runtime_mode].add(batch_descriptor)
def initialize_supagraph_keys(self, supagraph_mode: SUPAGraphMode,
uniform_decode_query_len: int):
# This should be called only after attention backend is initialized.
# Note: we create all valid keys possible for supagraph but do not
# guarantee all keys would be used. For example, we create keys for
# piecewise supagraphs when it is piecewise compilation, which is always
# valid, but for attention backend support unified routine, we may not
# trigger capturing/replaying the piecewise supagraphs depending on
# CompilationConfig.supagraph_mode. In addition, if we allow lazy
# capturing in future PR, some keys may never be triggered.
if supagraph_mode == SUPAGraphMode.FULL:
max_num_tokens = (uniform_decode_query_len *
self.vllm_config.scheduler_config.max_num_seqs)
supagraph_capture_sizes_for_decode = [
x for x in self.capture_list
if x <= max_num_tokens and x >= uniform_decode_query_len
]
for bs in supagraph_capture_sizes_for_decode:
self.add_supagraph_key(
supagraph_mode,
BatchDescriptor(num_tokens=bs, uniform_decode=True))
# if decode supagraph mode is FULL, and we don't already have mixed
# mode full supagraphs then add them here.
if supagraph_mode == SUPAGraphMode.FULL_DECODE_ONLY:
max_num_tokens = uniform_decode_query_len * \
self.vllm_config.scheduler_config.max_num_seqs
supagraph_capture_sizes_for_decode = [
x for x in self.capture_list
if x <= max_num_tokens and x >= uniform_decode_query_len
]
for bs in supagraph_capture_sizes_for_decode:
self.add_supagraph_key(
supagraph_mode,
BatchDescriptor(num_tokens=bs, uniform_decode=True))
self.keys_initialized = True
def dispatch(
self, batch_descriptor: BatchDescriptor
) -> tuple[SUPAGraphMode, Optional[BatchDescriptor]]:
"""
Given a batch descriptor, dispatch to a supagraph mode.
A new batch descriptor is returned as we might dispatch a uniform batch
to a graph that supports a more general batch (uniform to non-uniform).
"""
# if not initialized, just skip dispatching.
if not self.keys_initialized:
logger.warning_once("supagraph dispatching keys are not "
"initialized. No supagraph will be used.")
return SUPAGraphMode.NONE, None
if batch_descriptor in self.supagraph_keys[
SUPAGraphMode.FULL_DECODE_ONLY]:
return SUPAGraphMode.FULL_DECODE_ONLY, batch_descriptor
# check if key exists for full supagraph
if batch_descriptor in self.supagraph_keys[SUPAGraphMode.FULL]:
return SUPAGraphMode.FULL, batch_descriptor
# # otherwise, check if non-uniform key exists
non_uniform_key = batch_descriptor.non_uniform
if non_uniform_key in self.supagraph_keys[SUPAGraphMode.FULL]:
return SUPAGraphMode.FULL, non_uniform_key
#
# # also check if non-uniform key exists for more "general"
# # piecewise supagraph
# if non_uniform_key in self.supagraph_keys[SUPAGraphMode.PIECEWISE]:
# return SUPAGraphMode.PIECEWISE, non_uniform_key
# finally, just return no supagraphs
return SUPAGraphMode.NONE, None

View File

@@ -0,0 +1,195 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import threading
import torch
from vllm import forward_context
from vllm.forward_context import ForwardContext
from vllm.utils import current_stream
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
class SUPAUBatchContext:
"""
Context manager for micro-batching synchronization using threading events.
"""
def __init__(self,
id: int,
comm_stream: torch.supa.Stream,
compute_stream: torch.supa.Stream,
forward_context: ForwardContext,
ready_barrier: threading.Barrier,
cpu_wait_event: threading.Event,
cpu_signal_event: threading.Event,
gpu_comm_done_event: torch.supa.Event,
gpu_compute_done_event: torch.supa.Event,
schedule: str = "default"):
self.id = id
self.comm_stream = comm_stream
self.compute_stream = compute_stream
self.forward_context = forward_context
self.ready_barrier = ready_barrier
self.cpu_wait_event = cpu_wait_event
self.cpu_signal_event = cpu_signal_event
self.current_stream = compute_stream
self.gpu_comm_done_event = gpu_comm_done_event
self.gpu_compute_done_event = gpu_compute_done_event
self.schedule = schedule
self.recv_hook = None
def __enter__(self):
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
_THREAD_ID_TO_CONTEXT[threading.get_ident()] = self.id
_CURRENT_CONTEXTS[self.id] = self
self.ready_barrier.wait()
self.cpu_wait_event.wait()
self.cpu_wait_event.clear()
self._restore_context()
# Assume we want to start on the compute stream
self.update_stream(self.compute_stream)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
_CURRENT_CONTEXTS[self.id] = None
del _THREAD_ID_TO_CONTEXT[threading.get_ident()]
self.maybe_run_recv_hook()
self.cpu_signal_event.set()
self.cpu_wait_event.clear()
return False
def _restore_context(self):
forward_context._forward_context = self.forward_context
def update_stream(self, stream):
self.current_stream = stream
if current_stream() != self.current_stream:
torch.supa.set_stream(self.current_stream)
def _signal_comm_done(self):
self.gpu_comm_done_event.record(self.comm_stream)
def _signal_compute_done(self):
self.gpu_compute_done_event.record(self.compute_stream)
def _wait_compute_done(self):
self.comm_stream.wait_event(self.gpu_compute_done_event)
def _wait_comm_done(self):
self.compute_stream.wait_event(self.gpu_comm_done_event)
def _cpu_yield(self):
# It is critical for correctness that only one thread is running
# at a time. These asserts just make sure that this is the only
# thread running before waking the other one up and going to sleep
assert forward_context._forward_context == self.forward_context
assert current_stream() == self.current_stream
assert not self.cpu_wait_event.is_set()
self.cpu_signal_event.set()
self.cpu_wait_event.wait()
self.cpu_wait_event.clear()
self._restore_context()
def switch_to_comm(self):
self.update_stream(self.comm_stream)
def switch_to_compute(self):
self.update_stream(self.compute_stream)
def switch_to_comm_sync(self):
self._signal_compute_done()
self.update_stream(self.comm_stream)
self._wait_compute_done()
def switch_to_compute_sync(self):
self._signal_comm_done()
self.update_stream(self.compute_stream)
self._wait_comm_done()
def maybe_run_recv_hook(self):
if self.recv_hook is not None:
self.recv_hook()
self.recv_hook = None
def yield_(self):
self.current_stream = current_stream()
self._cpu_yield()
self.update_stream(self.current_stream)
def yield_and_switch_from_compute_to_comm(self):
assert current_stream() == self.compute_stream
self._signal_compute_done()
self._cpu_yield()
assert self.current_stream == self.compute_stream
self.update_stream(self.comm_stream)
self._wait_compute_done()
def yield_and_switch_from_comm_to_compute(self):
assert current_stream() == self.comm_stream
self._signal_comm_done()
self._cpu_yield()
assert self.current_stream == self.comm_stream
self.update_stream(self.compute_stream)
self._wait_comm_done()
def supa_make_ubatch_contexts(
num_micro_batches: int,
compute_stream: torch.supa.Stream,
comm_stream: torch.supa.Stream,
forward_contexts: list[ForwardContext],
ready_barrier: threading.Barrier,
schedule: str = "default",
) -> list[UBatchContext]:
assert num_micro_batches == 2, "only been tested with 2 micro-batches"
"""
Create a context manager for micro-batching synchronization.
"""
cpu_events = [threading.Event() for _ in range(num_micro_batches)]
gpu_comm_done_events = [
torch.supa.Event() for _ in range(num_micro_batches)
]
gpu_compute_done_events = [
torch.supa.Event() for _ in range(num_micro_batches)
]
assert len(forward_contexts) == 2
ctxs = []
for i in range(num_micro_batches):
ctx = UBatchContext(id=i,
compute_stream=compute_stream,
comm_stream=comm_stream,
forward_context=forward_contexts[i],
ready_barrier=ready_barrier,
cpu_wait_event=cpu_events[i],
cpu_signal_event=cpu_events[(i + 1) %
num_micro_batches],
gpu_comm_done_event=gpu_comm_done_events[i],
gpu_compute_done_event=gpu_compute_done_events[i],
schedule=schedule)
ctxs.append(ctx)
return ctxs
UBatchContext = SUPAUBatchContext
make_ubatch_contexts = supa_make_ubatch_contexts

View File

@@ -0,0 +1,86 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from typing import TYPE_CHECKING, Optional
import torch
from vllm.model_executor.models.utils import extract_layer_index
from vllm.platforms import current_platform
if TYPE_CHECKING:
from vllm.attention.layer import Attention
def bind_kv_cache(
kv_caches: dict[str, torch.Tensor],
forward_context: dict[str, "Attention"],
runner_kv_caches: list[torch.Tensor],
num_attn_module: Optional[int] = 1,
) -> None:
"""
Bind the allocated KV cache to both ModelRunner and forward context so
that the KV cache can be used in the forward pass.
This function:
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
kv_caches.
2) Associates each attention layer in the `forward_context` with its
corresponding KV cache in kv_caches.
Args:
kv_caches: The allocated kv_caches with layer names as keys.
forward_context: The global forward context containing all Attention
layers with layer names as keys.
runner_kv_caches: The kv_cache declared by ModelRunner.
"""
# Bind kv_caches to ModelRunner
assert len(runner_kv_caches) == 0
# Convert kv_caches dict to a list of tensors in the order of layer_index.
index2name = defaultdict(list)
for layer_name in kv_caches:
index2name[extract_layer_index(layer_name,
num_attn_module)].append(layer_name)
for layer_index in sorted(index2name.keys()):
layer_names = index2name[layer_index]
if len(layer_names) > 1:
# One typical case is encoder-decoder model, e.g., bart.
# The cross attention and self attention in the same decoder layer
# has different layer_name but the same layer_index.
# TODO - analyze where runner_kv_caches is used and the right
# way to ensure it properly reflects multiple attention layers
# in the same decoder block.
if current_platform.is_cuda() or current_platform.is_xpu(
) or current_platform.is_supa():
# We know that the GPU runner is not impacted by this
# case. Some test code depends on runner_kv_caches, but
# not in a way that's impacted by ignoring this.
pass
else:
raise NotImplementedError
layer_name = layer_names[0]
runner_kv_caches.append(kv_caches[layer_name])
# Bind kv_caches to forward context
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]

429
vllm_br/v1/worker/worker.py Normal file
View File

@@ -0,0 +1,429 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# SPDX-License-Identifier: Apache-2.0
"""A GPU worker class."""
import copy
import datetime
import gc
from typing import TYPE_CHECKING, Optional, Union
import torch
import torch.nn as nn
import vllm.envs as envs
import vllm_br.envs as br_envs
from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
DraftTokenIds, ModelRunnerOutput)
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.worker_base import WorkerBase
from vllm_br.platform import SUPAPlatform
from vllm_br.utils import GiB_bytes, SUPAMemorySnapshot
from vllm_br.v1.worker.model_runner import SUPAModelRunner
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
class SUPAWorker(WorkerBase):
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
):
super().__init__(
vllm_config=vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
)
self.kv_transfer_config = vllm_config.kv_transfer_config
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
# Buffers saved before sleep
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info(
"Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir,
)
self.profiler = torch.profiler.profile(
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir, use_gzip=True),
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.SUPA, # type: ignore
],
schedule=torch.profiler.schedule(wait=0,
warmup=0,
active=1,
repeat=1),
profile_memory=False,
record_shapes=True,
with_stack=False,
use_supa_simple=True, # type: ignore
)
else:
self.profiler = None
def sleep(self, level: int = 1) -> None:
raise NotImplementedError
def wake_up(self, tags: Optional[list[str]] = None) -> None:
raise NotImplementedError
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
def init_device(self):
if self.device_config.device.type == "supa":
self.device = torch.device(f"supa:{self.local_rank}")
if self.kv_transfer_config is not None:
device_cursor = self.kv_transfer_config.get_from_extra_config(
"device_cursor", 0)
self.device = torch.device(
f"supa:{self.local_rank + int(device_cursor)}")
SUPAPlatform.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype)
# Initialize the distributed environment BEFORE taking
# memory snapshot
# This ensures SUCCL buffers are allocated before we measure
# available memory
self._init_worker_distributed_environment()
# Set random seed.
set_random_seed(self.model_config.seed)
gc.collect()
torch.supa.empty_cache()
self.init_gpu_memory = SUPAPlatform.get_device_total_memory()
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Construct the model runner
self.model_runner: SUPAModelRunner = SUPAModelRunner( # type: ignore
self.vllm_config, self.device)
if self.rank == 0:
# If usage stat is enabled, collect relevant info.
report_usage_stats(self.vllm_config)
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
# to hijack tensor allocation.
def load_model(self) -> None:
if self.vllm_config.model_config.enable_sleep_mode:
raise NotImplementedError('SUPA do not support sleep mode')
else:
from contextlib import nullcontext
context = nullcontext()
with context:
self.model_runner.load_model()
@torch.inference_mode()
def determine_available_memory(self) -> int:
"""Profiles the peak memory usage of the model to determine how much
memory can be used for KV cache without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the free memory that can be used for KV cache in
bytes.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
torch.supa.empty_cache()
_, total_gpu_memory = torch.supa.mem_get_info()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
before_profile = SUPAMemorySnapshot()
after_profile = SUPAMemorySnapshot()
before_profile.measure()
self.model_runner.profile_run()
after_profile.measure()
free_gpu_memory, _ = torch.supa.mem_get_info()
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
assert self.init_gpu_memory > free_gpu_memory, (
"Error in memory profiling. "
f"Initial free memory {self.init_gpu_memory}, current free memory"
f" {free_gpu_memory}. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.")
# GPU did not change their memory usage during the profiling.
peak_memory = torch.supa.memory_allocated()
# Check for any memory left around that may have been allocated on the
# gpu outside of `torch`. NCCL operations, for example, can use a few
# GB during a forward pass
torch.supa.empty_cache()
torch_allocated_bytes = SUPAPlatform.get_memory_stats(
self.device, "allocated_bytes.all.current")
total_allocated_bytes = (torch.supa.mem_get_info()[1] -
torch.supa.mem_get_info()[0])
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
#if non_torch_allocations > 0:
# peak_memory += non_torch_allocations
available_kv_cache_memory = (
total_gpu_memory * self.cache_config.gpu_memory_utilization -
peak_memory)
memory_for_current_instance = total_gpu_memory * \
self.cache_config.gpu_memory_utilization
diff_profile = after_profile - before_profile
msg = (f"Memory profiling takes {diff_profile.timestamp:.2f} seconds\n"
"the current vLLM instance can use "
"total_gpu_memory "
f"({(total_gpu_memory / GiB_bytes):.2f}GiB)"
" x gpu_memory_utilization "
f"({self.cache_config.gpu_memory_utilization:.2f})"
f" = {(memory_for_current_instance / GiB_bytes):.2f}GiB\n"
"model weights take "
f"{(self.model_runner.model_memory_usage / GiB_bytes):.2f}GiB;"
" non_torch_memory takes "
f"{(non_torch_allocations / GiB_bytes):.2f}GiB;"
" PyTorch activation peak memory takes "
f"{(diff_profile.torch_peak / GiB_bytes):.2f}GiB;"
" the rest of the memory reserved for KV Cache is "
f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.")
logger.info(msg)
return int(available_kv_cache_memory)
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec()
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
if self.vllm_config.model_config.enable_sleep_mode:
raise NotImplementedError('SUPA do not support sleep mode')
else:
from contextlib import nullcontext
context = nullcontext()
with context:
self.model_runner.initialize_kv_cache(kv_cache_config)
def compile_or_warm_up_model(self) -> None:
# warm up sizes that are not in cudagraph capture sizes,
# but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill.
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
if not self.model_config.enforce_eager:
warmup_sizes = [
x for x in warmup_sizes
if x not in self.scheduler_config.cuda_graph_sizes
]
for size in sorted(warmup_sizes, reverse=True):
logger.info("Compile and warming up model for size %d", size)
self.model_runner._dummy_run(size,
skip_eplb=True,
remove_lora=False)
self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
if not self.model_config.enforce_eager:
self.model_runner.capture_model()
# Warm up sampler and preallocate memory buffer for logits and other
# sampling related tensors of max possible shape to avoid memory
# fragmentation issue.
# NOTE: This is called after `capture_model` on purpose to prevent
# memory buffers from being cleared by `SUPAPlatform.empty_cache`.
if get_pp_group().is_last_rank:
max_num_reqs = min(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens,
)
hidden_states, last_hidden_states = \
self.model_runner._dummy_run(
num_tokens=max_num_reqs,
skip_eplb=True,
)
if self.model_runner.is_pooling_model:
self.model_runner._dummy_pooler_run(hidden_states)
else:
self.model_runner._dummy_sampler_run(
hidden_states=last_hidden_states)
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.model_runner.get_supported_tasks()
@torch.inference_mode()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]:
intermediate_tensors = None
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
if forward_pass and not get_pp_group().is_first_rank:
# intermediate_tensors = IntermediateTensors(
# get_pp_group().recv_tensor_dict(
# all_gather_group=get_tp_group()))
# use cpu send/recv
if br_envs.VLLM_PP_CPU_SEND_RECV:
cpu_dict = get_pp_group().recv_tensor_dict()
gpu_dict = {
k: v.to(torch.supa.current_device())
for k, v in cpu_dict.items()
}
intermediate_tensors = IntermediateTensors(gpu_dict)
else:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict())
output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)
if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
return output
assert isinstance(output, IntermediateTensors)
parallel_config = self.vllm_config.parallel_config
assert parallel_config.distributed_executor_backend != (
"external_launcher") and not get_pp_group().is_last_rank
# use cpu send/recv
if br_envs.VLLM_PP_CPU_SEND_RECV:
cpu_dict = {k: v.cpu() for k, v in output.tensors.items()}
get_pp_group().send_tensor_dict(cpu_dict)
else:
get_pp_group().send_tensor_dict(output.tensors)
kv_connector_output = output.kv_connector_output
if not kv_connector_output:
return None
# In case of PP with kv transfer, we need to pass through the
# kv_connector_output
if (not kv_connector_output.finished_sending
and not kv_connector_output.finished_recving):
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.kv_connector_output = kv_connector_output
return output
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
return self.model_runner.take_draft_token_ids()
def profile(self, is_start: bool = True):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
if is_start:
self.profiler.start()
else:
self.profiler.stop()
def execute_dummy_batch(self) -> None:
self.model_runner._dummy_run(1)
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.model_runner.remove_lora(lora_id)
def list_loras(self) -> set[int]:
return self.model_runner.list_loras()
def pin_lora(self, lora_id: int) -> bool:
return self.model_runner.pin_lora(lora_id)
def check_health(self) -> None:
# worker will always be healthy as long as it's running.
return
def save_sharded_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
from vllm.model_executor.model_loader.loader import ShardedStateLoader
ShardedStateLoader.save_model(
self.model_runner.model,
path,
pattern=pattern,
max_size=max_size,
)
def _init_worker_distributed_environment(self) -> None:
"""Initialize the distributed environment."""
set_custom_all_reduce(
not self.parallel_config.disable_custom_all_reduce)
init_distributed_environment(self.parallel_config.world_size,
self.rank,
self.distributed_init_method,
self.local_rank,
"sccl",
timeout=datetime.timedelta(seconds=100))
ensure_model_parallel_initialized(
self.parallel_config.tensor_parallel_size,
self.parallel_config.pipeline_parallel_size)
ensure_kv_transfer_initialized(self.vllm_config)
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.
# TODO: add checkers
return
if torch_dtype == torch.bfloat16: # noqa: SIM102
capability = SUPAPlatform.get_device_capability()
gpu_name = SUPAPlatform.get_device_name()
if capability is None:
compute_str = "does not have a compute capability"
else:
version_str = capability.as_version_str()
compute_str = f"has compute capability {version_str}"
raise ValueError(
"Bfloat16 is only supported on GPUs with compute capability "
f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
"You can use float16 instead by explicitly setting the "
"`dtype` flag in CLI, for example: --dtype=half.")