init
This commit is contained in:
402
vllm/forward_context.py
Normal file
402
vllm/forward_context.py
Normal file
@@ -0,0 +1,402 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.worker.ubatch_utils import UBatchSlices, is_second_ubatch_empty
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
|
||||
last_logging_time: float = 0
|
||||
forward_start_time: float = 0
|
||||
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
|
||||
batchsize_forward_time: defaultdict = defaultdict(list)
|
||||
|
||||
|
||||
class BatchDescriptor(NamedTuple):
|
||||
"""
|
||||
Batch descriptor for cudagraph dispatching. We should keep the num of
|
||||
items as minimal as possible to properly and uniquely describe the padded
|
||||
batch for cudagraph.
|
||||
"""
|
||||
num_tokens: int
|
||||
uniform_decode: bool = False
|
||||
"""
|
||||
False can also be used for an uniform decode batch to dispatch to the
|
||||
cudagraph supporting non-uniform batches.
|
||||
"""
|
||||
|
||||
@property
|
||||
def non_uniform(self) -> "BatchDescriptor":
|
||||
"""
|
||||
Return a non-uniform version of current batch descriptor.
|
||||
"""
|
||||
return BatchDescriptor(self.num_tokens, uniform_decode=False)
|
||||
|
||||
|
||||
def _compute_sp_num_tokens(num_tokens_across_dp_cpu: torch.Tensor,
|
||||
sequence_parallel_size: int) -> list[int]:
|
||||
sp_tokens = ((num_tokens_across_dp_cpu + sequence_parallel_size - 1) //
|
||||
sequence_parallel_size)
|
||||
|
||||
sp_tokens = sp_tokens.repeat_interleave(sequence_parallel_size)
|
||||
return sp_tokens.tolist()
|
||||
|
||||
|
||||
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: torch.Tensor,
|
||||
sequence_parallel_size: int,
|
||||
max_num_tokens: int,
|
||||
chunk_idx: int) -> list[int]:
|
||||
|
||||
sp_tokens = _compute_sp_num_tokens(num_tokens_across_dp_cpu,
|
||||
sequence_parallel_size)
|
||||
sp_size = len(sp_tokens)
|
||||
|
||||
local_size = [-1] * sp_size
|
||||
for i in range(sp_size):
|
||||
# Take into account sharding if MoE activation is sequence parallel.
|
||||
local_size[i] = min(max_num_tokens,
|
||||
sp_tokens[i] - (max_num_tokens * chunk_idx))
|
||||
if local_size[i] <= 0:
|
||||
local_size[i] = 1 # ensure lockstep even if done
|
||||
return local_size
|
||||
|
||||
|
||||
@dataclass
|
||||
class DPMetadata:
|
||||
max_tokens_across_dp_cpu: torch.Tensor
|
||||
num_tokens_across_dp_cpu: torch.Tensor
|
||||
|
||||
# NOTE: local_sizes should only be set by the chunked_sizes context manager
|
||||
local_sizes: Optional[list[int]] = None
|
||||
|
||||
@staticmethod
|
||||
def num_tokens_across_dp(num_tokens: int, dp_size: int,
|
||||
dp_rank: int) -> torch.Tensor:
|
||||
"""
|
||||
Gather the num_tokens across all DP ranks and return results in a
|
||||
CPU tensor of size dp_size.
|
||||
"""
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
device = current_platform.device_type
|
||||
group = get_dp_group().device_group
|
||||
|
||||
# Transfering this tensor from GPU to CPU will introduce a GPU sync
|
||||
# point that could adversely affect performance of vllm with asynch
|
||||
# scheduling. This environment variable exists to quickly disable
|
||||
# this optimization if we run into this case.
|
||||
if envs.VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION:
|
||||
logger.info_once(
|
||||
"Using CPU all reduce to syncronize DP padding between ranks.")
|
||||
device = "cpu"
|
||||
group = get_dp_group().cpu_group
|
||||
num_tokens_across_dp = [0] * dp_size
|
||||
num_tokens_across_dp[dp_rank] = num_tokens
|
||||
num_tokens_tensor = torch.tensor(num_tokens_across_dp,
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
dist.all_reduce(num_tokens_tensor, group=group)
|
||||
return num_tokens_tensor.cpu()
|
||||
|
||||
# Get the cumulative tokens across sequence parallel ranks.
|
||||
# In this case the input to the MoEs will be distributed w.r.t both
|
||||
# DP and TP rank.
|
||||
# When sp_size==1, this is just the cummulative num tokens across DP.
|
||||
def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor:
|
||||
num_tokens_across_sp_cpu = (
|
||||
(self.num_tokens_across_dp_cpu - 1 + sp_size) // sp_size)
|
||||
num_tokens_across_sp_cpu = (
|
||||
num_tokens_across_sp_cpu.repeat_interleave(sp_size))
|
||||
return torch.cumsum(num_tokens_across_sp_cpu, dim=0)
|
||||
|
||||
@staticmethod
|
||||
def should_ubatch_across_dp(
|
||||
should_ubatch: bool, orig_num_tokens_per_ubatch: int,
|
||||
padded_num_tokens_per_ubatch: int, dp_size: int,
|
||||
dp_rank: int) -> tuple[bool, Optional[torch.Tensor]]:
|
||||
"""
|
||||
1. Decides if each DP rank is going to microbatch. Either all ranks
|
||||
run with microbatching or none of them do. If this function decides
|
||||
not to run with microbatching. It will "abort" meaning that no padding
|
||||
information will be returned to the caller. It will return (False, None)
|
||||
|
||||
2. Determines the total number of tokens that each rank will run.
|
||||
All ranks will be padded out so that the run with the same number
|
||||
of tokens
|
||||
|
||||
Returns: tuple[
|
||||
should_ubatch: Are all DP ranks going to microbatch
|
||||
num_tokens_after_padding: A tensor containing the total number of
|
||||
tokens per-microbatch for each DP rank including padding. Will be
|
||||
None if should_ubatch if False
|
||||
]
|
||||
"""
|
||||
|
||||
device = current_platform.device_type
|
||||
tensor = torch.zeros(3, dp_size, device=device, dtype=torch.int32)
|
||||
tensor[0][dp_rank] = orig_num_tokens_per_ubatch
|
||||
tensor[1][dp_rank] = padded_num_tokens_per_ubatch
|
||||
tensor[2][dp_rank] = 1 if should_ubatch else 0
|
||||
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
dist.all_reduce(tensor, group=get_dp_group().device_group)
|
||||
|
||||
result: bool = bool(torch.all(tensor[2] == 1).item())
|
||||
if not result:
|
||||
return result, None
|
||||
|
||||
orig_num_tokens_tensor = tensor[0, :]
|
||||
padded_num_tokens_tensor = tensor[1, :]
|
||||
|
||||
orig_min_num_tokens = int(orig_num_tokens_tensor.min().item())
|
||||
padded_max_num_tokens = int(padded_num_tokens_tensor.max().item())
|
||||
if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens):
|
||||
logger.debug("Aborting ubatching %s %s", orig_min_num_tokens,
|
||||
padded_max_num_tokens)
|
||||
return False, None
|
||||
return result, padded_num_tokens_tensor.cpu()
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
parallel_config: ParallelConfig,
|
||||
attn_metadata: Any,
|
||||
num_tokens: int,
|
||||
num_tokens_across_dp_cpu: Optional[torch.Tensor] = None
|
||||
) -> "DPMetadata":
|
||||
|
||||
assert parallel_config.data_parallel_size > 1
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
dp_rank = parallel_config.data_parallel_rank
|
||||
if attn_metadata is not None and hasattr(attn_metadata,
|
||||
"num_prefill_tokens"):
|
||||
# for v0 attention backends
|
||||
batchsize = attn_metadata.num_prefill_tokens + \
|
||||
attn_metadata.num_decode_tokens
|
||||
else:
|
||||
# for v1 attention backends or no attn_metadata
|
||||
batchsize = num_tokens
|
||||
|
||||
# If num_tokens_across_dp is None, it will be computed by all_reduce
|
||||
# Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
|
||||
assert (num_tokens_across_dp_cpu is None
|
||||
or num_tokens_across_dp_cpu[dp_rank] == batchsize
|
||||
), f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}"
|
||||
if num_tokens_across_dp_cpu is None:
|
||||
num_tokens_across_dp_cpu = DPMetadata.num_tokens_across_dp(
|
||||
batchsize, dp_size, dp_rank)
|
||||
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu)
|
||||
return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu)
|
||||
|
||||
@contextmanager
|
||||
def chunked_sizes(self, sequence_parallel_size: int,
|
||||
max_chunk_size_per_rank: int, chunk_idx: int):
|
||||
"""
|
||||
Context manager to compute and temporarily set the per-rank local token
|
||||
sizes for a specific chunk during chunked forward execution.
|
||||
|
||||
This is necessary to ensure each DP (data parallel) rank processes its
|
||||
designated portion of tokens in lockstep with others, even when the
|
||||
token counts are uneven or some ranks have completed their input early.
|
||||
|
||||
For chunked execution, we break up the total tokens on each rank into
|
||||
multiple chunks (of at most `max_chunk_size_per_rank`), and for a given
|
||||
`chunk_idx`, this context manager sets `self.local_sizes` to the number
|
||||
of tokens to process in that chunk on each rank.
|
||||
|
||||
`self.local_sizes` is only valid inside the context.
|
||||
|
||||
Args:
|
||||
sequence_parallel_size: When Attn is TP and MoE layers are EP,
|
||||
we use SP between the layers to avoid
|
||||
redundant ops. We need this value to
|
||||
compute the chunked sizes.
|
||||
max_chunk_size_per_rank: The max number of tokens each rank is
|
||||
allowed to process in this chunk.
|
||||
chunk_idx: The index of the chunk to compute sizes for.
|
||||
"""
|
||||
self.local_sizes = _compute_chunked_local_num_tokens(
|
||||
self.num_tokens_across_dp_cpu, sequence_parallel_size,
|
||||
max_chunk_size_per_rank, chunk_idx)
|
||||
try:
|
||||
yield self.local_sizes
|
||||
finally:
|
||||
self.local_sizes = None
|
||||
|
||||
@contextmanager
|
||||
def sp_local_sizes(self, sequence_parallel_size: int):
|
||||
"""
|
||||
Context mamager for setting self.local_sizes. Same as self.chunked_sizes
|
||||
but without any chunking.
|
||||
"""
|
||||
self.local_sizes = _compute_sp_num_tokens(
|
||||
self.num_tokens_across_dp_cpu, sequence_parallel_size)
|
||||
try:
|
||||
yield self.local_sizes
|
||||
finally:
|
||||
self.local_sizes = None
|
||||
|
||||
def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]:
|
||||
assert self.local_sizes is not None
|
||||
return self.local_sizes
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForwardContext:
|
||||
# copy from vllm_config.compilation_config.static_forward_context
|
||||
no_compile_layers: dict[str, Any]
|
||||
"""
|
||||
Type AttentionMetadata for v0,
|
||||
Type Dict[str, AttentionMetadata] for v1, map from layer_name of each
|
||||
attention layer to its attention metadata
|
||||
Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one
|
||||
for each microbatch.
|
||||
Set dynamically for each forward pass
|
||||
"""
|
||||
attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"],
|
||||
list[dict[str, "AttentionMetadata"]]]
|
||||
# TODO: remove after making all virtual_engines share the same kv cache
|
||||
virtual_engine: int # set dynamically for each forward pass
|
||||
# set dynamically for each forward pass
|
||||
dp_metadata: Optional[DPMetadata] = None
|
||||
# determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE.
|
||||
# by default NONE, no cudagraph is used.
|
||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE
|
||||
batch_descriptor: Optional[BatchDescriptor] = None
|
||||
|
||||
ubatch_slices: Optional[UBatchSlices] = None
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.cudagraph_runtime_mode in [
|
||||
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
|
||||
f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
|
||||
|
||||
|
||||
_forward_context: Optional[ForwardContext] = None
|
||||
|
||||
|
||||
def get_forward_context() -> ForwardContext:
|
||||
"""Get the current forward context."""
|
||||
assert _forward_context is not None, (
|
||||
"Forward context is not set. "
|
||||
"Please use `set_forward_context` to set the forward context.")
|
||||
return _forward_context
|
||||
|
||||
|
||||
def create_forward_context(
|
||||
attn_metadata: Any,
|
||||
vllm_config: VllmConfig,
|
||||
virtual_engine: int = 0,
|
||||
dp_metadata: Optional[DPMetadata] = None,
|
||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor: Optional[BatchDescriptor] = None,
|
||||
ubatch_slices: Optional[UBatchSlices] = None):
|
||||
return ForwardContext(no_compile_layers=vllm_config.compilation_config.
|
||||
static_forward_context,
|
||||
virtual_engine=virtual_engine,
|
||||
attn_metadata=attn_metadata,
|
||||
dp_metadata=dp_metadata,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
ubatch_slices=ubatch_slices)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def override_forward_context(forward_context: Optional[ForwardContext]):
|
||||
"""A context manager that overrides the current forward context.
|
||||
This is used to override the forward context for a specific
|
||||
forward pass.
|
||||
"""
|
||||
global _forward_context
|
||||
prev_context = _forward_context
|
||||
_forward_context = forward_context
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_forward_context = prev_context
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_forward_context(
|
||||
attn_metadata: Any,
|
||||
vllm_config: VllmConfig,
|
||||
virtual_engine: int = 0,
|
||||
num_tokens: Optional[int] = None,
|
||||
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor: Optional[BatchDescriptor] = None,
|
||||
ubatch_slices: Optional[UBatchSlices] = None):
|
||||
"""A context manager that stores the current forward context,
|
||||
can be attention metadata, etc.
|
||||
Here we can inject common logic for every model forward pass.
|
||||
"""
|
||||
global forward_start_time
|
||||
need_to_track_batchsize = track_batchsize and attn_metadata is not None
|
||||
if need_to_track_batchsize:
|
||||
forward_start_time = time.perf_counter()
|
||||
|
||||
dp_metadata: Optional[DPMetadata] = None
|
||||
if vllm_config.parallel_config.data_parallel_size > 1 and (
|
||||
attn_metadata is not None or num_tokens is not None):
|
||||
dp_metadata = DPMetadata.make(vllm_config.parallel_config,
|
||||
attn_metadata, num_tokens or 0,
|
||||
num_tokens_across_dp)
|
||||
|
||||
forward_context = create_forward_context(attn_metadata, vllm_config,
|
||||
virtual_engine, dp_metadata,
|
||||
cudagraph_runtime_mode,
|
||||
batch_descriptor, ubatch_slices)
|
||||
|
||||
try:
|
||||
with override_forward_context(forward_context):
|
||||
yield
|
||||
finally:
|
||||
global last_logging_time, batchsize_logging_interval
|
||||
if need_to_track_batchsize:
|
||||
if hasattr(attn_metadata, "num_prefill_tokens"):
|
||||
# for v0 attention backends
|
||||
batchsize = attn_metadata.num_prefill_tokens + \
|
||||
attn_metadata.num_decode_tokens
|
||||
else:
|
||||
# for v1 attention backends
|
||||
batchsize = num_tokens
|
||||
# we use synchronous scheduling right now,
|
||||
# adding a sync point here should not affect
|
||||
# scheduling of the next batch
|
||||
from vllm.platforms import current_platform
|
||||
synchronize = current_platform.synchronize
|
||||
if synchronize is not None:
|
||||
synchronize()
|
||||
now = time.perf_counter()
|
||||
# time measurement is in milliseconds
|
||||
batchsize_forward_time[batchsize].append(
|
||||
(now - forward_start_time) * 1000)
|
||||
if now - last_logging_time > batchsize_logging_interval:
|
||||
last_logging_time = now
|
||||
forward_stats = []
|
||||
for bs, times in batchsize_forward_time.items():
|
||||
if len(times) <= 1:
|
||||
# can be cudagraph / profiling run
|
||||
continue
|
||||
medium = torch.quantile(torch.tensor(times), q=0.5).item()
|
||||
medium = round(medium, 2)
|
||||
forward_stats.append((bs, len(times), medium))
|
||||
forward_stats.sort(key=lambda x: x[1], reverse=True)
|
||||
if forward_stats:
|
||||
logger.info(("Batchsize forward time stats "
|
||||
"(batchsize, count, median_time(ms)): %s"),
|
||||
forward_stats)
|
||||
Reference in New Issue
Block a user