Add minimal vLLM 0.16.1 build repo for BI-V150
This commit is contained in:
94
vllm/model_executor/offloader/prefetch_ops.py
Normal file
94
vllm/model_executor/offloader/prefetch_ops.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Custom ops for prefetch offloader torch.compile + CUDA graph compatibility.
|
||||
|
||||
These ops use mutates_args to create data dependencies that prevent
|
||||
the compiler from reordering prefetch/sync operations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.offloader.base import get_offloader
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
# --- wait_prefetch op ---
|
||||
|
||||
|
||||
def _wait_prefetch_impl(
|
||||
input_tensor: torch.Tensor,
|
||||
layer_idx: int,
|
||||
) -> None:
|
||||
"""Wait for prefetch of layer_idx to complete.
|
||||
|
||||
Synchronizes the compute stream with the copy stream to ensure
|
||||
the prefetched weights are ready for use.
|
||||
|
||||
Args:
|
||||
input_tensor: Input to the layer (e.g., hidden_states) - declared
|
||||
as mutated to create data dependency for torch.compile.
|
||||
layer_idx: Index of the layer to wait for.
|
||||
"""
|
||||
get_offloader()._wait_for_layer(layer_idx)
|
||||
|
||||
|
||||
def _wait_prefetch_fake(
|
||||
input_tensor: torch.Tensor,
|
||||
layer_idx: int,
|
||||
) -> None:
|
||||
"""Fake implementation for torch.compile tracing."""
|
||||
return
|
||||
|
||||
|
||||
# --- start_prefetch op ---
|
||||
|
||||
|
||||
def _start_prefetch_impl(
|
||||
output_tensor: torch.Tensor,
|
||||
layer_idx: int,
|
||||
) -> None:
|
||||
"""Start async prefetch of layer_idx weights.
|
||||
|
||||
Initiates H2D copy on the copy stream for the specified layer.
|
||||
|
||||
Args:
|
||||
output_tensor: Output from forward - declared as mutated to
|
||||
prevent torch.compile from reordering this op before the
|
||||
computation that produces output_tensor.
|
||||
layer_idx: Index of the layer to prefetch.
|
||||
"""
|
||||
get_offloader()._start_prefetch(layer_idx)
|
||||
|
||||
|
||||
def _start_prefetch_fake(
|
||||
output_tensor: torch.Tensor,
|
||||
layer_idx: int,
|
||||
) -> None:
|
||||
"""Fake implementation for torch.compile tracing."""
|
||||
return
|
||||
|
||||
|
||||
def register_prefetch_offloader_ops() -> None:
|
||||
"""Register custom ops for prefetch offloader.
|
||||
|
||||
Must be called before the ops are used. This is typically done
|
||||
at module import time.
|
||||
"""
|
||||
direct_register_custom_op(
|
||||
op_name="wait_prefetch",
|
||||
op_func=_wait_prefetch_impl,
|
||||
mutates_args=["input_tensor"],
|
||||
fake_impl=_wait_prefetch_fake,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="start_prefetch",
|
||||
op_func=_start_prefetch_impl,
|
||||
mutates_args=["output_tensor"],
|
||||
fake_impl=_start_prefetch_fake,
|
||||
)
|
||||
|
||||
|
||||
# Register ops at module import time
|
||||
register_prefetch_offloader_ops()
|
||||
Reference in New Issue
Block a user