first commit
This commit is contained in:
15
vllm_br/v1/worker/__init__.py
Normal file
15
vllm_br/v1/worker/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
BIN
vllm_br/v1/worker/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm_br/v1/worker/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vllm_br/v1/worker/__pycache__/model_runner.cpython-310.pyc
Normal file
BIN
vllm_br/v1/worker/__pycache__/model_runner.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
vllm_br/v1/worker/__pycache__/ubatching.cpython-310.pyc
Normal file
BIN
vllm_br/v1/worker/__pycache__/ubatching.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm_br/v1/worker/__pycache__/utils.cpython-310.pyc
Normal file
BIN
vllm_br/v1/worker/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm_br/v1/worker/__pycache__/worker.cpython-310.pyc
Normal file
BIN
vllm_br/v1/worker/__pycache__/worker.cpython-310.pyc
Normal file
Binary file not shown.
49
vllm_br/v1/worker/kv_connector_model_runner_mixin.py
Normal file
49
vllm_br/v1/worker/kv_connector_model_runner_mixin.py
Normal file
@@ -0,0 +1,49 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
import copy
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||
KVConnectorModelRunnerMixin)
|
||||
from vllm_br.forward_context import set_forward_context
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
|
||||
# @staticmethod
|
||||
def kv_connector_no_forward(scheduler_output: "SchedulerOutput",
|
||||
vllm_config: VllmConfig) -> ModelRunnerOutput:
|
||||
# KV send/recv even if no work to do.
|
||||
with set_forward_context(
|
||||
None,
|
||||
vllm_config), KVConnectorModelRunnerMixin._get_kv_connector_output(
|
||||
scheduler_output, wait_for_save=False) as kv_connector_output:
|
||||
pass
|
||||
|
||||
if (not kv_connector_output.finished_sending
|
||||
and not kv_connector_output.finished_recving):
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
output.kv_connector_output = kv_connector_output
|
||||
return output
|
||||
|
||||
|
||||
KVConnectorModelRunnerMixin.kv_connector_no_forward = kv_connector_no_forward
|
||||
4595
vllm_br/v1/worker/model_runner.py
Normal file
4595
vllm_br/v1/worker/model_runner.py
Normal file
File diff suppressed because it is too large
Load Diff
413
vllm_br/v1/worker/supa_ubatch_wrapper.py
Normal file
413
vllm_br/v1/worker/supa_ubatch_wrapper.py
Normal file
@@ -0,0 +1,413 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
set_graph_pool_id)
|
||||
from vllm.forward_context import (create_forward_context, get_forward_context,
|
||||
override_forward_context)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
|
||||
from vllm_br.compilation.supa_graph import SUPAGraphWrapper
|
||||
from vllm_br.config.compilation import SUPAGraphMode
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UbatchMetadata:
|
||||
context: UBatchContext
|
||||
input_ids: torch.Tensor
|
||||
positions: torch.Tensor
|
||||
inputs_embeds: Optional[torch.Tensor]
|
||||
intermediate_tensors: Optional[IntermediateTensors]
|
||||
num_tokens: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class SUPAGraphMetaData:
|
||||
supagraph: torch.supa.SUPAGraph
|
||||
ubatch_metadata: UbatchMetadata
|
||||
outputs: Optional[Any] = None
|
||||
|
||||
|
||||
class SMControlContextManager:
|
||||
|
||||
def __init__(self, comm_sms: int, set_comm_sms: Callable[[int], None],
|
||||
set_compute_sms: Callable[[int], None]):
|
||||
"""
|
||||
Context manager for controlling SM (Streaming Multiprocessor)
|
||||
allocation. Upon entering the context, it sets the number of SMs
|
||||
allocated for communication and computation to comm_sms and
|
||||
total_sms - comm_sms respectively. Upon exiting, it restores the
|
||||
allocation to use all available SMs (i.e. total_sms).
|
||||
|
||||
Args:
|
||||
comm_sms (int): The number of SMs to allocate for communication.
|
||||
(The remainder will be used for computation.)
|
||||
set_comm_sms (Callable[[int], None]):
|
||||
A function that sets the number of SMs for communication.
|
||||
set_compute_sms (Callable[[int], None]):
|
||||
A function that sets the number of SMs for computation.
|
||||
"""
|
||||
|
||||
assert current_platform.is_supa(), \
|
||||
"SM control is currently only supported on SUPA"
|
||||
|
||||
props = torch.supa.get_device_properties(torch.supa.current_device())
|
||||
total_sms = props.multi_processor_count
|
||||
|
||||
assert comm_sms < total_sms
|
||||
self.total_sms = total_sms
|
||||
self.compute_sms = total_sms - comm_sms
|
||||
self.comm_sms = comm_sms
|
||||
self.set_comm_sms = set_comm_sms
|
||||
self.set_compute_sms = set_compute_sms
|
||||
|
||||
def __enter__(self):
|
||||
self.set_comm_sms(self.comm_sms)
|
||||
self.set_compute_sms(self.compute_sms)
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.set_comm_sms(self.total_sms)
|
||||
self.set_compute_sms(self.total_sms)
|
||||
|
||||
|
||||
class UBatchWrapper:
|
||||
|
||||
def __init__(self, runnable: Callable, vllm_config: VllmConfig,
|
||||
runtime_mode: SUPAGraphMode, device: torch.supa.device):
|
||||
self.runnable = runnable
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.comm_stream = torch.supa.Stream(device=device)
|
||||
# Two ubatch threads plus the main thread
|
||||
self.ready_barrier = threading.Barrier(3)
|
||||
|
||||
self.supagraphs: dict[int, SUPAGraphMetaData] = {}
|
||||
|
||||
self.supagraph_wrapper = None
|
||||
self.graph_pool = None
|
||||
if runtime_mode is not SUPAGraphMode.NONE:
|
||||
self.supagraph_wrapper = SUPAGraphWrapper(
|
||||
runnable, vllm_config, runtime_mode=runtime_mode)
|
||||
self.graph_pool = current_platform.get_global_graph_pool()
|
||||
|
||||
self.sm_control = self._create_sm_control_context(vllm_config)
|
||||
self.device = device
|
||||
|
||||
@staticmethod
|
||||
def _create_sm_control_context(vllm_config: VllmConfig):
|
||||
comm_sms = envs.VLLM_DBO_COMM_SMS
|
||||
|
||||
set_comm_sms = lambda sms: None
|
||||
if vllm_config.parallel_config.enable_expert_parallel:
|
||||
# Currently only DeepEP highthroughput supports SM control so this
|
||||
# only affects that case.
|
||||
all2all_manager = get_ep_group(
|
||||
).device_communicator.all2all_manager
|
||||
|
||||
if all2all_manager.max_sms_used() is not None:
|
||||
comm_sms = min(comm_sms, all2all_manager.max_sms_used())
|
||||
|
||||
if comm_sms > 0:
|
||||
set_comm_sms = lambda sms: all2all_manager.set_num_sms(sms)
|
||||
|
||||
# TODO(lucas): support other kernels besides DeepGEMM
|
||||
set_compute_sms = lambda sms: None
|
||||
if has_deep_gemm() and comm_sms > 0:
|
||||
import deep_gemm as dg
|
||||
set_compute_sms = lambda sms: dg.set_num_sms(sms)
|
||||
|
||||
return SMControlContextManager(comm_sms=comm_sms,
|
||||
set_comm_sms=set_comm_sms,
|
||||
set_compute_sms=set_compute_sms)
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
# allow accessing the attributes of the runnable.
|
||||
if hasattr(self.runnable, key):
|
||||
return getattr(self.runnable, key)
|
||||
raise AttributeError(f"Attribute {key} not exists in the runnable of "
|
||||
f"supagraph wrapper: {self.runnable}")
|
||||
|
||||
def unwrap(self) -> Callable:
|
||||
# in case we need to access the original runnable.
|
||||
return self.runnable
|
||||
|
||||
def _capture_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
|
||||
"""
|
||||
Capture a supagraph for a microbatched run.
|
||||
|
||||
The logic here is somewhat complicated because we need to make sure that
|
||||
each of the ubatch threads initialize the supa context before we start
|
||||
the graph capture.
|
||||
|
||||
The flow is as follows:
|
||||
1. The main thread starts up each ubatch thread. Each thread will
|
||||
initialize its supa context (torch.supa.current_blas_handle())
|
||||
before going to sleep upon entering the ubatch_context.
|
||||
|
||||
2. The main thread starts the graph capture and wakes up the first
|
||||
ubatch thread.
|
||||
|
||||
3. Each ubatch thread runs the model to completion and returns the
|
||||
completed output tensors back to the main thread.
|
||||
|
||||
4. The main thread stores the captured supagraph along with its metadata
|
||||
and returns
|
||||
"""
|
||||
|
||||
@torch.inference_mode()
|
||||
def _capture_ubatch_thread(results, ubatch_metadata):
|
||||
torch.supa.set_device(self.device)
|
||||
ubatch_context = ubatch_metadata.context
|
||||
with torch.supa.stream(ubatch_context.compute_stream):
|
||||
_ = torch.supa.current_blas_handle()
|
||||
with torch.supa.stream(ubatch_context.comm_stream):
|
||||
_ = torch.supa.current_blas_handle()
|
||||
with ubatch_context:
|
||||
model_output = model(
|
||||
input_ids=ubatch_metadata.input_ids,
|
||||
positions=ubatch_metadata.positions,
|
||||
intermediate_tensors=ubatch_metadata.intermediate_tensors,
|
||||
inputs_embeds=ubatch_metadata.inputs_embeds,
|
||||
)
|
||||
|
||||
results.append((ubatch_metadata.context.id, model_output))
|
||||
|
||||
results: list[tuple[int, torch.Tensor]] = []
|
||||
compute_stream = ubatch_metadata[0].context.compute_stream
|
||||
num_tokens = ubatch_metadata[0].num_tokens + \
|
||||
ubatch_metadata[1].num_tokens
|
||||
|
||||
# Ubatches will manually manage the forward context, so we override
|
||||
# it to None here so we can have it restored correctly later
|
||||
with override_forward_context(None):
|
||||
ubatch_threads = []
|
||||
for metadata in ubatch_metadata:
|
||||
thread = threading.Thread(target=_capture_ubatch_thread,
|
||||
args=(
|
||||
results,
|
||||
metadata,
|
||||
))
|
||||
ubatch_threads.append(thread)
|
||||
thread.start()
|
||||
self.ready_barrier.wait() # Wait for both threads to be ready
|
||||
|
||||
# Capture the supagraph
|
||||
supagraph_metadata = \
|
||||
SUPAGraphMetaData(
|
||||
supagraph=torch.supa.SUPAGraph(),
|
||||
ubatch_metadata=ubatch_metadata,
|
||||
)
|
||||
if self.graph_pool is not None:
|
||||
set_graph_pool_id(self.graph_pool)
|
||||
else:
|
||||
set_graph_pool_id(current_platform.graph_pool_handle())
|
||||
with torch.supa.graph(supagraph_metadata.supagraph,
|
||||
stream=compute_stream,
|
||||
pool=self.graph_pool):
|
||||
ubatch_metadata[0].context.cpu_wait_event.set()
|
||||
for thread in ubatch_threads:
|
||||
thread.join()
|
||||
sorted_results = [value for position, value in sorted(results)]
|
||||
result = torch.cat(sorted_results, dim=0)
|
||||
supagraph_metadata.outputs = result
|
||||
self.supagraphs[num_tokens] = supagraph_metadata
|
||||
return supagraph_metadata.outputs
|
||||
|
||||
def _run_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
|
||||
|
||||
@torch.inference_mode()
|
||||
def _ubatch_thread(results, model, ubatch_metadata):
|
||||
with ubatch_metadata.context:
|
||||
model_output = model(
|
||||
input_ids=ubatch_metadata.input_ids,
|
||||
positions=ubatch_metadata.positions,
|
||||
intermediate_tensors=ubatch_metadata.intermediate_tensors,
|
||||
inputs_embeds=ubatch_metadata.inputs_embeds,
|
||||
)
|
||||
results.append((ubatch_metadata.context.id, model_output))
|
||||
|
||||
results: list[tuple[int, torch.Tensor]] = []
|
||||
|
||||
# Ubatch threads will manually manage the forward context, so we
|
||||
# override it to None here so we can have it restored correctly
|
||||
# after both threads have finished
|
||||
with override_forward_context(None):
|
||||
ubatch_threads = []
|
||||
for metadata in ubatch_metadata:
|
||||
thread = threading.Thread(target=_ubatch_thread,
|
||||
args=(
|
||||
results,
|
||||
model,
|
||||
metadata,
|
||||
))
|
||||
ubatch_threads.append(thread)
|
||||
thread.start()
|
||||
self.ready_barrier.wait() # Wait for both threads to be ready
|
||||
ubatch_metadata[0].context.cpu_wait_event.set()
|
||||
for thread in ubatch_threads:
|
||||
thread.join()
|
||||
sorted_results = [value for position, value in sorted(results)]
|
||||
result = torch.cat(sorted_results, dim=0)
|
||||
return result
|
||||
|
||||
def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, input_ids,
|
||||
positions, inputs_embeds, intermediate_tensors,
|
||||
compute_stream, dp_metadata, batch_descriptor,
|
||||
supagraph_runtime_mode) -> list[UbatchMetadata]:
|
||||
|
||||
# Create one forward context per ubatch
|
||||
forward_contexts = []
|
||||
for i, ubatch_slice in enumerate(ubatch_slices):
|
||||
forward_contexts.append(
|
||||
create_forward_context(
|
||||
attn_metadata[i] if attn_metadata is not None else None,
|
||||
self.vllm_config,
|
||||
dp_metadata=dp_metadata,
|
||||
batch_descriptor=batch_descriptor,
|
||||
supagraph_runtime_mode=supagraph_runtime_mode))
|
||||
|
||||
ubatch_ctxs = make_ubatch_contexts(
|
||||
num_micro_batches=len(ubatch_slices),
|
||||
comm_stream=self.comm_stream,
|
||||
compute_stream=compute_stream,
|
||||
forward_contexts=forward_contexts,
|
||||
ready_barrier=self.ready_barrier)
|
||||
|
||||
ubatch_metadata: list[UbatchMetadata] = []
|
||||
for i, ubatch_slice in enumerate(ubatch_slices):
|
||||
sliced_input_ids, sliced_positions, sliced_inputs_embeds, \
|
||||
sliced_intermediate_tensors = \
|
||||
self._slice_model_inputs(
|
||||
ubatch_slice.token_slice, input_ids, positions,
|
||||
inputs_embeds, intermediate_tensors)
|
||||
ubatch_metadata.append(
|
||||
UbatchMetadata(
|
||||
context=ubatch_ctxs[i],
|
||||
input_ids=sliced_input_ids,
|
||||
positions=sliced_positions,
|
||||
inputs_embeds=sliced_inputs_embeds,
|
||||
intermediate_tensors=sliced_intermediate_tensors,
|
||||
num_tokens=ubatch_slice.token_slice.stop -
|
||||
ubatch_slice.token_slice.start))
|
||||
|
||||
return ubatch_metadata
|
||||
|
||||
def _slice_model_inputs(self, tokens_slice: slice, input_ids, positions,
|
||||
inputs_embeds, intermediate_tensors):
|
||||
sliced_input_ids = input_ids[tokens_slice]
|
||||
# if we are using mrope. Mrope adds an additional dimension to the
|
||||
# positions tensor
|
||||
if positions.ndim == 2:
|
||||
sliced_positions = positions[:, tokens_slice]
|
||||
else:
|
||||
sliced_positions = positions[tokens_slice]
|
||||
sliced_inputs_embeds = inputs_embeds[
|
||||
tokens_slice] if inputs_embeds else None
|
||||
sliced_intermediate_tensors = intermediate_tensors[
|
||||
tokens_slice] if intermediate_tensors else None
|
||||
|
||||
return (sliced_input_ids, sliced_positions, sliced_inputs_embeds,
|
||||
sliced_intermediate_tensors)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
forward_context = get_forward_context()
|
||||
batch_descriptor = forward_context.batch_descriptor
|
||||
ubatch_slices = forward_context.ubatch_slices
|
||||
supagraph_runtime_mode = forward_context.cudagraph_runtime_mode
|
||||
|
||||
# If there's no ubatching, just run the runnable object
|
||||
if ubatch_slices is None:
|
||||
|
||||
# This is to account for the case where ubatching was aborted.
|
||||
# When we capture full graphs we only capture one graph per shape,
|
||||
# meaning that if we have a ubatched supagraph for the current
|
||||
# num_tokens, we don't have a non-ubatched one. Without this
|
||||
# check, the supagraph wrapper will try to capture a supagraph
|
||||
# for this shape during a normal run.
|
||||
if supagraph_runtime_mode is SUPAGraphMode.FULL:
|
||||
assert batch_descriptor is not None
|
||||
if batch_descriptor.num_tokens in self.supagraphs:
|
||||
supagraph_runtime_mode = SUPAGraphMode.NONE
|
||||
|
||||
if supagraph_runtime_mode in (SUPAGraphMode.NONE,
|
||||
SUPAGraphMode.PIECEWISE):
|
||||
return self.runnable(*args, **kwargs)
|
||||
else:
|
||||
assert self.supagraph_wrapper is not None
|
||||
return self.supagraph_wrapper(*args, **kwargs)
|
||||
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
num_tokens = (ubatch_slices[0].token_slice.stop -
|
||||
ubatch_slices[0].token_slice.start) * 2
|
||||
input_ids = kwargs['input_ids']
|
||||
positions = kwargs['positions']
|
||||
intermediate_tensors = kwargs['intermediate_tensors']
|
||||
inputs_embeds = kwargs['inputs_embeds']
|
||||
compute_stream = torch.supa.current_stream()
|
||||
|
||||
dp_metadata = forward_context.dp_metadata
|
||||
|
||||
# We shouldn't be here unless we are running with multiple DP ranks
|
||||
assert dp_metadata is not None
|
||||
|
||||
if num_tokens not in self.supagraphs \
|
||||
and supagraph_runtime_mode is SUPAGraphMode.FULL:
|
||||
ubatch_metadata = self._make_ubatch_metadata(
|
||||
ubatch_slices=ubatch_slices,
|
||||
attn_metadata=attn_metadata,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
compute_stream=compute_stream,
|
||||
dp_metadata=dp_metadata,
|
||||
batch_descriptor=batch_descriptor,
|
||||
supagraph_runtime_mode=SUPAGraphMode.NONE)
|
||||
with self.sm_control:
|
||||
return self._capture_ubatches(ubatch_metadata, self.model)
|
||||
elif num_tokens in self.supagraphs \
|
||||
and supagraph_runtime_mode is SUPAGraphMode.FULL:
|
||||
supagraph_metadata = self.supagraphs[num_tokens]
|
||||
supagraph_metadata.supagraph.replay()
|
||||
return supagraph_metadata.outputs
|
||||
else:
|
||||
ubatch_metadata = self._make_ubatch_metadata(
|
||||
ubatch_slices=ubatch_slices,
|
||||
attn_metadata=attn_metadata,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
compute_stream=compute_stream,
|
||||
dp_metadata=dp_metadata,
|
||||
batch_descriptor=batch_descriptor,
|
||||
supagraph_runtime_mode=SUPAGraphMode.NONE)
|
||||
with self.sm_control:
|
||||
return self._run_ubatches(ubatch_metadata, self.model)
|
||||
155
vllm_br/v1/worker/supagraph_dispatcher.py
Normal file
155
vllm_br/v1/worker/supagraph_dispatcher.py
Normal file
@@ -0,0 +1,155 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from vllm.config import CompilationLevel, VllmConfig
|
||||
from vllm.logger import logger
|
||||
from vllm_br.config.compilation import SUPAGraphMode
|
||||
from vllm_br.forward_context import BatchDescriptor
|
||||
|
||||
_BATCH_SIZE_ALIGNMENT = 8
|
||||
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
|
||||
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 34)
|
||||
]
|
||||
|
||||
|
||||
class SupagraphDispatcher:
|
||||
"""
|
||||
Runtime supagraph dispatcher to dispatch keys for multiple set of
|
||||
supagraphs.
|
||||
|
||||
The dispatcher stores two sets of dispatch keys, one for PIECEWISE and one
|
||||
for FULL supagraph runtime mode. The keys are initialized depending on
|
||||
attention support and what supagraph mode is set in CompilationConfig. The
|
||||
keys stored in dispatcher are the only source of truth for valid
|
||||
supagraphs that can be dispatched at runtime.
|
||||
|
||||
At runtime, the dispatch method generates the runtime supagraph mode (FULL,
|
||||
PIECEWISE, or NONE for no supagraph) and the valid key (batch descriptor)
|
||||
based on the input key. After dispatching (communicate via forward context),
|
||||
the supagraph wrappers will trust the dispatch key to do either capturing
|
||||
or replaying (if mode matched), or pass through to the underlying runnable
|
||||
without supagraph (if mode no match or mode is NONE).
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.use_default_list = vllm_config.compilation_config.max_capture_size > 256 or self.compilation_config.max_capture_size == 0
|
||||
self.capture_list = _BATCH_SIZES_TO_CAPTURE if self.use_default_list else self.compilation_config.cudagraph_capture_sizes
|
||||
|
||||
# TODO(liming): Remove this hard code once we support piecewise
|
||||
self.supagraph_mode = SUPAGraphMode.FULL
|
||||
|
||||
# Dict to store valid supagraph dispatching keys.
|
||||
self.supagraph_keys: dict[SUPAGraphMode, set[BatchDescriptor]] = {
|
||||
SUPAGraphMode.PIECEWISE: set(),
|
||||
SUPAGraphMode.FULL: set(),
|
||||
SUPAGraphMode.FULL_DECODE_ONLY: set(),
|
||||
}
|
||||
|
||||
assert not self.supagraph_mode.requires_piecewise_compilation() or \
|
||||
(self.compilation_config.level == CompilationLevel.PIECEWISE and
|
||||
self.compilation_config.splitting_ops_contain_attention()), \
|
||||
"Compilation level should be CompilationLevel.PIECEWISE when "\
|
||||
"supagraph_mode piecewise supagraphs is used, "\
|
||||
f"supagraph_mode={self.supagraph_mode}, "\
|
||||
f"compilation_level={self.compilation_config.level}, "\
|
||||
f"splitting_ops={self.compilation_config.splitting_ops}"
|
||||
|
||||
self.keys_initialized = False
|
||||
|
||||
def add_supagraph_key(self, runtime_mode: SUPAGraphMode,
|
||||
batch_descriptor: BatchDescriptor):
|
||||
assert runtime_mode in [SUPAGraphMode.PIECEWISE,SUPAGraphMode.FULL_DECODE_ONLY, SUPAGraphMode.FULL], \
|
||||
f"Invalid supagraph runtime mode: {runtime_mode}"
|
||||
self.supagraph_keys[runtime_mode].add(batch_descriptor)
|
||||
|
||||
def initialize_supagraph_keys(self, supagraph_mode: SUPAGraphMode,
|
||||
uniform_decode_query_len: int):
|
||||
# This should be called only after attention backend is initialized.
|
||||
|
||||
# Note: we create all valid keys possible for supagraph but do not
|
||||
# guarantee all keys would be used. For example, we create keys for
|
||||
# piecewise supagraphs when it is piecewise compilation, which is always
|
||||
# valid, but for attention backend support unified routine, we may not
|
||||
# trigger capturing/replaying the piecewise supagraphs depending on
|
||||
# CompilationConfig.supagraph_mode. In addition, if we allow lazy
|
||||
# capturing in future PR, some keys may never be triggered.
|
||||
if supagraph_mode == SUPAGraphMode.FULL:
|
||||
max_num_tokens = (uniform_decode_query_len *
|
||||
self.vllm_config.scheduler_config.max_num_seqs)
|
||||
supagraph_capture_sizes_for_decode = [
|
||||
x for x in self.capture_list
|
||||
if x <= max_num_tokens and x >= uniform_decode_query_len
|
||||
]
|
||||
for bs in supagraph_capture_sizes_for_decode:
|
||||
self.add_supagraph_key(
|
||||
supagraph_mode,
|
||||
BatchDescriptor(num_tokens=bs, uniform_decode=True))
|
||||
|
||||
# if decode supagraph mode is FULL, and we don't already have mixed
|
||||
# mode full supagraphs then add them here.
|
||||
if supagraph_mode == SUPAGraphMode.FULL_DECODE_ONLY:
|
||||
max_num_tokens = uniform_decode_query_len * \
|
||||
self.vllm_config.scheduler_config.max_num_seqs
|
||||
supagraph_capture_sizes_for_decode = [
|
||||
x for x in self.capture_list
|
||||
if x <= max_num_tokens and x >= uniform_decode_query_len
|
||||
]
|
||||
for bs in supagraph_capture_sizes_for_decode:
|
||||
self.add_supagraph_key(
|
||||
supagraph_mode,
|
||||
BatchDescriptor(num_tokens=bs, uniform_decode=True))
|
||||
self.keys_initialized = True
|
||||
|
||||
def dispatch(
|
||||
self, batch_descriptor: BatchDescriptor
|
||||
) -> tuple[SUPAGraphMode, Optional[BatchDescriptor]]:
|
||||
"""
|
||||
Given a batch descriptor, dispatch to a supagraph mode.
|
||||
A new batch descriptor is returned as we might dispatch a uniform batch
|
||||
to a graph that supports a more general batch (uniform to non-uniform).
|
||||
"""
|
||||
# if not initialized, just skip dispatching.
|
||||
if not self.keys_initialized:
|
||||
logger.warning_once("supagraph dispatching keys are not "
|
||||
"initialized. No supagraph will be used.")
|
||||
return SUPAGraphMode.NONE, None
|
||||
|
||||
if batch_descriptor in self.supagraph_keys[
|
||||
SUPAGraphMode.FULL_DECODE_ONLY]:
|
||||
return SUPAGraphMode.FULL_DECODE_ONLY, batch_descriptor
|
||||
|
||||
# check if key exists for full supagraph
|
||||
if batch_descriptor in self.supagraph_keys[SUPAGraphMode.FULL]:
|
||||
return SUPAGraphMode.FULL, batch_descriptor
|
||||
|
||||
# # otherwise, check if non-uniform key exists
|
||||
non_uniform_key = batch_descriptor.non_uniform
|
||||
if non_uniform_key in self.supagraph_keys[SUPAGraphMode.FULL]:
|
||||
return SUPAGraphMode.FULL, non_uniform_key
|
||||
|
||||
|
||||
#
|
||||
# # also check if non-uniform key exists for more "general"
|
||||
# # piecewise supagraph
|
||||
# if non_uniform_key in self.supagraph_keys[SUPAGraphMode.PIECEWISE]:
|
||||
# return SUPAGraphMode.PIECEWISE, non_uniform_key
|
||||
|
||||
# finally, just return no supagraphs
|
||||
return SUPAGraphMode.NONE, None
|
||||
195
vllm_br/v1/worker/ubatching.py
Normal file
195
vllm_br/v1/worker/ubatching.py
Normal file
@@ -0,0 +1,195 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
import threading
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import forward_context
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.utils import current_stream
|
||||
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
|
||||
|
||||
|
||||
class SUPAUBatchContext:
|
||||
"""
|
||||
Context manager for micro-batching synchronization using threading events.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
id: int,
|
||||
comm_stream: torch.supa.Stream,
|
||||
compute_stream: torch.supa.Stream,
|
||||
forward_context: ForwardContext,
|
||||
ready_barrier: threading.Barrier,
|
||||
cpu_wait_event: threading.Event,
|
||||
cpu_signal_event: threading.Event,
|
||||
gpu_comm_done_event: torch.supa.Event,
|
||||
gpu_compute_done_event: torch.supa.Event,
|
||||
schedule: str = "default"):
|
||||
self.id = id
|
||||
self.comm_stream = comm_stream
|
||||
self.compute_stream = compute_stream
|
||||
self.forward_context = forward_context
|
||||
self.ready_barrier = ready_barrier
|
||||
self.cpu_wait_event = cpu_wait_event
|
||||
self.cpu_signal_event = cpu_signal_event
|
||||
self.current_stream = compute_stream
|
||||
self.gpu_comm_done_event = gpu_comm_done_event
|
||||
self.gpu_compute_done_event = gpu_compute_done_event
|
||||
self.schedule = schedule
|
||||
self.recv_hook = None
|
||||
|
||||
def __enter__(self):
|
||||
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
|
||||
_THREAD_ID_TO_CONTEXT[threading.get_ident()] = self.id
|
||||
_CURRENT_CONTEXTS[self.id] = self
|
||||
self.ready_barrier.wait()
|
||||
|
||||
self.cpu_wait_event.wait()
|
||||
self.cpu_wait_event.clear()
|
||||
self._restore_context()
|
||||
# Assume we want to start on the compute stream
|
||||
self.update_stream(self.compute_stream)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
|
||||
_CURRENT_CONTEXTS[self.id] = None
|
||||
del _THREAD_ID_TO_CONTEXT[threading.get_ident()]
|
||||
self.maybe_run_recv_hook()
|
||||
self.cpu_signal_event.set()
|
||||
self.cpu_wait_event.clear()
|
||||
return False
|
||||
|
||||
def _restore_context(self):
|
||||
forward_context._forward_context = self.forward_context
|
||||
|
||||
def update_stream(self, stream):
|
||||
self.current_stream = stream
|
||||
if current_stream() != self.current_stream:
|
||||
torch.supa.set_stream(self.current_stream)
|
||||
|
||||
def _signal_comm_done(self):
|
||||
self.gpu_comm_done_event.record(self.comm_stream)
|
||||
|
||||
def _signal_compute_done(self):
|
||||
self.gpu_compute_done_event.record(self.compute_stream)
|
||||
|
||||
def _wait_compute_done(self):
|
||||
self.comm_stream.wait_event(self.gpu_compute_done_event)
|
||||
|
||||
def _wait_comm_done(self):
|
||||
self.compute_stream.wait_event(self.gpu_comm_done_event)
|
||||
|
||||
def _cpu_yield(self):
|
||||
# It is critical for correctness that only one thread is running
|
||||
# at a time. These asserts just make sure that this is the only
|
||||
# thread running before waking the other one up and going to sleep
|
||||
assert forward_context._forward_context == self.forward_context
|
||||
assert current_stream() == self.current_stream
|
||||
assert not self.cpu_wait_event.is_set()
|
||||
|
||||
self.cpu_signal_event.set()
|
||||
self.cpu_wait_event.wait()
|
||||
self.cpu_wait_event.clear()
|
||||
self._restore_context()
|
||||
|
||||
def switch_to_comm(self):
|
||||
self.update_stream(self.comm_stream)
|
||||
|
||||
def switch_to_compute(self):
|
||||
self.update_stream(self.compute_stream)
|
||||
|
||||
def switch_to_comm_sync(self):
|
||||
self._signal_compute_done()
|
||||
self.update_stream(self.comm_stream)
|
||||
self._wait_compute_done()
|
||||
|
||||
def switch_to_compute_sync(self):
|
||||
self._signal_comm_done()
|
||||
self.update_stream(self.compute_stream)
|
||||
self._wait_comm_done()
|
||||
|
||||
def maybe_run_recv_hook(self):
|
||||
if self.recv_hook is not None:
|
||||
self.recv_hook()
|
||||
self.recv_hook = None
|
||||
|
||||
def yield_(self):
|
||||
self.current_stream = current_stream()
|
||||
self._cpu_yield()
|
||||
self.update_stream(self.current_stream)
|
||||
|
||||
def yield_and_switch_from_compute_to_comm(self):
|
||||
assert current_stream() == self.compute_stream
|
||||
self._signal_compute_done()
|
||||
self._cpu_yield()
|
||||
assert self.current_stream == self.compute_stream
|
||||
self.update_stream(self.comm_stream)
|
||||
self._wait_compute_done()
|
||||
|
||||
def yield_and_switch_from_comm_to_compute(self):
|
||||
assert current_stream() == self.comm_stream
|
||||
self._signal_comm_done()
|
||||
self._cpu_yield()
|
||||
assert self.current_stream == self.comm_stream
|
||||
self.update_stream(self.compute_stream)
|
||||
self._wait_comm_done()
|
||||
|
||||
|
||||
def supa_make_ubatch_contexts(
|
||||
num_micro_batches: int,
|
||||
compute_stream: torch.supa.Stream,
|
||||
comm_stream: torch.supa.Stream,
|
||||
forward_contexts: list[ForwardContext],
|
||||
ready_barrier: threading.Barrier,
|
||||
schedule: str = "default",
|
||||
) -> list[UBatchContext]:
|
||||
assert num_micro_batches == 2, "only been tested with 2 micro-batches"
|
||||
"""
|
||||
Create a context manager for micro-batching synchronization.
|
||||
"""
|
||||
cpu_events = [threading.Event() for _ in range(num_micro_batches)]
|
||||
gpu_comm_done_events = [
|
||||
torch.supa.Event() for _ in range(num_micro_batches)
|
||||
]
|
||||
gpu_compute_done_events = [
|
||||
torch.supa.Event() for _ in range(num_micro_batches)
|
||||
]
|
||||
|
||||
assert len(forward_contexts) == 2
|
||||
|
||||
ctxs = []
|
||||
for i in range(num_micro_batches):
|
||||
ctx = UBatchContext(id=i,
|
||||
compute_stream=compute_stream,
|
||||
comm_stream=comm_stream,
|
||||
forward_context=forward_contexts[i],
|
||||
ready_barrier=ready_barrier,
|
||||
cpu_wait_event=cpu_events[i],
|
||||
cpu_signal_event=cpu_events[(i + 1) %
|
||||
num_micro_batches],
|
||||
gpu_comm_done_event=gpu_comm_done_events[i],
|
||||
gpu_compute_done_event=gpu_compute_done_events[i],
|
||||
schedule=schedule)
|
||||
ctxs.append(ctx)
|
||||
|
||||
return ctxs
|
||||
|
||||
|
||||
UBatchContext = SUPAUBatchContext
|
||||
make_ubatch_contexts = supa_make_ubatch_contexts
|
||||
86
vllm_br/v1/worker/utils.py
Normal file
86
vllm_br/v1/worker/utils.py
Normal file
@@ -0,0 +1,86 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.layer import Attention
|
||||
|
||||
|
||||
def bind_kv_cache(
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
forward_context: dict[str, "Attention"],
|
||||
runner_kv_caches: list[torch.Tensor],
|
||||
num_attn_module: Optional[int] = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Bind the allocated KV cache to both ModelRunner and forward context so
|
||||
that the KV cache can be used in the forward pass.
|
||||
|
||||
This function:
|
||||
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
|
||||
kv_caches.
|
||||
2) Associates each attention layer in the `forward_context` with its
|
||||
corresponding KV cache in kv_caches.
|
||||
|
||||
Args:
|
||||
kv_caches: The allocated kv_caches with layer names as keys.
|
||||
forward_context: The global forward context containing all Attention
|
||||
layers with layer names as keys.
|
||||
runner_kv_caches: The kv_cache declared by ModelRunner.
|
||||
"""
|
||||
# Bind kv_caches to ModelRunner
|
||||
assert len(runner_kv_caches) == 0
|
||||
|
||||
# Convert kv_caches dict to a list of tensors in the order of layer_index.
|
||||
index2name = defaultdict(list)
|
||||
for layer_name in kv_caches:
|
||||
index2name[extract_layer_index(layer_name,
|
||||
num_attn_module)].append(layer_name)
|
||||
|
||||
for layer_index in sorted(index2name.keys()):
|
||||
layer_names = index2name[layer_index]
|
||||
if len(layer_names) > 1:
|
||||
# One typical case is encoder-decoder model, e.g., bart.
|
||||
# The cross attention and self attention in the same decoder layer
|
||||
# has different layer_name but the same layer_index.
|
||||
|
||||
# TODO - analyze where runner_kv_caches is used and the right
|
||||
# way to ensure it properly reflects multiple attention layers
|
||||
# in the same decoder block.
|
||||
if current_platform.is_cuda() or current_platform.is_xpu(
|
||||
) or current_platform.is_supa():
|
||||
# We know that the GPU runner is not impacted by this
|
||||
# case. Some test code depends on runner_kv_caches, but
|
||||
# not in a way that's impacted by ignoring this.
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError
|
||||
layer_name = layer_names[0]
|
||||
runner_kv_caches.append(kv_caches[layer_name])
|
||||
|
||||
# Bind kv_caches to forward context
|
||||
for layer_name, kv_cache in kv_caches.items():
|
||||
# NOTE: Use list because of v0 PP virtual engine.
|
||||
forward_context[layer_name].kv_cache = [kv_cache]
|
||||
429
vllm_br/v1/worker/worker.py
Normal file
429
vllm_br/v1/worker/worker.py
Normal file
@@ -0,0 +1,429 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""A GPU worker class."""
|
||||
import copy
|
||||
import datetime
|
||||
import gc
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm_br.envs as br_envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
set_custom_all_reduce)
|
||||
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.logger import logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||
DraftTokenIds, ModelRunnerOutput)
|
||||
from vllm.v1.utils import report_usage_stats
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
from vllm_br.platform import SUPAPlatform
|
||||
from vllm_br.utils import GiB_bytes, SUPAMemorySnapshot
|
||||
from vllm_br.v1.worker.model_runner import SUPAModelRunner
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
|
||||
class SUPAWorker(WorkerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=is_driver_worker,
|
||||
)
|
||||
self.kv_transfer_config = vllm_config.kv_transfer_config
|
||||
if self.model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
|
||||
init_cached_hf_modules()
|
||||
|
||||
# Buffers saved before sleep
|
||||
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
|
||||
|
||||
# Torch profiler. Enabled and configured through env vars:
|
||||
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
logger.info(
|
||||
"Profiling enabled. Traces will be saved to: %s",
|
||||
torch_profiler_trace_dir,
|
||||
)
|
||||
self.profiler = torch.profiler.profile(
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
torch_profiler_trace_dir, use_gzip=True),
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.SUPA, # type: ignore
|
||||
],
|
||||
schedule=torch.profiler.schedule(wait=0,
|
||||
warmup=0,
|
||||
active=1,
|
||||
repeat=1),
|
||||
profile_memory=False,
|
||||
record_shapes=True,
|
||||
with_stack=False,
|
||||
use_supa_simple=True, # type: ignore
|
||||
)
|
||||
else:
|
||||
self.profiler = None
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def wake_up(self, tags: Optional[list[str]] = None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
def init_device(self):
|
||||
if self.device_config.device.type == "supa":
|
||||
self.device = torch.device(f"supa:{self.local_rank}")
|
||||
if self.kv_transfer_config is not None:
|
||||
device_cursor = self.kv_transfer_config.get_from_extra_config(
|
||||
"device_cursor", 0)
|
||||
self.device = torch.device(
|
||||
f"supa:{self.local_rank + int(device_cursor)}")
|
||||
SUPAPlatform.set_device(self.device)
|
||||
|
||||
_check_if_gpu_supports_dtype(self.model_config.dtype)
|
||||
# Initialize the distributed environment BEFORE taking
|
||||
# memory snapshot
|
||||
# This ensures SUCCL buffers are allocated before we measure
|
||||
# available memory
|
||||
self._init_worker_distributed_environment()
|
||||
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
gc.collect()
|
||||
torch.supa.empty_cache()
|
||||
self.init_gpu_memory = SUPAPlatform.get_device_total_memory()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Not support device type: {self.device_config.device}")
|
||||
|
||||
# Construct the model runner
|
||||
self.model_runner: SUPAModelRunner = SUPAModelRunner( # type: ignore
|
||||
self.vllm_config, self.device)
|
||||
|
||||
if self.rank == 0:
|
||||
# If usage stat is enabled, collect relevant info.
|
||||
report_usage_stats(self.vllm_config)
|
||||
|
||||
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
|
||||
# to hijack tensor allocation.
|
||||
def load_model(self) -> None:
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
raise NotImplementedError('SUPA do not support sleep mode')
|
||||
else:
|
||||
from contextlib import nullcontext
|
||||
|
||||
context = nullcontext()
|
||||
with context:
|
||||
self.model_runner.load_model()
|
||||
|
||||
@torch.inference_mode()
|
||||
def determine_available_memory(self) -> int:
|
||||
"""Profiles the peak memory usage of the model to determine how much
|
||||
memory can be used for KV cache without OOMs.
|
||||
|
||||
The engine will first conduct a profiling of the existing memory usage.
|
||||
Then, it calculate the free memory that can be used for KV cache in
|
||||
bytes.
|
||||
|
||||
.. tip::
|
||||
You may limit the usage of GPU memory
|
||||
by adjusting the `gpu_memory_utilization` parameter.
|
||||
"""
|
||||
torch.supa.empty_cache()
|
||||
|
||||
_, total_gpu_memory = torch.supa.mem_get_info()
|
||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||
# of the model.
|
||||
before_profile = SUPAMemorySnapshot()
|
||||
after_profile = SUPAMemorySnapshot()
|
||||
before_profile.measure()
|
||||
self.model_runner.profile_run()
|
||||
after_profile.measure()
|
||||
|
||||
free_gpu_memory, _ = torch.supa.mem_get_info()
|
||||
# NOTE(woosuk): Here we assume that the other processes using the same
|
||||
# GPU did not change their memory usage during the profiling.
|
||||
assert self.init_gpu_memory > free_gpu_memory, (
|
||||
"Error in memory profiling. "
|
||||
f"Initial free memory {self.init_gpu_memory}, current free memory"
|
||||
f" {free_gpu_memory}. This happens when the GPU memory was "
|
||||
"not properly cleaned up before initializing the vLLM instance.")
|
||||
|
||||
# GPU did not change their memory usage during the profiling.
|
||||
peak_memory = torch.supa.memory_allocated()
|
||||
# Check for any memory left around that may have been allocated on the
|
||||
# gpu outside of `torch`. NCCL operations, for example, can use a few
|
||||
# GB during a forward pass
|
||||
torch.supa.empty_cache()
|
||||
torch_allocated_bytes = SUPAPlatform.get_memory_stats(
|
||||
self.device, "allocated_bytes.all.current")
|
||||
total_allocated_bytes = (torch.supa.mem_get_info()[1] -
|
||||
torch.supa.mem_get_info()[0])
|
||||
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
|
||||
#if non_torch_allocations > 0:
|
||||
# peak_memory += non_torch_allocations
|
||||
available_kv_cache_memory = (
|
||||
total_gpu_memory * self.cache_config.gpu_memory_utilization -
|
||||
peak_memory)
|
||||
memory_for_current_instance = total_gpu_memory * \
|
||||
self.cache_config.gpu_memory_utilization
|
||||
diff_profile = after_profile - before_profile
|
||||
msg = (f"Memory profiling takes {diff_profile.timestamp:.2f} seconds\n"
|
||||
"the current vLLM instance can use "
|
||||
"total_gpu_memory "
|
||||
f"({(total_gpu_memory / GiB_bytes):.2f}GiB)"
|
||||
" x gpu_memory_utilization "
|
||||
f"({self.cache_config.gpu_memory_utilization:.2f})"
|
||||
f" = {(memory_for_current_instance / GiB_bytes):.2f}GiB\n"
|
||||
"model weights take "
|
||||
f"{(self.model_runner.model_memory_usage / GiB_bytes):.2f}GiB;"
|
||||
" non_torch_memory takes "
|
||||
f"{(non_torch_allocations / GiB_bytes):.2f}GiB;"
|
||||
" PyTorch activation peak memory takes "
|
||||
f"{(diff_profile.torch_peak / GiB_bytes):.2f}GiB;"
|
||||
" the rest of the memory reserved for KV Cache is "
|
||||
f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.")
|
||||
logger.info(msg)
|
||||
return int(available_kv_cache_memory)
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
return self.model_runner.get_kv_cache_spec()
|
||||
|
||||
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""Allocate GPU KV cache with the specified kv_cache_config."""
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
raise NotImplementedError('SUPA do not support sleep mode')
|
||||
else:
|
||||
from contextlib import nullcontext
|
||||
|
||||
context = nullcontext()
|
||||
with context:
|
||||
self.model_runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
# warm up sizes that are not in cudagraph capture sizes,
|
||||
# but users still want to compile for better performance,
|
||||
# e.g. for the max-num-batched token size in chunked prefill.
|
||||
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
|
||||
if not self.model_config.enforce_eager:
|
||||
warmup_sizes = [
|
||||
x for x in warmup_sizes
|
||||
if x not in self.scheduler_config.cuda_graph_sizes
|
||||
]
|
||||
for size in sorted(warmup_sizes, reverse=True):
|
||||
logger.info("Compile and warming up model for size %d", size)
|
||||
self.model_runner._dummy_run(size,
|
||||
skip_eplb=True,
|
||||
remove_lora=False)
|
||||
self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
|
||||
if not self.model_config.enforce_eager:
|
||||
self.model_runner.capture_model()
|
||||
|
||||
# Warm up sampler and preallocate memory buffer for logits and other
|
||||
# sampling related tensors of max possible shape to avoid memory
|
||||
# fragmentation issue.
|
||||
# NOTE: This is called after `capture_model` on purpose to prevent
|
||||
# memory buffers from being cleared by `SUPAPlatform.empty_cache`.
|
||||
if get_pp_group().is_last_rank:
|
||||
max_num_reqs = min(
|
||||
self.scheduler_config.max_num_seqs,
|
||||
self.scheduler_config.max_num_batched_tokens,
|
||||
)
|
||||
hidden_states, last_hidden_states = \
|
||||
self.model_runner._dummy_run(
|
||||
num_tokens=max_num_reqs,
|
||||
skip_eplb=True,
|
||||
)
|
||||
if self.model_runner.is_pooling_model:
|
||||
self.model_runner._dummy_pooler_run(hidden_states)
|
||||
else:
|
||||
self.model_runner._dummy_sampler_run(
|
||||
hidden_states=last_hidden_states)
|
||||
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model_runner.get_model()
|
||||
|
||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
return self.model_runner.get_supported_tasks()
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]:
|
||||
intermediate_tensors = None
|
||||
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
|
||||
if forward_pass and not get_pp_group().is_first_rank:
|
||||
# intermediate_tensors = IntermediateTensors(
|
||||
# get_pp_group().recv_tensor_dict(
|
||||
# all_gather_group=get_tp_group()))
|
||||
# use cpu send/recv
|
||||
if br_envs.VLLM_PP_CPU_SEND_RECV:
|
||||
cpu_dict = get_pp_group().recv_tensor_dict()
|
||||
gpu_dict = {
|
||||
k: v.to(torch.supa.current_device())
|
||||
for k, v in cpu_dict.items()
|
||||
}
|
||||
intermediate_tensors = IntermediateTensors(gpu_dict)
|
||||
else:
|
||||
intermediate_tensors = IntermediateTensors(
|
||||
get_pp_group().recv_tensor_dict())
|
||||
|
||||
output = self.model_runner.execute_model(scheduler_output,
|
||||
intermediate_tensors)
|
||||
if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
|
||||
return output
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
assert parallel_config.distributed_executor_backend != (
|
||||
"external_launcher") and not get_pp_group().is_last_rank
|
||||
# use cpu send/recv
|
||||
if br_envs.VLLM_PP_CPU_SEND_RECV:
|
||||
cpu_dict = {k: v.cpu() for k, v in output.tensors.items()}
|
||||
get_pp_group().send_tensor_dict(cpu_dict)
|
||||
else:
|
||||
get_pp_group().send_tensor_dict(output.tensors)
|
||||
kv_connector_output = output.kv_connector_output
|
||||
if not kv_connector_output:
|
||||
return None
|
||||
# In case of PP with kv transfer, we need to pass through the
|
||||
# kv_connector_output
|
||||
if (not kv_connector_output.finished_sending
|
||||
and not kv_connector_output.finished_recving):
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
output.kv_connector_output = kv_connector_output
|
||||
return output
|
||||
|
||||
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
||||
return self.model_runner.take_draft_token_ids()
|
||||
|
||||
def profile(self, is_start: bool = True):
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
if is_start:
|
||||
self.profiler.start()
|
||||
else:
|
||||
self.profiler.stop()
|
||||
|
||||
def execute_dummy_batch(self) -> None:
|
||||
self.model_runner._dummy_run(1)
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.model_runner.add_lora(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self.model_runner.remove_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> set[int]:
|
||||
return self.model_runner.list_loras()
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.model_runner.pin_lora(lora_id)
|
||||
|
||||
def check_health(self) -> None:
|
||||
# worker will always be healthy as long as it's running.
|
||||
return
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
path: str,
|
||||
pattern: Optional[str] = None,
|
||||
max_size: Optional[int] = None,
|
||||
) -> None:
|
||||
from vllm.model_executor.model_loader.loader import ShardedStateLoader
|
||||
|
||||
ShardedStateLoader.save_model(
|
||||
self.model_runner.model,
|
||||
path,
|
||||
pattern=pattern,
|
||||
max_size=max_size,
|
||||
)
|
||||
|
||||
def _init_worker_distributed_environment(self) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
set_custom_all_reduce(
|
||||
not self.parallel_config.disable_custom_all_reduce)
|
||||
init_distributed_environment(self.parallel_config.world_size,
|
||||
self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank,
|
||||
"sccl",
|
||||
timeout=datetime.timedelta(seconds=100))
|
||||
ensure_model_parallel_initialized(
|
||||
self.parallel_config.tensor_parallel_size,
|
||||
self.parallel_config.pipeline_parallel_size)
|
||||
|
||||
ensure_kv_transfer_initialized(self.vllm_config)
|
||||
|
||||
|
||||
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
||||
# Check if the GPU supports the dtype.
|
||||
# TODO: add checkers
|
||||
return
|
||||
if torch_dtype == torch.bfloat16: # noqa: SIM102
|
||||
capability = SUPAPlatform.get_device_capability()
|
||||
gpu_name = SUPAPlatform.get_device_name()
|
||||
|
||||
if capability is None:
|
||||
compute_str = "does not have a compute capability"
|
||||
else:
|
||||
version_str = capability.as_version_str()
|
||||
compute_str = f"has compute capability {version_str}"
|
||||
|
||||
raise ValueError(
|
||||
"Bfloat16 is only supported on GPUs with compute capability "
|
||||
f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
|
||||
"You can use float16 instead by explicitly setting the "
|
||||
"`dtype` flag in CLI, for example: --dtype=half.")
|
||||
Reference in New Issue
Block a user