first commit

This commit is contained in:
2026-03-10 13:31:25 +08:00
parent ba974cecfa
commit b62b889355
2604 changed files with 438977 additions and 0 deletions

19
vllm_br/v1/__init__.py Normal file
View File

@@ -0,0 +1,19 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from . import attention # noqa: F401
from . import executor # noqa: F401
from . import core, engine, kv_cache_interface, outputs, sample # noqa: F401

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,17 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from . import backends # noqa: F401

View File

@@ -0,0 +1,17 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from . import mla # noqa: F401
from .utils import *

View File

@@ -0,0 +1,657 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Optional, Tuple
import torch
import torch_br
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType,
is_quantized_kv_cache)
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)
from vllm.config import VllmConfig
from vllm.logger import logger
from vllm.v1.attention.backends.flash_attn import _get_sliding_window_configs
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
get_kv_cache_layout,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_br.config.compilation import SUPAGraphMode
if TYPE_CHECKING:
pass
# from vllm.v1.worker.gpu_model_runner import GPUModelRunner
class SUPAFlashAttentionBackend(AttentionBackend):
# NOTE: When piecewise cudagraph is enabled, this
# makes sure the output tensor is allocated inside the cudagraph.
# NOTE: currently, we do not support accept_output_buffer=True
accept_output_buffer: bool = False
supports_quant_query_input: bool = True
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")
@staticmethod
def get_name() -> str:
return "SUPAFLASH_ATTN_VLLM_V1"
@staticmethod
def get_impl_cls() -> type["SUPAFlashAttentionImpl"]:
return SUPAFlashAttentionImpl
@staticmethod
def get_metadata_cls() -> type["SUPAFlashAttentionMetadata"]:
return SUPAFlashAttentionMetadata
@staticmethod
def get_builder_cls() -> type["SUPAFlashAttentionMetadataBuilder"]:
return SUPAFlashAttentionMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_kv_cache_usharp_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
th_gran = SUPAFlashAttentionBackend.get_kv_cache_usharp_alignment(
block_size)
n_block = max(1, (num_blocks + th_gran - 1) // th_gran)
logger.debug(
f'Origin kv cache shape is [2, {num_blocks}, {block_size}, {num_kv_heads}, {head_size}, For SUPA Speed up, use [2, {n_block}, {th_gran * block_size}, {num_kv_heads * head_size}]' # noqa: G004
)
return (2, n_block, th_gran * block_size, num_kv_heads * head_size)
@staticmethod
def get_kv_cache_usharp_alignment(block_size: int) -> int:
max_h_limit = 2048
return max_h_limit // block_size
@staticmethod
def get_kv_cache_stride_order() -> tuple[int, ...]:
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_cache_layout()
if cache_layout == "NHD":
stride_order = (0, 1, 2, 3, 4)
elif cache_layout == "HND":
stride_order = (0, 1, 3, 2, 4)
else:
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return stride_order
@staticmethod
def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
return torch.float8_e4m3fn
else:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
@dataclass
class SUPAFlashAttentionMetadata:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
# BIREN Attention Params
seq_start_loc: torch.Tensor
context_lens: torch.Tensor
max_decode_seq_len: int
do_cache: bool # when use attentionsplit, do cache = False
num_actual_reqs: torch.Tensor
# Graph mode
supagraph_runtime_mode: SUPAGraphMode
# For handling prefill decode split
num_decodes: int
num_decode_tokens: int
num_prefills: int
num_prefill_tokens: int
# For cascade attention.
use_cascade: bool
common_prefix_len: int
cu_prefix_query_lens: Optional[torch.Tensor]
prefix_kv_lens: Optional[torch.Tensor]
suffix_kv_lens: Optional[torch.Tensor]
# Optional aot scheduling
scheduler_metadata: Optional[torch.Tensor] = None
prefix_scheduler_metadata: Optional[torch.Tensor] = None
max_num_splits: int = 0
causal: bool = True
# for local attention
# @dataclass
# class LocalAttentionMetadata:
# local_query_start_loc: torch.Tensor
# local_seqused_k: torch.Tensor
# local_block_table: torch.Tensor
# local_max_query_len: int
# local_max_seq_len: int
# local_scheduler_metadata: Optional[torch.Tensor]
# local_attn_metadata: Optional[LocalAttentionMetadata] = None
class SUPAFlashAttentionMetadataBuilder(
AttentionMetadataBuilder[SUPAFlashAttentionMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.ALWAYS
reorder_batch_threshold: int = 1
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config
self.cache_config = vllm_config.cache_config
self.compilation_config = vllm_config.compilation_config
self.num_heads_q = self.model_config.get_num_attention_heads(
self.parallel_config)
self.num_heads_kv = self.model_config.get_num_kv_heads(
self.parallel_config)
self.kv_cache_dtype = kv_cache_spec.dtype
self.headdim = self.model_config.get_head_size()
self.block_size = kv_cache_spec.block_size
supports_spec_as_decode = True
self._init_reorder_batch_threshold(1, supports_spec_as_decode)
self.max_num_splits = 0 # No upper bound on the number of splits.
# self.aot_schedule = (get_flash_attn_version() == 3)
self.aot_schedule = False
self.use_full_cuda_graph = \
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
self.max_cudagraph_size = self.compilation_config.max_capture_size
# if self.use_full_cuda_graph and self.aot_schedule:
# if self.max_cudagraph_size > 992:
# # This condition derives from FA3's internal heuristic.
# # TODO(woosuk): Support larger cudagraph sizes.
# raise ValueError(
# "Capture size larger than 992 is not supported for "
# "full cuda graph.")
# self.scheduler_metadata = torch.zeros(
# vllm_config.scheduler_config.max_num_seqs + 1,
# dtype=torch.int32,
# device=self.device,
# )
# # When using cuda graph, we need to set the upper bound of the
# # number of splits so that large enough intermediate buffers are
# # pre-allocated during capture.
# self.max_num_splits = (
# envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH)
# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
self.aot_sliding_window: Optional[tuple[int, int]] = None
# model_config = runner.model_config
# self.runner = runner
# self.num_heads_q = model_config.get_num_attention_heads(
# runner.parallel_config)
# self.num_heads_kv = model_config.get_num_kv_heads(
# runner.parallel_config)
# self.headdim = model_config.get_head_size()
# self.block_size = kv_cache_spec.block_size
# self.kv_cache_spec = kv_cache_spec
# self.block_table = block_table
# self.aot_schedule = False
# logger.warning(
# "AOT Schedule is disabled when using SUPAFlashAttention.")
# # Sliding window size to be used with the AOT scheduler will be
# # populated on first build() call.
# self.aot_sliding_window: Optional[tuple[int, int]] = None
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> SUPAFlashAttentionMetadata:
"""
fast_build disables AOT scheduling, used when there will be few
iterations i.e. spec-decode
"""
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=self.reorder_batch_threshold,
require_uniform=True)
max_query_len = common_attn_metadata.max_query_len
max_seq_len = common_attn_metadata.max_seq_len
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
causal = common_attn_metadata.causal
num_actual_reqs = common_attn_metadata.num_actual_reqs
seq_start_loc = common_attn_metadata.seq_start_loc
context_lens = common_attn_metadata.context_lens
# the overhead of the aot schedule is not worth it for spec-decode
aot_schedule = self.aot_schedule and not fast_build
if self.aot_sliding_window is None:
self.aot_sliding_window = (-1, -1)
# For the AOT scheduler we need the sliding window value to be
# constant for all layers to. We have to populate this on the first
# build() call so the layers are constructed (cannot populate)
# in __init__.
if aot_schedule:
sliding_window_configs = _get_sliding_window_configs(
self.vllm_config)
if len(sliding_window_configs) == 1:
sliding_window_config = sliding_window_configs.pop()
if sliding_window_config is not None:
self.aot_sliding_window = sliding_window_config
elif len(sliding_window_configs) > 1:
self.aot_schedule = False
aot_schedule = False
max_num_splits = 0 # 0 means use FA3's heuristics, not CG compatible
if self.use_full_cuda_graph and \
num_actual_tokens <= self.max_cudagraph_size:
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits,
# num_heads, num_tokens, head_size] are allocated. Therefore,
# we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
max_seq_len, causal):
if self.aot_schedule:
raise NotImplementedError(
'aot schedule not support in SUPA attention')
return None
# for local attention
# local_attn_metadata = None
# if self.runner.attention_chunk_size is not None:
# seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
# virt_block_table_tensor = make_local_attention_virtual_batches(
# self.runner.attention_chunk_size,
# self.runner.query_start_loc_np[:num_reqs + 1],
# self.runner.seq_lens_np[:num_reqs],
# block_table_tensor,
# self.block_size,
# )
# local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
# self.runner.device, non_blocking=False)
# local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
# self.runner.device, non_blocking=False)
# local_max_query_len = seqlens_q_local_np.max()
# local_max_seq_len = virt_k_seqlens_np.max()
# local_scheduler_metadata = schedule(
# batch_size=local_query_start_loc.shape[0] - 1,
# cu_query_lens=local_query_start_loc,
# max_query_len=local_max_query_len,
# seqlens=local_seqused_k,
# max_seq_len=local_max_seq_len,
# causal=True)
# local_attn_metadata = SUPAFlashAttentionMetadata.LocalAttentionMetadata(
# local_query_start_loc=local_query_start_loc,
# local_seqused_k=local_seqused_k,
# local_block_table=virt_block_table_tensor,
# local_max_query_len=local_max_query_len,
# local_max_seq_len=local_max_seq_len,
# local_scheduler_metadata=local_scheduler_metadata,
# )
use_cascade = common_prefix_len > 0
if use_cascade:
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
dtype=torch.int32,
device=self.runner.device)
prefix_kv_lens = torch.tensor([common_prefix_len],
dtype=torch.int32,
device=self.runner.device)
suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len).to(
self.device, non_blocking=True)
prefix_scheduler_metadata = schedule(
batch_size=1,
cu_query_lens=cu_prefix_query_lens,
max_query_len=num_actual_tokens,
seqlens=prefix_kv_lens,
max_seq_len=common_prefix_len,
causal=False)
scheduler_metadata = schedule(batch_size=num_reqs,
cu_query_lens=query_start_loc,
max_query_len=max_query_len,
seqlens=suffix_kv_lens,
max_seq_len=max_seq_len -
common_prefix_len,
causal=True)
else:
cu_prefix_query_lens = None
prefix_kv_lens = None
suffix_kv_lens = None
prefix_scheduler_metadata = None
scheduler_metadata = schedule(batch_size=num_reqs,
cu_query_lens=query_start_loc,
max_query_len=max_query_len,
seqlens=seq_lens,
max_seq_len=max_seq_len,
causal=causal)
if common_attn_metadata.max_decode_seq_len is None:
max_decode_seq_len = max_decode_seq_len = int(
seq_lens.max().item())
else:
max_decode_seq_len = common_attn_metadata.max_decode_seq_len
attn_metadata = SUPAFlashAttentionMetadata(
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table_tensor,
slot_mapping=slot_mapping,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
scheduler_metadata=scheduler_metadata,
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
# local_attn_metadata=local_attn_metadata,
prefix_scheduler_metadata=prefix_scheduler_metadata,
max_num_splits=max_num_splits,
causal=causal,
# Biren Attention Params
seq_start_loc=seq_start_loc,
context_lens=context_lens,
max_decode_seq_len=max_decode_seq_len,
num_prefills=num_prefills,
num_decodes=num_decodes,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
do_cache=True,
num_actual_reqs=num_actual_reqs,
supagraph_runtime_mode=common_attn_metadata.supagraph_runtime_mode)
return attn_metadata
def use_cascade_attention(self, *args, **kwargs) -> bool:
return False
class SUPAFlashAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
sinks: Optional[torch.Tensor] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes,
dtype=torch.float32,
device="cpu")
self.alibi_slopes = alibi_slopes
self.sliding_window = sliding_window or None
self.kv_cache_dtype = kv_cache_dtype
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
SUPAFlashAttentionBackend.validate_head_size(head_size)
self.attn_type = attn_type
if attn_type not in (AttentionType.DECODER,
AttentionType.ENCODER_ONLY):
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl")
self.vllm_flash_attn_version = get_flash_attn_version()
if is_quantized_kv_cache(self.kv_cache_dtype) \
and not flash_attn_supports_fp8():
raise NotImplementedError(
"FlashAttention does not support fp8 kv-cache on this device.")
self.sinks: Optional[torch.Tensor] = None
if sinks is not None:
if sinks.shape[0] != num_heads:
raise ValueError(
"Sinks must have the same number of heads as the number of "
f"heads in the layer. Expected {num_heads}, but got "
f"{sinks.shape[0]}.")
if sinks.dtype != torch.float32:
raise ValueError("Sinks must be of type float32, but got "
f"{sinks.dtype}.")
self.sinks = sinks
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: SUPAFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
NOTE: FP8 quantization, flash-attn expect the size of
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
"""
assert output is None, "Output tensor should not provided."
if attn_metadata is None:
# FIXME: this may lead to wrong block estimatation
# Profiling run.
return query
is_encoder = self.attn_type in (AttentionType.ENCODER_ONLY,
AttentionType.ENCODER)
# NOTE: supa attn use [batch_size, num_tokens, num_heads * head_size] as shape
if kv_cache is not None and attn_metadata.do_cache and not is_encoder:
torch_br.supa_kvcache_store_infer_v2(
kv_cache,
key,
value, # type: ignore
attn_metadata.slot_mapping,
self.head_size)
if self.sinks is not None:
return self.forward_sw_sinks(query, kv_cache, attn_metadata)
if self.attn_type in (AttentionType.ENCODER_ONLY,
AttentionType.ENCODER):
assert len(query.shape) == 3
return torch_br.supa_flash_attention_infer( # type: ignore
query,
key,
value,
attn_metadata.query_start_loc,
self.head_size,
len(attn_metadata.query_start_loc), # type: ignore
self.alibi_slopes,
softmax_scale=self.scale,
is_causal=_get_causal_option(self.attn_type))
num_prefill_tokens = attn_metadata.num_prefill_tokens
if attn_metadata.supagraph_runtime_mode is None or (
attn_metadata.supagraph_runtime_mode
in (SUPAGraphMode.NONE, SUPAGraphMode.FULL_DECODE_ONLY)):
# prefill + decode(non-mtp)
if num_prefill_tokens > 0:
output_prefill = torch_br.br_flash_attn_with_kvcache_infer( # type: ignore
query,
kv_cache,
attn_metadata.query_start_loc,
attn_metadata.seq_start_loc,
attn_metadata.block_table,
self.head_size,
alibi_slopes=self.alibi_slopes,
softmax_scale=self.scale,
num_reqs=attn_metadata.num_actual_reqs)
return output_prefill
## decode only
output_decode = torch_br.supa_attention_decoder_infer_v2( # type: ignore
query, # type: ignore
kv_cache,
attn_metadata.block_table,
attn_metadata.seq_lens,
attn_metadata.max_decode_seq_len,
self.head_size,
attn_metadata.num_prefills,
self.alibi_slopes,
softmax_scale=self.scale)
return output_decode
else:
output_prefill = torch_br.br_flash_attn_with_kvcache_infer( # type: ignore
query,
kv_cache,
attn_metadata.query_start_loc,
attn_metadata.seq_start_loc,
attn_metadata.block_table,
self.head_size,
alibi_slopes=self.alibi_slopes,
softmax_scale=self.scale,
num_reqs=attn_metadata.num_actual_reqs)
return output_prefill
# sliding window with sinks impl
def forward_sw_sinks(
self,
query: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: SUPAFlashAttentionMetadata,
) -> torch.Tensor:
# prefix-enabled attention
output = torch_br.supa_flash_attn_cache_infer( # type: ignore
query,
kv_cache,
attn_metadata.query_start_loc,
attn_metadata.seq_start_loc,
attn_metadata.block_table,
attn_metadata.context_lens,
attn_metadata.slot_mapping,
attn_metadata.max_seq_len,
self.head_size,
window_size=self.sliding_window,
sinks=self.sinks)
return output
def _get_causal_option(attn_type: str) -> bool:
"""
Determine whether the given attention type is suitable for causal
attention mechanisms.
Args:
attn_type (AttentionType): The type of attention being evaluated
Returns:
bool: Returns `True` if the attention type is suitable for causal
attention (i.e., not encoder, encoder-only, or encoder-decoder),
otherwise returns `False`.
"""
return not (attn_type == AttentionType.ENCODER
or attn_type == AttentionType.ENCODER_ONLY
or attn_type == AttentionType.ENCODER_DECODER)

View File

@@ -0,0 +1,19 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from . import flashmla # noqa: F401
from . import flashmla_sparse # noqa: F401
from . import indexer # noqa: F401

View File

@@ -0,0 +1,657 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import itertools
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Tuple, Union
import torch
import torch_br
from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
is_quantized_kv_cache)
from vllm.config import VllmConfig
from vllm.distributed import (get_tensor_model_parallel_world_size,
get_tp_group, tensor_model_parallel_all_reduce)
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, ReplicatedLinear,
RowParallelLinear,
UnquantizedLinearMethod)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.v1.attention.backends.flash_attn import _get_sliding_window_configs
from vllm.v1.attention.backends.mla.common import (MLACommonImpl,
MLACommonMetadataBuilder)
from vllm.v1.attention.backends.mla.flashmla import (FlashMLABackend,
FlashMLAMetadata)
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_br import envs
from vllm_br.model_executor.layers.br_utils import _convert_to_numa_tensor
from vllm_br.utils import get_grandparent_pid
from vllm_br.v1.attention.backends.utils import SUPACommonAttentionMetadata
if TYPE_CHECKING:
pass
logger = init_logger(__name__)
class SupaFlashMLABackend(FlashMLABackend):
# NOTE: When piecewise cudagraph is enabled, this
# makes sure the output tensor is allocated inside the cudagraph.
# NOTE: currently, we do not support accept_output_buffer=True
accept_output_buffer: bool = False
@staticmethod
def get_supported_head_sizes() -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_name() -> str:
return "SUPAFLASHMLA"
@staticmethod
def get_metadata_cls() -> type["SupaFlashMLAMetadata"]:
return SupaFlashMLAMetadata
@staticmethod
def get_builder_cls() -> type["SupaFlashMLAMetadataBuilder"]:
return SupaFlashMLAMetadataBuilder
@staticmethod
def get_impl_cls() -> type["SupaFlashMLAImpl"]:
return SupaFlashMLAImpl
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_kv_cache_usharp_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
th_gran = SupaFlashMLABackend.get_kv_cache_usharp_alignment(block_size)
n_block = max(1, (num_blocks + th_gran - 1) // th_gran)
# return (2, n_block, th_gran * block_size, num_kv_heads * head_size)
logger.debug(
f'Origin kv cache shape is [1, {num_blocks}, {block_size}, {num_kv_heads}, {head_size}, For SUPA Speed up, use [1, {n_block}, {th_gran * block_size}, {num_kv_heads * head_size}]' # noqa: G004
)
# TODO, shared kv only used in deepseek
return (1, n_block, th_gran * block_size, num_kv_heads * head_size)
@staticmethod
def get_kv_cache_usharp_alignment(block_size: int) -> int:
max_h_limit = 2048
return max_h_limit // block_size
@dataclass
class SupaFlashMLAMetadata:
# class SupaFlashMLAMetadata(FlashMLAMetadata):
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
# BIREN Attention Params
seq_start_loc: torch.Tensor
context_lens: torch.Tensor
max_decode_seq_len: int
do_cache: bool # when use attentionsplit, do cache = False
# For handling prefill decode split
num_decodes: int
num_decode_tokens: int
num_prefills: int
num_prefill_tokens: int
num_actual_reqs: torch.Tensor
# For cascade attention.
use_cascade: bool
common_prefix_len: int
cu_prefix_query_lens: Optional[torch.Tensor]
prefix_kv_lens: Optional[torch.Tensor]
suffix_kv_lens: Optional[torch.Tensor]
# Optional aot scheduling
scheduler_metadata: Optional[torch.Tensor] = None
prefix_scheduler_metadata: Optional[torch.Tensor] = None
class SupaFlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_BATCH
reorder_batch_threshold: int = 1
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
FlashMLAMetadata)
self.vllm_config = vllm_config
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config)
self.cg_buf_tile_scheduler_metadata = None
self.cg_buf_num_splits = None
device_properties = torch.cuda.get_device_properties(self.device)
num_sms = device_properties.multi_processor_count
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.cg_buf_tile_scheduler_metadata = torch.zeros(
# Upper bound on size (<= #SMs, TileSchedulerMetaDataSize)
# TileSchedulerMetaDataSize = 8
(num_sms, 8),
device=self.device,
dtype=torch.int32,
)
self.cg_buf_num_splits = torch.empty(
(vllm_config.scheduler_config.max_num_seqs + 1),
device=self.device,
dtype=torch.int32)
self.aot_schedule = False
logger.warning(
"AOT Schedule is disabled when using SUPAFlashAttention.")
# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
self.aot_sliding_window: Optional[tuple[int, int]] = None
supports_spec_as_decode = True
self._init_reorder_batch_threshold(1, supports_spec_as_decode)
def build(self,
common_prefix_len: int,
common_attn_metadata: SUPACommonAttentionMetadata,
fast_build: bool = False):
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
max_seq_len = int(common_attn_metadata.seq_lens_cpu[:num_reqs].max())
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
num_actual_reqs = common_attn_metadata.num_actual_reqs
aot_schedule = self.aot_schedule and not fast_build
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=self.reorder_batch_threshold,
require_uniform=True)
if self.aot_sliding_window is None:
self.aot_sliding_window = (-1, -1)
# For the AOT scheduler we need the sliding window value to be
# constant for all layers to. We have to populate this on the first
# build() call so the layers are constructed (cannot populate)
# in __init__.
if aot_schedule:
sliding_window_configs = _get_sliding_window_configs(
self.vllm_config)
if len(sliding_window_configs) == 1:
sliding_window_config = sliding_window_configs.pop()
if sliding_window_config is not None:
self.aot_sliding_window = sliding_window_config
elif len(sliding_window_configs) > 1:
self.aot_schedule = False
aot_schedule = False
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
max_seq_len, causal):
if self.aot_schedule:
raise NotImplementedError(
'aot schedule not support in SUPA attention')
return None
use_cascade = common_prefix_len > 0
if use_cascade:
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
dtype=torch.int32,
device=self.runner.device)
prefix_kv_lens = torch.tensor([common_prefix_len],
dtype=torch.int32,
device=self.runner.device)
suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] -
common_prefix_len)
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
self.runner.device)
prefix_scheduler_metadata = schedule(
batch_size=1,
cu_query_lens=cu_prefix_query_lens,
max_query_len=num_actual_tokens,
seqlens=prefix_kv_lens,
max_seq_len=common_prefix_len,
causal=False)
scheduler_metadata = schedule(batch_size=num_reqs,
cu_query_lens=query_start_loc,
max_query_len=max_query_len,
seqlens=suffix_kv_lens,
max_seq_len=max_seq_len -
common_prefix_len,
causal=True)
else:
cu_prefix_query_lens = None
prefix_kv_lens = None
suffix_kv_lens = None
prefix_scheduler_metadata = None
scheduler_metadata = schedule(batch_size=num_reqs,
cu_query_lens=query_start_loc,
max_query_len=max_query_len,
seqlens=seq_lens,
max_seq_len=max_seq_len,
causal=True)
if common_attn_metadata.seq_start_loc is None:
if len(seq_lens) > 8:
seq_lens_cpu = seq_lens.cpu()
seq_start_loc = torch.tensor(
[0] + list(itertools.accumulate(seq_lens_cpu)),
device=query_start_loc.device,
dtype=torch.int32)
else:
seq_start_loc = torch.tensor(
[0] + list(itertools.accumulate(seq_lens)),
device=query_start_loc.device,
dtype=torch.int32)
else:
seq_start_loc = common_attn_metadata.seq_start_loc
if common_attn_metadata.context_lens is None:
context_lens = seq_lens - (query_start_loc[1:] -
query_start_loc[:-1])
else:
context_lens = common_attn_metadata.context_lens
if common_attn_metadata.max_decode_seq_len is None:
max_decode_seq_len = max_decode_seq_len = int(
seq_lens.max().item())
else:
max_decode_seq_len = common_attn_metadata.max_decode_seq_len
attn_metadata = SupaFlashMLAMetadata(
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table_tensor,
slot_mapping=slot_mapping,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
scheduler_metadata=scheduler_metadata,
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
prefix_scheduler_metadata=prefix_scheduler_metadata,
# Biren Attention Params
seq_start_loc=seq_start_loc,
context_lens=context_lens,
max_decode_seq_len=max_decode_seq_len,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
do_cache=True,
num_actual_reqs=num_actual_reqs)
return attn_metadata
def can_run_in_cudagraph(
self, common_attn_metadata: SUPACommonAttentionMetadata) -> bool:
# Full CUDA Graph always supported (FA2 support checked separately)
return False
def use_cascade_attention(self, *args, **kwargs) -> bool:
return False
# class SupaFlashMLAImpl(FlashMLAImpl):
class SupaFlashMLAImpl(MLACommonImpl[SupaFlashMLAMetadata]):
can_return_lse_for_decode: bool = True
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
q_lora_rank: Optional[int],
kv_lora_rank: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
qk_head_dim: int,
v_head_dim: int,
kv_b_proj: ColumnParallelLinear,
rotary_emb: RotaryEmbedding,
# # q_proj should be q_b_proj if q_lora_rank is not None, but from an
# # attention backend perspective we rely on the layer to pass in the
# # correct matrix
q_proj: ColumnParallelLinear, # q_b_proj
# kv_b_proj: ColumnParallelLinear,
o_proj: RowParallelLinear,
kv_a_proj_with_mqa: ReplicatedLinear,
kv_a_layernorm: Any,
q_a_proj: ReplicatedLinear,
q_a_layernorm: Any,
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, q_lora_rank,
kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim,
qk_head_dim, v_head_dim, kv_b_proj, **mla_args)
self.rotary_emb = rotary_emb
self.q_proj = q_proj
self.kv_b_proj = kv_b_proj
self.o_proj = o_proj
self.kv_a_proj_with_mqa = kv_a_proj_with_mqa
self.kv_a_layernorm = kv_a_layernorm
self.q_a_layernorm = q_a_layernorm
self.q_a_proj = q_a_proj
self.tp_size = get_tensor_model_parallel_world_size()
cur_device = torch.supa.current_device()
self.spc_num = torch_br.supa.get_device_properties(
cur_device).max_compute_units
if envs.VLLM_BR_USE_FUSED_ALLREDUCE and self.tp_size == 8 and self.spc_num == 16:
# Initialize the p2p info
torch.supa.init_p2p_remote_id(cur_device)
assert self.q_lora_rank is not None
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"SUPAFlashMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"SUPAFlashMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"SUPAFlashMLA V1 with FP8 KV cache not yet supported")
def process_weights_after_loading(self, act_dtype: torch.dtype):
def get_layer_weight(layer):
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
for attr in WEIGHT_NAMES:
if hasattr(layer, attr):
return getattr(layer, attr)
raise AttributeError(
f"Layer '{layer}' has no recognized weight attribute:"
f" {WEIGHT_NAMES}.")
def get_and_maybe_dequant_weights(layer: LinearBase):
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
# NOTE: This should only be used offline, since it's O(N^3)
eye = torch.eye(layer.input_size_per_partition,
dtype=act_dtype,
device=get_layer_weight(layer).device)
dequant_weights = layer.quant_method.apply(layer,
eye,
bias=None)
del eye
# standardize to (output, input)
return dequant_weights.T
return layer.weight
if self.q_lora_rank is not None:
# handle deepseek_v3 weight
w_q_a = get_and_maybe_dequant_weights(self.q_a_proj).T
w_kv_a = get_and_maybe_dequant_weights(self.kv_a_proj_with_mqa).T
w_qkv_a = torch.cat([w_q_a, w_kv_a], dim=-1)
# w_qkv_a must make two copies in br166
align_size = 32
die_spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
if die_spc_num > 16:
w_qkv_a = torch.cat([w_qkv_a, w_qkv_a], dim=-1)
self.w_qkv_a = _convert_to_numa_tensor(w_qkv_a, align_size,
"colmajor", w_qkv_a.dtype)
w_kv_b = get_and_maybe_dequant_weights(self.kv_b_proj).T
w_k_b, w_v_b = w_kv_b.reshape(
self.kv_lora_rank, -1,
self.qk_nope_head_dim + self.v_head_dim).split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
w_k_b = w_k_b.permute(1, 2, 0).contiguous()
w_v_b = w_v_b.permute(1, 0, 2).contiguous()
w_o = get_and_maybe_dequant_weights(self.o_proj.to(w_v_b.device)).T
hidden_dim = w_o.shape[-1]
w_o = w_o.reshape(-1, self.v_head_dim, hidden_dim)
w_vo = torch.bmm(w_v_b, w_o).reshape(-1, hidden_dim)
self.w_vo = _convert_to_numa_tensor(w_vo,
align_size,
"colmajor",
w_qkv_a.dtype,
parallel_type="row_parallel")
# replace q_b_proj as q_proj
w_q_b = get_and_maybe_dequant_weights(self.q_proj).T
w_q_b_nope, w_q_b_rope = w_q_b.reshape(
self.q_lora_rank, -1, self.qk_head_dim).split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
w_q_b_nope = w_q_b_nope.permute(1, 0, 2).contiguous()
w_q_b_rope = w_q_b_rope.reshape(self.q_lora_rank, -1)
w_qk_b_nope = torch.bmm(w_q_b_nope, w_k_b).permute(
1, 0, 2).contiguous().reshape(self.q_lora_rank, -1)
# w_qk_b_nope w_q_b_rope is independent head, separate like QKVParallelLinear
if die_spc_num > 16:
qk_b_nope0, qk_b_nope1 = torch.chunk(w_qk_b_nope, 2, dim=-1)
qk_b_rope0, qk_b_rope1 = torch.chunk(w_q_b_rope, 2, dim=-1)
w_qk_b = torch.cat(
[qk_b_nope0, qk_b_rope0, qk_b_nope1, qk_b_rope1], dim=-1)
else:
w_qk_b = torch.cat([w_qk_b_nope, w_q_b_rope], dim=-1)
self.w_qk_b = _convert_to_numa_tensor(w_qk_b, align_size,
"colmajor", w_qkv_a.dtype)
self.q_a_proj.weight = None
self.kv_a_proj_with_mqa.weight = None
self.q_proj.weight = None
self.kv_b_proj.weight = None
self.o_proj.weight = None
if self.kv_a_layernorm.weight.dtype != torch.float32:
self.kv_a_layernorm.weight.data = self.kv_a_layernorm.weight.to(
torch.float32)
if self.q_a_layernorm.weight.dtype != torch.float32:
self.q_a_layernorm.weight.data = self.q_a_layernorm.weight.to(
torch.float32)
else:
raise NotImplementedError
torch.supa.empty_cache()
def _forward_decode(
self,
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLAMetadata,
layer: AttentionLayer,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
raise NotImplementedError
def forward(
self,
layer: AttentionLayer,
hidden_states: torch.Tensor, # query in unified attn
positions: torch.Tensor, # reuse k_c_normed as position
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
attn_metadata: SupaFlashMLAMetadata,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with torch SPDA and PagedAttention.
Args:
hidden_states: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [1, num_blocks, block_size * num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is None, "Output tensor should not provided."
if envs.VLLM_BR_USE_CPU_ALL_REDUCE != 0 and not hasattr(
self, "grandparent_pid"):
self.grandparent_pid = get_grandparent_pid()
# profile and warm up mla attention kernel
if attn_metadata is None:
return hidden_states
# handle deepseek_v3 mla
if hidden_states.shape[1] <= 512:
query, key = torch_br.supa_mla_prefix_infer_v2(
hidden_states, self.w_qkv_a, self.w_qk_b,
self.q_a_layernorm.weight, self.kv_a_layernorm.weight,
self.rotary_emb.sin_cache, self.rotary_emb.cos_cache,
positions, kv_cache, attn_metadata.slot_mapping,
self.num_heads, self.qk_head_dim, self.qk_nope_head_dim,
self.qk_rope_head_dim, self.kv_lora_rank, self.v_head_dim,
self.q_lora_rank, self.kv_a_layernorm.variance_epsilon)
else:
query, key = torch_br.supa_mla_prefix_infer_v3(
hidden_states, self.w_qkv_a, self.w_qk_b,
self.q_a_layernorm.weight, self.kv_a_layernorm.weight,
self.rotary_emb.sin_cache, self.rotary_emb.cos_cache,
positions, kv_cache, attn_metadata.slot_mapping,
self.num_heads, self.qk_head_dim, self.qk_nope_head_dim,
self.qk_rope_head_dim, self.kv_lora_rank, self.v_head_dim,
self.q_lora_rank, self.kv_a_layernorm.variance_epsilon)
if query.shape[0] == 1:
output = torch.empty_like(query)
else:
output = torch_br._empty_ut_only(
[1, query.shape[1], query.shape[0] * self.kv_lora_rank],
device=query.device,
dtype=query.dtype,
tensor_type="colmajor",
axis=2,
sbp="SB" if envs.VLLM_BR_DEVICE_SPC_NUM > 16 else None)
num_prefill_tokens = attn_metadata.num_prefill_tokens
#decoder_qloc = attn_metadata.query_start_loc[:attn_metadata.num_decodes + 1].cpu()
#if decoder_qloc.shape[0] > 1:
# assert torch.all(torch.diff(decoder_qloc) == 1), f"Must ensure that it is an increasing queue with a step of 1 !\nq_loc:{attn_metadata.query_start_loc}"
#print("num_prefill_tokens:", num_prefill_tokens)
if num_prefill_tokens > 0:
assert len(query.shape) == 3
output = torch_br.br_flash_attn_with_kvcache_infer( # type: ignore
query,
kv_cache,
attn_metadata.query_start_loc,
attn_metadata.seq_start_loc,
attn_metadata.block_table,
self.head_size,
alibi_slopes=None,
softmax_scale=self.scale,
v_head_size=self.kv_lora_rank,
num_reqs=attn_metadata.num_actual_reqs,
)
else:
assert len(query.shape) == 3 and attn_metadata.num_prefills == 0
output = torch_br.supa_attention_decoder_infer_v2( # type: ignore
query, # type: ignore
kv_cache,
attn_metadata.block_table,
attn_metadata.seq_lens,
attn_metadata.max_decode_seq_len,
self.head_size,
attn_metadata.num_prefills,
alibi_slopes=None,
softmax_scale=self.scale,
v_head_size=self.kv_lora_rank,
)
# now linear+allreduce only support M <= 512 and tp_size == 4 | 8 and spc_num == 16
seq_len = hidden_states.shape[-2]
tp_size = get_tensor_model_parallel_world_size()
support_types = ((16, 4), (16, 8), (32, 2), (32, 4))
fused_comm = (envs.VLLM_BR_USE_FUSED_ALLREDUCE
and seq_len <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN
and
(envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types)
if fused_comm:
tp_rank = get_tp_group().rank_in_group
global_rank = get_tp_group().rank
rank_i = global_rank % tp_size
assert rank_i == tp_rank
o_proj_out = torch_br.supa_fused_linear_allreduce_opt(
output, self.w_vo, hidden_states.shape[-1], tp_rank, tp_size,
global_rank, 0)
else:
# do o_proj
output_parallel = torch_br.br_fused_mlp_infer(
output, [self.w_vo], output_w=hidden_states.shape[-1])
if self.tp_size > 1:
o_proj_out = tensor_model_parallel_all_reduce(output_parallel)
else:
o_proj_out = output_parallel
return o_proj_out

View File

@@ -0,0 +1,450 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import itertools
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Optional, Tuple
import numpy as np
import torch
import torch_br
from vllm.attention.backends.abstract import AttentionLayer, AttentionMetadata
from vllm.attention.ops.flashmla import get_mla_metadata
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.flashmla_sparse import (
FlashMLASparseBackend, FlashMLASparseImpl, FlashMLASparseMetadata,
FlashMLASparseMetadataBuilder)
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
CommonAttentionMetadata,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
if TYPE_CHECKING:
from vllm.model_executor.models.deepseek_v2 import Indexer
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
logger = init_logger(__name__)
_NO_DEFAULT = object()
@dataclass
class SupaFlashMLASparseMetadata(FlashMLASparseMetadata):
# BIREN Attention Params
seq_start_loc: torch.Tensor = _NO_DEFAULT
context_lens: torch.Tensor = _NO_DEFAULT
max_decode_seq_len: int = -1
num_prefills: int = -1
num_decodes: int = -1
num_prefill_tokens: int = -1
num_decode_tokens: int = -1
def __post_init__(self):
if self.seq_start_loc is _NO_DEFAULT or self.context_lens is _NO_DEFAULT or \
self.max_decode_seq_len == -1 or self.num_prefills == -1 or \
self.num_decodes == -1 or self.num_prefill_tokens == -1 or \
self.num_decode_tokens == -1:
raise TypeError("__init__ missing required argument")
class SupaFlashMLASparseMetadataBuilder(FlashMLASparseMetadataBuilder):
reorder_batch_threshold: int = 1
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_BATCH
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(
kv_cache_spec=kv_cache_spec,
layer_names=layer_names,
vllm_config=vllm_config,
device=device,
)
self.vllm_config = vllm_config
self.num_speculative_tokens = (
self.vllm_config.speculative_config.num_speculative_tokens
if self.vllm_config.speculative_config else 0)
# Now deepgemm fp8_paged_mqa_logits does not support next_n > 2
self.reorder_batch_threshold += min(self.num_speculative_tokens, 1)
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
"""On SUPA, we want prefill at front and decode at back.
"""
# TODO update doc
# We now want to reorder the batch so that the "decode" requests are and
# the front and the "prefill" requests are at the using the least amount
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
# where attention is likely memory-bound and "prefill" to mean requests
# where attention is likely compute-bound, TODO(lucas): figure out a
# better naming here)
decodes = []
prefills = []
num_decode_tokens = 0
num_prefill_tokens = 0
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_spec_tokens = len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
# for now treat 1 scheduled token as "decode" even if its not,
# we should update this to something like < 8 in the future but
# currently the TritonMLA._forward_decode only supports
# num_tokens = 1
if num_tokens - num_spec_tokens == 1:
decodes.append(i)
num_decode_tokens += num_tokens
else:
prefills.append(i)
num_prefill_tokens += num_tokens
# TODO update doc
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
modified_batch = False
# for i in range(1, min(num_decodes, num_prefills) + 1):
# # If the decode is at the "back" of the batch, i, we can swap it
# # with the prefill closest to the front of the batch
# decode_idx = decodes[num_decodes - i]
# if decode_idx < num_decodes:
# break
# input_batch.swap_states(prefills[i - 1], decode_idx)
# modified_batch = True
for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
prefills_idx = prefills[num_prefills - i]
if prefills_idx < num_prefills:
break
input_batch.swap_states(decodes[i - 1], prefills_idx)
modified_batch = True
# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
self._num_decodes = num_decodes
self._num_prefills = num_prefills
self._num_decode_tokens = num_decode_tokens
self._num_prefill_tokens = num_prefill_tokens
return modified_batch
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> SupaFlashMLASparseMetadata:
num_tokens = common_attn_metadata.num_actual_tokens
starts = np.asarray(common_attn_metadata.query_start_loc_cpu,
dtype=np.int32)
seg_lengths = np.diff(starts)
req_id_per_token = np.repeat(
np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths)
# Zero-fill for cudagraphs
self.req_id_per_token_buffer.fill_(0)
self.req_id_per_token_buffer[:req_id_per_token.shape[0]]\
.copy_(torch.from_numpy(req_id_per_token), non_blocking=True)
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
fp8_extra_metadata = None
if self.use_fp8_kv_cache:
tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens=self.topk_tokens_tensor,
num_q_tokens_per_head_k=num_tokens * self.num_heads,
topk=self.topk_tokens,
num_heads_q=self.num_heads,
num_heads_k=1,
is_fp8_kvcache=True,
)
num_sm_parts = tile_scheduler_metadata.size(0)
# Copy to persistent buffer for full-CG support
tile_scheduler_metadata_buffer = \
self.tile_scheduler_metadata_buffer[:num_sm_parts]
tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata)
self.num_splits_buffer.copy_(num_splits)
fp8_extra_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
scheduler_metadata=tile_scheduler_metadata_buffer,
num_splits=self.num_splits_buffer,
# cache_lens and block_table are basically unused in sparse case
# but the decode kernel will treat -1 and indices >= cache_lens
# as invalid so we make sure cache_lens is large enough to not
# accidentally mark indices invalid, we will use -1 exclusively
# to mark invalid indices
cache_lens=self.max_model_len_tensor,
dummy_block_table=self.dummy_block_table)
# Add biren attention params
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
if common_attn_metadata.seq_start_loc is None:
if len(seq_lens) > 8:
seq_lens_cpu = seq_lens.cpu()
seq_start_loc = torch.tensor(
[0] + list(itertools.accumulate(seq_lens_cpu)),
device=query_start_loc.device,
dtype=torch.int32)
else:
seq_start_loc = torch.tensor(
[0] + list(itertools.accumulate(seq_lens)),
device=query_start_loc.device,
dtype=torch.int32)
else:
seq_start_loc = common_attn_metadata.seq_start_loc
if common_attn_metadata.context_lens is None:
context_lens = seq_lens - (query_start_loc[1:] -
query_start_loc[:-1])
else:
context_lens = common_attn_metadata.context_lens
if common_attn_metadata.max_decode_seq_len is None:
max_decode_seq_len = max_decode_seq_len = int(
seq_lens.max().item())
else:
max_decode_seq_len = common_attn_metadata.max_decode_seq_len
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_tokens
metadata = SupaFlashMLASparseMetadata(
num_reqs=common_attn_metadata.num_reqs,
max_query_len=common_attn_metadata.max_query_len,
max_seq_len=common_attn_metadata.max_seq_len,
num_actual_tokens=common_attn_metadata.num_actual_tokens,
query_start_loc=common_attn_metadata.query_start_loc,
slot_mapping=common_attn_metadata.slot_mapping,
block_table=common_attn_metadata.block_table_tensor,
req_id_per_token=req_id_per_token,
block_size=self.kv_cache_spec.block_size,
topk_tokens=self.topk_tokens,
fp8_extra_metadata=fp8_extra_metadata,
seq_start_loc=seq_start_loc,
context_lens=context_lens,
max_decode_seq_len=max_decode_seq_len,
num_prefills=num_prefills,
num_decodes=num_decodes,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
)
return metadata
class SupaFlashMLASparseBackend(FlashMLASparseBackend):
@staticmethod
def get_name() -> str:
return "SUPA_FLASHMLA_SPARSE_VLLM_V1"
@staticmethod
def get_metadata_cls() -> type[AttentionMetadata]:
return SupaFlashMLASparseMetadata
@staticmethod
def get_builder_cls() -> type["SupaFlashMLASparseMetadataBuilder"]:
return SupaFlashMLASparseMetadataBuilder
@staticmethod
def get_impl_cls() -> type["SupaFlashMLASparseImpl"]:
return SupaFlashMLASparseImpl
@staticmethod
def get_kv_cache_usharp_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
th_gran = SupaFlashMLASparseBackend.get_kv_cache_usharp_alignment(
block_size)
n_block = max(1, (num_blocks + th_gran - 1) // th_gran)
logger.debug(
f'Origin kv cache shape is [2, {num_blocks}, {block_size}, {num_kv_heads}, {head_size}, For SUPA Speed up, use [2, {n_block}, {th_gran * block_size}, {num_kv_heads * head_size}]' # noqa: G004
)
return (2, n_block, th_gran * block_size, num_kv_heads * head_size)
@staticmethod
def get_kv_cache_usharp_alignment(block_size: int) -> int:
max_h_limit = 2048
return max_h_limit // block_size
class SupaFlashMLASparseImpl(FlashMLASparseImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
topk_indice_buffer: Optional[torch.Tensor] = None,
indexer: Optional["Indexer"] = None,
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, topk_indice_buffer,
indexer, **mla_args)
def _forward_bf16_kv(
self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
topk_indices: torch.Tensor,
attn_metadata: SupaFlashMLASparseMetadata) -> torch.Tensor:
bsz = 1
seq_len_q, num_heads, _ = q.shape
# topk_indices = topk_indices.unsqueeze(0)
index_mask = torch.full((bsz, seq_len_q, seq_len_q),
1,
dtype=torch.int32,
device=q.device)
# .scatter_(-1, valid_mask.to(torch.int64), 0).to(torch.int32).supa()
for idx_bsz in range(bsz):
for idx_q in range(seq_len_q):
for idx_k in range(topk_indices.shape[-1]):
target_idx = topk_indices[idx_q][idx_k]
if target_idx >= 0 and target_idx < seq_len_q:
index_mask[idx_bsz][idx_q][topk_indices[idx_q]
[idx_k]] = 0
query = q.transpose(0,
1).contiguous() # [num_heads, seq_len, head_dim]
# output is always [1, seq_len, num_heads * head_dim] however query;s shape is
output = torch_br.supa_flash_attn_cache_infer(
query,
kv_c_and_k_pe_cache[:
1], # [1, num_blocks, block_szieself.head_size]
attn_metadata.query_start_loc,
attn_metadata.seq_start_loc,
attn_metadata.block_table,
attn_metadata.context_lens,
attn_metadata.slot_mapping,
attn_metadata.max_seq_len,
self.head_size,
softmax_scale=self.softmax_scale,
v_head_size=self.kv_lora_rank,
mask=index_mask)
output = output.reshape(seq_len_q, num_heads,
self.kv_lora_rank).contiguous()
return output
def forward(
self,
layer: AttentionLayer,
q: torch.Tensor,
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
attn_metadata: SupaFlashMLASparseMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
# MQA 576/512 approach for both prefill and decode
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for MLACommonImpl")
if attn_metadata is None:
# The zero fill is required when used with DP + EP
# to ensure all ranks within a DP group compute the
# same expert outputs.
return output.fill_(0)
num_actual_toks = attn_metadata.num_actual_tokens
# Inputs and outputs may be padded for CUDA graphs
q = q[:num_actual_toks, ...]
k_c_normed = k_c_normed[:num_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...]
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
dim=-1)
# Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
ql_nope = ql_nope.transpose(0, 1)
topk_indices = self.topk_indices_buffer[:num_actual_toks]
# TODO: handle index / kv_cache correctly
# topk_indices_global = triton_convert_req_index_to_global_index(
# attn_metadata.req_id_per_token,
# attn_metadata.block_table,
# topk_indices,
# BLOCK_SIZE=attn_metadata.block_size,
# NUM_TOPK_TOKENS=attn_metadata.topk_tokens,
# )
q = torch.cat([ql_nope, q_pe], dim=-1)
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
_, num_blocks, block_size, head_size = kv_cache.shape
k_pe_tmp = k_pe.squeeze(1).unsqueeze(0)
key_supa = torch.cat([k_c_normed, k_pe_tmp], dim=2)
torch_br.supa_kvcache_store_infer_v2(kv_cache, key_supa, key_supa,
attn_metadata.slot_mapping,
head_size)
if self.kv_cache_dtype != "fp8_ds_mla":
attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices,
attn_metadata)
else:
raise RuntimeError("Not support fp8 on br.")
self._v_up_proj(attn_out, out=output[:num_actual_toks])
return output

View File

@@ -0,0 +1,140 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Tuple
import torch
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerBackend, DeepSeekV32IndexerDecodeMetadata,
DeepseekV32IndexerMetadata, DeepseekV32IndexerMetadataBuilder,
DeepseekV32IndexerPrefillMetadata, split_prefill_chunks)
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
split_decodes_and_prefills)
logger = init_logger(__name__)
class SupaDeepseekV32IndexerBackend(DeepseekV32IndexerBackend):
@staticmethod
def get_builder_cls() -> type["SupaDeepseekV32IndexerMetadataBuilder"]:
return SupaDeepseekV32IndexerMetadataBuilder
@staticmethod
def get_kv_cache_usharp_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
th_gran = SupaDeepseekV32IndexerBackend.get_kv_cache_usharp_alignment(
block_size)
n_block = max(1, (num_blocks + th_gran - 1) // th_gran)
logger.debug(
f'Origin kv cache shape is [1, {num_blocks}, {block_size}, {num_kv_heads}, {head_size}, For SUPA Speed up, use [1, {n_block}, {th_gran * block_size}, {num_kv_heads * head_size}]' # noqa: G004
)
return (1, n_block, th_gran * block_size, num_kv_heads * head_size)
@staticmethod
def get_kv_cache_usharp_alignment(block_size: int) -> int:
max_h_limit = 2048
return max_h_limit // block_size
class SupaDeepseekV32IndexerMetadataBuilder(DeepseekV32IndexerMetadataBuilder):
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> DeepseekV32IndexerMetadata:
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_tokens
prefill_metadata = None
if num_prefills > 0:
chunk_seq_ids = split_prefill_chunks(
common_attn_metadata.seq_lens_cpu,
self.max_prefill_buffer_size,
num_decodes,
)
chunks = [
self.build_one_prefill_chunk(
reqs_start, reqs_end, query_start_loc_cpu,
common_attn_metadata.seq_lens_cpu,
common_attn_metadata.block_table_tensor)
for reqs_start, reqs_end in chunk_seq_ids
]
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
chunks=chunks, )
decode_metadata = None
if num_decodes > 0:
torch.diff(common_attn_metadata.query_start_loc[:num_decodes + 1],
out=self.decode_lens_buffer[:num_decodes])
decode_lens = self.decode_lens_buffer[:num_decodes]
decode_lens_cpu = torch.diff(
common_attn_metadata.query_start_loc_cpu[:num_decodes + 1])
# Use CPU to avoid GPU sync; breaking async scheduling
requires_padding = (decode_lens_cpu.max()
> decode_lens_cpu.min()).item()
# self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
# seq_lens, self.kv_cache_spec.block_size, self.num_sms)
self.scheduler_metadata_buffer = None
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
block_table=common_attn_metadata.
block_table_tensor[:num_decodes, ...],
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
decode_lens=decode_lens,
requires_padding=requires_padding,
schedule_metadata=self.scheduler_metadata_buffer,
)
attn_metadata = DeepseekV32IndexerMetadata(
seq_lens=common_attn_metadata.seq_lens,
num_reqs=common_attn_metadata.num_reqs,
max_query_len=common_attn_metadata.max_query_len,
max_seq_len=common_attn_metadata.max_seq_len,
num_actual_tokens=common_attn_metadata.num_actual_tokens,
query_start_loc=common_attn_metadata.query_start_loc,
slot_mapping=common_attn_metadata.slot_mapping,
head_dim=128,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
prefill=prefill_metadata,
decode=decode_metadata,
)
# if get_tensor_model_parallel_rank() == 0:
# logger.info(f"attn_metadata: {attn_metadata}")
return attn_metadata

View File

@@ -0,0 +1,47 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from dataclasses import dataclass
import torch
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm_br.config.compilation import SUPAGraphMode
@dataclass
class SUPACommonAttentionMetadata(CommonAttentionMetadata):
"""
Attention metadata attributes that can be shared by layers in different KV
cache groups and thus having different block table.
"""
query_start_loc: torch.Tensor
"""(batch_size + 1,), the start location of each request in query Tensor"""
seq_lens: torch.Tensor
"""(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens"""
num_actual_reqs: torch.Tensor | None = None
"""(1,), numble of actual request in the batch"""
supagraph_runtime_mode: SUPAGraphMode | None = None
context_lens: torch.Tensor | None = None
"""(batch_size,), the length of each request including computed tokens only"""
max_decode_seq_len: int | None = None
"""The maximum length of the decoded sequence in the batch."""
seq_start_loc: torch.Tensor | None = None
"""(batch_size + 1,), the start location of each request in sequence Tensor.
This is used to compute the sequence length of each request.
If not provided, it will be computed from seq_lens."""

View File

@@ -0,0 +1,17 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from . import kv_cache_utils, sched # noqa: F401

Binary file not shown.

View File

@@ -0,0 +1,219 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from fastcore.basics import patch_to
import vllm.v1.core.kv_cache_utils
from vllm.config import VllmConfig
from vllm.logger import logger
from vllm.v1.core.kv_cache_utils import (
create_kv_cache_group_specs, get_max_concurrency_for_kv_cache_config,
get_uniform_page_size, may_override_num_blocks)
from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheGroupSpec,
KVCacheSpec, KVCacheTensor,
UniformTypeKVCacheSpecs)
from vllm_br.v1.attention.backends.attention_v1 import (
SUPAFlashAttentionBackend)
@patch_to(vllm.v1.core.kv_cache_utils)
def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int) -> KVCacheConfig:
"""
Generates the KV cache configuration for a model with one type of KV cache.
Divide the available memory equally among all layers.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes.
Returns:
The generated KVCacheConfig
"""
page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()}
assert len(page_sizes) == 1
page_size = page_sizes.pop()
# NOTE: SUPA has layouts
# Both MLA/FlashAttention use the same gran
th_gran = SUPAFlashAttentionBackend.get_kv_cache_usharp_alignment(
vllm_config.cache_config.block_size)
num_blocks = int(available_memory // page_size // len(kv_cache_spec))
# NOTE: limit gpu blocks number due to the shape restriction of colmajor layout
num_blocks = min(th_gran * 1024, num_blocks // th_gran * th_gran)
num_blocks = max(num_blocks, 0)
if vllm_config.cache_config.num_gpu_blocks_override is not None:
num_gpu_blocks_override = \
vllm_config.cache_config.num_gpu_blocks_override
logger.info(
"Overriding num_gpu_blocks=%d with "
"num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override)
num_blocks = num_gpu_blocks_override
num_tokens = num_blocks * vllm_config.cache_config.block_size
num_tokens_str = f"{num_tokens:,}"
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
max_concurrency = num_tokens / vllm_config.model_config.max_model_len
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
max_model_len_str, max_concurrency)
per_layer_size = page_size * num_blocks
# All layers have the same KV cache spec, so we create one kv cache group
# for all layers.
grouped_layer_names = [list(kv_cache_spec.keys())]
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
tensors={
layer_name: KVCacheTensor(size=per_layer_size)
for layer_name in kv_cache_spec
},
kv_cache_groups=create_kv_cache_group_specs(kv_cache_spec,
grouped_layer_names),
)
return kv_cache_config
logger.info('===[Patch] patch _get_kv_cache_config_uniform_type')
# @patch_to(vllm.v1.core.kv_cache_utils)
def get_num_blocks(vllm_config: VllmConfig, num_layers: int,
available_memory: int, page_size: int) -> int:
"""
Get the number of kv cache blocks.
Args:
vllm_config: The global VllmConfig
num_layers: The number of layers
available_memory: Memory available for KV cache in bytes.
page_size: The page size of the KV cache.
"""
th_gran = SUPAFlashAttentionBackend.get_kv_cache_usharp_alignment(
vllm_config.cache_config.block_size)
num_blocks = int(available_memory // page_size // num_layers)
num_blocks = min(th_gran * 1024, num_blocks // th_gran * th_gran)
num_blocks = max(num_blocks, 0)
num_blocks = may_override_num_blocks(vllm_config, num_blocks)
return num_blocks
@patch_to(vllm.v1.core.kv_cache_utils)
def get_kv_cache_config_from_groups(vllm_config: VllmConfig,
kv_cache_groups: list[KVCacheGroupSpec],
kv_cache_specs: dict[str, KVCacheSpec],
available_memory: int) -> KVCacheConfig:
"""
Generate the KV cache configuration from the KV cache groups and spec
of each layer.
Args:
vllm_config: The global VllmConfig
kv_cache_groups: The KV cache groups
kv_cache_specs: The KV cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes
Returns:
The generated KVCacheConfig
"""
if len(kv_cache_groups) == 0:
# Attention free models do not have KV cache.
# Return num_blocks=1 as BlockPool always needs a null_block.
return KVCacheConfig(
num_blocks=1,
kv_cache_tensors=[],
kv_cache_groups=kv_cache_groups,
)
# Determine how model runners should initialize the KV cache tensors.
# assert len(kv_cache_groups) == 1 # supa not support multi group
if len(kv_cache_groups) == 1 and \
isinstance(kv_cache_groups[0].kv_cache_spec, UniformTypeKVCacheSpecs):
# Special case: all layers have the same type of KV cache but with
# different hidden size. Allocate different amount of memory for each
# layer based on its hidden size.
th_gran = SUPAFlashAttentionBackend.get_kv_cache_usharp_alignment(
vllm_config.cache_config.block_size)
num_blocks = available_memory // kv_cache_groups[
0].kv_cache_spec.page_size_bytes
num_blocks = min(th_gran * 1024, num_blocks // th_gran * th_gran)
num_blocks = max(num_blocks, 0)
num_blocks = may_override_num_blocks(vllm_config, num_blocks)
per_layer_specs = kv_cache_groups[0].kv_cache_spec.kv_cache_specs
kv_cache_tensors = [
KVCacheTensor(size=per_layer_specs[layer_name].page_size_bytes *
num_blocks,
shared_by=[layer_name])
for layer_name in kv_cache_groups[0].layer_names
]
else:
# General case:
# We will have group_size memory pools, each is shared by one layer from
# each group. As layers of different groups have different block table,
# they will use different parts of the shared Tensor.
# The memory layout for 3 groups (full.0, full.1), (sw.0, sw.2),
# (sw.1, padding) will be: (group_size = 2)
# full.0, sw.0, sw.1: share a Tensor with size=available_memory//2
# full.1, sw.2: share another Tensor with size=available_memory//2
group_size = max(len(group.layer_names) for group in kv_cache_groups)
page_size = get_uniform_page_size(kv_cache_specs)
assert group_size > 0, "group_size must be greater than 0"
num_blocks = get_num_blocks(vllm_config, group_size, available_memory,
page_size)
kv_cache_tensors = []
for i in range(group_size):
shared_by = []
for j in range(len(kv_cache_groups)):
if i < len(kv_cache_groups[j].layer_names):
shared_by.append(kv_cache_groups[j].layer_names[i])
kv_cache_tensors.append(
KVCacheTensor(size=page_size * num_blocks,
shared_by=shared_by))
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=kv_cache_tensors,
kv_cache_groups=kv_cache_groups,
)
min_block_size = min(
[group.kv_cache_spec.block_size for group in kv_cache_groups])
# Print the KV cache size and maximum concurrency.
num_tokens = num_blocks // len(kv_cache_groups) * min_block_size
if vllm_config.parallel_config.decode_context_parallel_size > 1:
num_tokens *= vllm_config.parallel_config.decode_context_parallel_size
logger.info(
"Multiplying the GPU KV cache size by the dcp_world_size %d.",
vllm_config.parallel_config.decode_context_parallel_size)
num_tokens_str = f"{num_tokens:,}"
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
max_concurrency = get_max_concurrency_for_kv_cache_config(
vllm_config, kv_cache_config)
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
max_model_len_str, max_concurrency)
return kv_cache_config
logger.info('===[Patch] patch get_kv_cache_config_from_groups')

View File

@@ -0,0 +1,17 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from . import scheduler # noqa: F401

View File

@@ -0,0 +1,558 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from __future__ import annotations
import itertools
import time
from typing import Optional
from fastcore.basics import patch_to
from vllm.distributed.kv_events import KVEventBatch
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
SchedulerOutput)
from vllm.v1.core.sched.request_queue import (SchedulingPolicy,
create_request_queue)
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.engine import EngineCoreEventType
from vllm.v1.request import Request, RequestStatus
logger = init_logger(__name__)
@patch_to(Scheduler)
def schedule(self) -> SchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
# Each request just has the num_computed_tokens and
# num_tokens_with_spec. num_tokens_with_spec =
# len(prompt_token_ids) + len(output_token_ids) + len(spec_token_ids).
# At each step, the scheduler tries to assign tokens to the requests
# so that each request's num_computed_tokens can catch up its
# num_tokens_with_spec. This is general enough to cover
# chunked prefills, prefix caching, speculative decoding,
# and the "jump decoding" optimization in the future.
scheduled_new_reqs: list[Request] = []
scheduled_resumed_reqs: list[Request] = []
scheduled_running_reqs: list[Request] = []
preempted_reqs: list[Request] = []
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
num_scheduled_tokens: dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens
# Encoder-related.
scheduled_encoder_inputs: dict[str, list[int]] = {}
encoder_compute_budget = self.max_num_encoder_input_tokens
# Spec decode-related.
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
# For logging.
scheduled_timestamp = time.monotonic()
# First, schedule the RUNNING requests.
req_index = 0
while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index]
num_new_tokens = (request.num_tokens_with_spec +
request.num_output_placeholders -
request.num_computed_tokens)
if (0 < self.scheduler_config.long_prefill_token_threshold <
num_new_tokens):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
num_new_tokens = min(num_new_tokens, token_budget)
# Make sure the input position does not exceed the max model len.
# This is necessary when using spec decoding.
num_new_tokens = min(
num_new_tokens,
self.max_model_len - 1 - request.num_computed_tokens)
# Schedule encoder inputs.
encoder_inputs_to_schedule = None
new_encoder_compute_budget = encoder_compute_budget
if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens,
new_encoder_compute_budget) = self._try_schedule_encoder_inputs(
request, request.num_computed_tokens, num_new_tokens,
encoder_compute_budget)
if self.scheduler_config.chunked_prefill_enabled and request.num_output_tokens == 0:
# shortest chunked prefill length is num_spec_tokens + 1
prefill_schedul_threshold = self.num_spec_tokens + 1
# Calculate remaining prompt tokens when request is in prefill phase
remaining_prompt_tokens = request.num_tokens - request.num_computed_tokens - num_new_tokens
if num_new_tokens > prefill_schedul_threshold:
# Boundary condition: when remaining tokens equal or less than threshold,
# reduce current round's token count to prevent phase misclassification
# in reorder batch later in next round
if 0 < remaining_prompt_tokens <= prefill_schedul_threshold:
num_new_tokens -= (prefill_schedul_threshold -
remaining_prompt_tokens + 1)
num_new_tokens = 0 if num_new_tokens < prefill_schedul_threshold else num_new_tokens
elif remaining_prompt_tokens > 0:
# cannot schedule less than threshold tokens in chunked prefill
num_new_tokens = 0
if num_new_tokens == 0:
# The request cannot be scheduled because one of the following
# reasons:
# 1. No new tokens to schedule. This may happen when
# (1) PP>1 and we have already scheduled all prompt tokens
# but they are not finished yet.
# (2) Async scheduling and the request has reached to either
# its max_total_tokens or max_model_len.
# 2. The encoder budget is exhausted.
# 3. The encoder cache is exhausted.
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
# we do not strictly follow the FCFS scheduling policy and
# allow the lower-priority requests to be scheduled.
req_index += 1
continue
while True:
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens,
num_lookahead_tokens=self.num_lookahead_tokens)
if new_blocks is None:
# The request cannot be scheduled.
# Preempt the lowest-priority request.
if self.policy == SchedulingPolicy.PRIORITY:
preempted_req = max(
self.running,
key=lambda r: (r.priority, r.arrival_time),
)
self.running.remove(preempted_req)
if preempted_req in scheduled_running_reqs:
scheduled_running_reqs.remove(preempted_req)
else:
preempted_req = self.running.pop()
self.kv_cache_manager.free(preempted_req)
self.encoder_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
if self.log_stats:
preempted_req.record_event(EngineCoreEventType.PREEMPTED,
scheduled_timestamp)
self.waiting.prepend_request(preempted_req)
preempted_reqs.append(preempted_req)
if preempted_req == request:
# No more request to preempt.
can_schedule = False
break
else:
# The request can be scheduled.
can_schedule = True
break
if not can_schedule:
break
assert new_blocks is not None
# Schedule the request.
scheduled_running_reqs.append(request)
req_to_new_blocks[request.request_id] = new_blocks
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
req_index += 1
# Speculative decode related.
if request.spec_token_ids:
num_scheduled_spec_tokens = (num_new_tokens +
request.num_computed_tokens -
request.num_tokens)
if num_scheduled_spec_tokens > 0:
# Trim spec_token_ids list to num_scheduled_spec_tokens.
del request.spec_token_ids[num_scheduled_spec_tokens:]
scheduled_spec_decode_tokens[request.request_id] = (
request.spec_token_ids)
# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = (
encoder_inputs_to_schedule)
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
encoder_compute_budget = new_encoder_compute_budget
# Record the LoRAs in scheduled_running_reqs
scheduled_loras: set[int] = set()
if self.lora_config:
scheduled_loras = set(
req.lora_request.lora_int_id for req in scheduled_running_reqs
if req.lora_request and req.lora_request.lora_int_id > 0)
assert len(scheduled_loras) <= self.lora_config.max_loras
# Use a temporary RequestQueue to collect requests that need to be
# skipped and put back at the head of the waiting queue later
skipped_waiting_requests = create_request_queue(self.policy)
# Next, schedule the WAITING requests.
if not preempted_reqs:
while self.waiting and token_budget > 0:
if len(self.running) == self.max_num_running_reqs:
break
request = self.waiting.peek_request()
# KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request)
if is_ready:
request.status = RequestStatus.WAITING
else:
logger.debug(
"%s is still in WAITING_FOR_REMOTE_KVS state.",
request.request_id)
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Skip request if the structured output request is still waiting
# for FSM compilation.
if request.status == RequestStatus.WAITING_FOR_FSM:
structured_output_req = request.structured_output_request
if structured_output_req and structured_output_req.grammar:
request.status = RequestStatus.WAITING
else:
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Check that adding the request still respects the max_loras
# constraint.
if (self.lora_config and request.lora_request and
(len(scheduled_loras) == self.lora_config.max_loras
and request.lora_request.lora_int_id not in scheduled_loras)):
# Scheduling would exceed max_loras, skip.
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
num_external_computed_tokens = 0
load_kv_async = False
# Get already-cached tokens.
if request.num_computed_tokens == 0:
# Get locally-cached tokens.
new_computed_blocks, num_new_local_computed_tokens = \
self.kv_cache_manager.get_computed_blocks(
request)
# Get externally-cached tokens if using a KVConnector.
if self.connector is not None:
num_external_computed_tokens, load_kv_async = (
self.connector.get_num_new_matched_tokens(
request, num_new_local_computed_tokens))
if num_external_computed_tokens is None:
# The request cannot be scheduled because
# the KVConnector couldn't determine
# the number of matched tokens.
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Total computed tokens (local + external).
num_computed_tokens = (num_new_local_computed_tokens +
num_external_computed_tokens)
# KVTransfer: WAITING reqs have num_computed_tokens > 0
# after async KV recvs are completed.
else:
new_computed_blocks = (
self.kv_cache_manager.create_empty_block_list())
num_new_local_computed_tokens = 0
num_computed_tokens = request.num_computed_tokens
encoder_inputs_to_schedule = None
new_encoder_compute_budget = encoder_compute_budget
# KVTransfer: loading remote KV, do not allocate for new work.
if load_kv_async:
assert num_external_computed_tokens > 0
num_new_tokens = 0
# Number of tokens to be scheduled.
else:
# We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed
# requests, which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens
if (0 < self.scheduler_config.long_prefill_token_threshold <
num_new_tokens):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
# chunked prefill has to be enabled explicitly to allow
# pooling requests to be chunked
if not self.scheduler_config.chunked_prefill_enabled and \
num_new_tokens > token_budget:
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
# Schedule encoder inputs.
if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens,
new_encoder_compute_budget
) = self._try_schedule_encoder_inputs(
request, num_computed_tokens, num_new_tokens,
encoder_compute_budget)
if num_new_tokens == 0:
# The request cannot be scheduled.
break
if num_new_tokens <= self.num_spec_tokens + 1:
# Too short waiting requests can not be scheduled.
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Handles an edge case when P/D Disaggregation
# is used with Spec Decoding where an
# extra block gets allocated which
# creates a mismatch between the number
# of local and remote blocks.
effective_lookahead_tokens = (0 if request.num_computed_tokens == 0
else self.num_lookahead_tokens)
# Determine if we need to allocate cross-attention blocks.
if self.is_encoder_decoder and request.has_encoder_inputs:
# TODO(russellb): For Whisper, we know that the input is
# always padded to the maximum length. If we support other
# encoder-decoder models, this will need to be updated if we
# want to only allocate what is needed.
num_encoder_tokens =\
self.scheduler_config.max_num_encoder_input_tokens
else:
num_encoder_tokens = 0
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens + num_external_computed_tokens,
num_new_local_computed_tokens,
new_computed_blocks,
num_lookahead_tokens=effective_lookahead_tokens,
delay_cache_blocks=load_kv_async,
num_encoder_tokens=num_encoder_tokens,
)
if new_blocks is None:
# The request cannot be scheduled.
break
# KVTransfer: the connector uses this info to determine
# if a load is needed. Note that
# This information is used to determine if a load is
# needed for this request.
if self.connector is not None:
self.connector.update_state_after_alloc(
request,
new_computed_blocks + new_blocks,
num_external_computed_tokens,
)
# Request was already popped from self.waiting
# unless it was re-added above due to new_blocks being None.
request = self.waiting.pop_request()
if load_kv_async:
# If loading async, allocate memory and put request
# into the WAITING_FOR_REMOTE_KV state.
skipped_waiting_requests.prepend_request(request)
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
continue
req_index += 1
self.running.append(request)
if self.log_stats:
request.record_event(EngineCoreEventType.SCHEDULED,
scheduled_timestamp)
if request.status == RequestStatus.WAITING:
scheduled_new_reqs.append(request)
elif request.status == RequestStatus.PREEMPTED:
scheduled_resumed_reqs.append(request)
else:
raise RuntimeError(f"Invalid request status: {request.status}")
if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id)
req_to_new_blocks[request.request_id] = (
self.kv_cache_manager.get_blocks(request.request_id))
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = (
encoder_inputs_to_schedule)
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
encoder_compute_budget = new_encoder_compute_budget
# Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests:
self.waiting.prepend_requests(skipped_waiting_requests)
# Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
assert token_budget >= 0
assert len(self.running) <= self.max_num_running_reqs
# Since some requests in the RUNNING queue may not be scheduled in
# this step, the total number of scheduled requests can be smaller than
# len(self.running).
assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) +
len(scheduled_running_reqs) <= len(self.running))
# Get the longest common prefix among all requests in the running queue.
# This can be potentially used for cascade attention.
num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
if self.running:
any_request = self.running[0]
num_common_prefix_blocks = (
self.kv_cache_manager.get_num_common_prefix_blocks(
any_request, len(self.running)))
# Construct the scheduler output.
new_reqs_data = [
NewRequestData.from_request(
req, req_to_new_blocks[req.request_id].get_block_ids())
for req in scheduled_new_reqs
]
cached_reqs_data = self._make_cached_request_data(
scheduled_running_reqs,
scheduled_resumed_reqs,
num_scheduled_tokens,
scheduled_spec_decode_tokens,
req_to_new_blocks,
)
scheduled_requests = (scheduled_new_reqs + scheduled_running_reqs +
scheduled_resumed_reqs)
structured_output_request_ids, grammar_bitmask = (self.get_grammar_bitmask(
scheduled_requests, scheduled_spec_decode_tokens))
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
scheduled_cached_reqs=cached_reqs_data,
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
scheduled_encoder_inputs=scheduled_encoder_inputs,
num_common_prefix_blocks=num_common_prefix_blocks,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between
# the previous and the current steps.
finished_req_ids=self.finished_req_ids,
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(
),
structured_output_request_ids=structured_output_request_ids,
grammar_bitmask=grammar_bitmask,
)
# NOTE(Kuntai): this function is designed for multiple purposes:
# 1. Plan the KV cache store
# 2. Wrap up all the KV cache load / save ops into an opaque object
# 3. Clear the internal states of the connector
if self.connector is not None:
meta = self.connector.build_connector_meta(scheduler_output)
scheduler_output.kv_connector_metadata = meta
# collect KV cache events from KV cache manager
events = self.kv_cache_manager.take_events()
# collect KV cache events from connector
if self.connector is not None:
connector_events = self.connector.take_events()
if connector_events:
if events is None:
events = list(connector_events)
else:
events.extend(connector_events)
# publish collected KV cache events
if events:
batch = KVEventBatch(ts=time.time(), events=events)
self.kv_event_publisher.publish(batch)
self._update_after_schedule(scheduler_output)
return scheduler_output
@patch_to(Scheduler)
def _make_cached_request_data(
self,
running_reqs: list[Request],
resumed_reqs: list[Request],
num_scheduled_tokens: dict[str, int],
spec_decode_tokens: dict[str, list[int]],
req_to_new_blocks: dict[str, KVCacheBlocks],
) -> CachedRequestData:
req_ids: list[str] = []
new_token_ids: list[list[int]] = []
new_block_ids: list[Optional[tuple[list[int], ...]]] = []
num_computed_tokens: list[int] = []
use_connector = self.connector is not None
for req in itertools.chain(running_reqs, resumed_reqs):
req_id = req.request_id
req_ids.append(req_id)
num_tokens = (num_scheduled_tokens[req_id] -
len(spec_decode_tokens.get(req_id, ())))
# if self.use_pp:
if not use_connector:
# When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first-
# stage worker and the last-stage worker. Otherwise, we don't
# need to send the sampled tokens back because the model runner
# will cache them.
token_ids = req.all_token_ids[req.num_computed_tokens:req.
num_computed_tokens + num_tokens]
new_token_ids.append(token_ids)
elif use_connector:
# When using a KVConnector, we add a placeholder to avoid index
# out of bounds errors. TODO: Remove this once the KVConnector
# is updated to handle token IDs properly.
new_token_ids.append([])
new_block_ids.append(
req_to_new_blocks[req_id].get_block_ids(allow_none=True))
num_computed_tokens.append(req.num_computed_tokens)
# Because resumed_reqs is usually empty, it is more efficient to do
# in-place appending so that we don't need to allocate a new list.
resumed_from_preemption = [False] * len(running_reqs)
resumed_from_preemption += [True] * len(resumed_reqs)
return CachedRequestData(
req_ids=req_ids,
resumed_from_preemption=resumed_from_preemption,
new_token_ids=new_token_ids,
new_block_ids=new_block_ids,
num_computed_tokens=num_computed_tokens,
)

View File

@@ -0,0 +1,19 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from . import async_llm # noqa
from . import core # noqa: F401
from . import llm_engine # noqa

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,179 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import asyncio
import os
import socket
from typing import Optional
import torch
from fastcore.basics import patch_to
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.tracing import init_tracer
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
from vllm_br import envs as envs_br
from vllm_br.utils import (create_cpu_all_reduce_shared_mem,
get_cpu_all_reduce_shared_mem)
logger = init_logger(__name__)
@patch_to(AsyncLLM)
def __init__(
self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False,
log_requests: bool = True,
start_engine_loop: bool = True,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
client_addresses: Optional[dict[str, str]] = None,
client_count: int = 1,
client_index: int = 0,
) -> None:
"""
Create an AsyncLLM.
Args:
vllm_config: global configuration.
executor_class: an Executor impl, e.g. MultiprocExecutor.
log_stats: Whether to log stats.
usage_context: Usage context of the LLM.
mm_registry: Multi-modal registry.
use_cached_outputs: Whether to use cached outputs.
log_requests: Whether to log requests.
start_engine_loop: Whether to start the engine loop.
stat_loggers: customized stat loggers for the engine.
If not provided, default stat loggers will be used.
PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.
Returns:
None
"""
if not envs.VLLM_USE_V1:
raise ValueError(
"Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
"This should not happen. As a workaround, try using "
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
if envs_br.VLLM_BR_USE_CPU_ALL_REDUCE != 0:
create_cpu_all_reduce_shared_mem()
# Ensure we can serialize custom transformer configs
maybe_register_config_serialize_by_value()
self.model_config = vllm_config.model_config
self.vllm_config = vllm_config
self.observability_config = vllm_config.observability_config
self.log_requests = log_requests
self.log_stats = log_stats or (stat_loggers is not None)
if not log_stats and stat_loggers is not None:
logger.info(
"AsyncLLM created with log_stats=False and non-empty custom "
"logger list; enabling logging without default stat loggers")
if self.model_config.skip_tokenizer_init:
self.tokenizer = None
else:
# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config)
# Processor (converts Inputs --> EngineCoreRequests).
self.processor = Processor(
vllm_config=vllm_config,
tokenizer=self.tokenizer,
mm_registry=mm_registry,
)
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
self.output_processor = OutputProcessor(self.tokenizer,
log_stats=self.log_stats)
if self.observability_config.otlp_traces_endpoint is not None:
tracer = init_tracer("vllm.llm_engine",
self.observability_config.otlp_traces_endpoint)
self.output_processor.tracer = tracer
# EngineCore (starts the engine in background process).
self.engine_core = EngineCoreClient.make_async_mp_client(
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=self.log_stats,
client_addresses=client_addresses,
client_count=client_count,
client_index=client_index,
)
# Loggers.
self.logger_manager: Optional[StatLoggerManager] = None # type: ignore
if self.log_stats:
self.logger_manager = StatLoggerManager(
vllm_config=vllm_config,
engine_idxs=self.engine_core.engine_ranks_managed,
custom_stat_loggers=stat_loggers,
enable_default_loggers=log_stats,
client_count=client_count,
)
self.logger_manager.log_engine_initialized()
self.output_handler: Optional[asyncio.Task] = None # type: ignore
try:
# Start output handler eagerly if we are in the asyncio eventloop.
asyncio.get_running_loop()
self._run_output_handler()
except RuntimeError:
pass
if envs.VLLM_TORCH_PROFILER_DIR:
logger.info(
"Torch profiler enabled. AsyncLLM CPU traces will be collected under %s", # noqa: E501
envs.VLLM_TORCH_PROFILER_DIR)
worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm"
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
],
with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
envs.VLLM_TORCH_PROFILER_DIR,
worker_name=worker_name,
use_gzip=True))
else:
self.profiler = None
@patch_to(AsyncLLM)
def __del__(self):
if get_cpu_all_reduce_shared_mem() is not None:
get_cpu_all_reduce_shared_mem()._cleanup()
self.shutdown()

157
vllm_br/v1/engine/core.py Normal file
View File

@@ -0,0 +1,157 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import os
import time
from typing import Optional
from fastcore.basics import patch_to
from vllm.config import ParallelConfig, VllmConfig
from vllm.logger import logger
from vllm.v1.core.kv_cache_utils import (generate_scheduler_kv_cache_config,
get_kv_cache_configs)
from vllm.v1.engine import EngineCoreOutputs
from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.kv_cache_interface import KVCacheConfig
@patch_to(EngineCore)
def _initialize_kv_caches(
self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
start = time.time()
# Get all kv cache needed by the model
kv_cache_specs = self.model_executor.get_kv_cache_specs()
has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
if has_kv_cache:
if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
dp_group = getattr(self, "dp_group", None)
assert dp_group is not None
self.available_gpu_memory_for_kv_cache = \
ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
available_gpu_memory = [self.available_gpu_memory_for_kv_cache
] * len(kv_cache_specs)
else:
# Profiles the peak memory usage of the model to determine how
# much memory can be allocated for kv cache.
available_gpu_memory = (
self.model_executor.determine_available_memory())
self.available_gpu_memory_for_kv_cache = \
available_gpu_memory[0]
else:
# Attention free models don't need memory for kv cache
available_gpu_memory = [0] * len(kv_cache_specs)
available_gpu_memory = self.model_executor.determine_available_memory()
assert len(kv_cache_specs) == len(available_gpu_memory)
kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs,
available_gpu_memory)
scheduler_kv_cache_config = generate_scheduler_kv_cache_config(
kv_cache_configs)
num_gpu_blocks = scheduler_kv_cache_config.num_blocks
num_cpu_blocks = 0
# Initialize kv cache and warmup the execution
self.model_executor.initialize_from_config(kv_cache_configs)
elapsed = time.time() - start
logger.info(("init engine (profile, create kv cache, "
"warmup model) took %.2f seconds"), elapsed)
return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
@patch_to(EngineCore)
def step_with_batch_queue(
self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]:
"""Schedule and execute batches with the batch queue.
Note that if nothing to output in this step, None is returned.
The execution flow is as follows:
1. Try to schedule a new batch if the batch queue is not full.
If a new batch is scheduled, directly return an empty engine core
output. In other words, fulfilling the batch queue has a higher priority
than getting model outputs.
2. If there is no new scheduled batch, meaning that the batch queue
is full or no other requests can be scheduled, we block until the first
batch in the job queue is finished.
3. Update the scheduler from the output.
"""
batch_queue = self.batch_queue
assert batch_queue is not None
# Try to schedule a new batch if the batch queue is not full, but
# the scheduler may return an empty batch if all requests are scheduled.
# Note that this is not blocking.
assert len(batch_queue) < self.batch_queue_size
model_executed = False
if self.scheduler.has_requests():
scheduler_output = self.scheduler.schedule()
future = self.model_executor.execute_model(scheduler_output,
non_block=True)
batch_queue.appendleft(
(future, scheduler_output)) # type: ignore[arg-type]
model_executed = scheduler_output.total_num_scheduled_tokens > 0
if model_executed and len(batch_queue) < self.batch_queue_size \
and not batch_queue[-1][0].done():
# Don't block on next worker response unless the queue is full
# or there are no more requests to schedule.
return None, True
elif not batch_queue:
# Queue is empty. We should not reach here since this method should
# only be called when the scheduler contains requests or the queue
# is non-empty.
return None, False
# Block until the next result is available.
future, scheduler_output = batch_queue.pop()
model_output = self.execute_model_with_error_logging(
lambda _: future.result(), scheduler_output)
if scheduler_output.total_num_scheduled_tokens != 0:
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output)
if self.use_spec_decode:
# Take the draft token ids.
# draft_token_ids = self.model_executor.take_draft_token_ids()
if model_output.draft_token_ids is not None:
model_output.draft_token_ids.req_ids = model_output.req_ids
self.scheduler.update_draft_token_ids(
model_output.draft_token_ids)
else:
pass
return engine_core_outputs, model_executed
else:
return None, False
@patch_to(EngineCoreProc)
def _process_engine_step(self) -> bool:
"""Called only when there are unfinished local requests."""
# Step the engine core.
outputs, model_executed = self.step_fn()
# Put EngineCoreOutputs into the output queue.
for output in (outputs.items() if outputs else ()):
self.output_queue.put_nowait(output)
# Post-step hook.
# if outputs is not None:
# self.post_step(model_executed)
return model_executed

View File

@@ -0,0 +1,143 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Optional
from fastcore.basics import patch_to
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.distributed.parallel_state import get_dp_group
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.tracing import init_tracer
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.llm_engine import LLMEngine
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
from vllm_br import envs as envs_br
from vllm_br.utils import (create_cpu_all_reduce_shared_mem,
get_cpu_all_reduce_shared_mem)
@patch_to(LLMEngine)
def __init__(
self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False,
multiprocess_mode: bool = False,
) -> None:
if not envs.VLLM_USE_V1:
raise ValueError("Using V1 LLMEngine, but envs.VLLM_USE_V1=False. "
"This should not happen. As a workaround, try using "
"LLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
if stat_loggers is not None:
raise NotImplementedError(
"Passing StatLoggers to LLMEngine in V1 is not yet supported. "
"Set VLLM_USE_V1=0 and file and issue on Github.")
if envs_br.VLLM_BR_USE_CPU_ALL_REDUCE != 0:
create_cpu_all_reduce_shared_mem()
self.vllm_config = vllm_config
self.observability_config = vllm_config.observability_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.log_stats = log_stats
executor_backend = (
self.vllm_config.parallel_config.distributed_executor_backend)
parallel_config = vllm_config.parallel_config
self.external_launcher_dp = (parallel_config.data_parallel_size > 1
and executor_backend == "external_launcher")
# important: init dp group before init the engine_core
# In the decoupled engine case this is handled in EngineCoreProc.
if not multiprocess_mode and parallel_config.data_parallel_size > 1 \
and not self.external_launcher_dp:
self.dp_group = parallel_config.stateless_init_dp_group()
else:
self.dp_group = None
self.should_execute_dummy_batch = False
if self.model_config.skip_tokenizer_init:
self.tokenizer = None
else:
# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config)
# Processor (convert Inputs --> EngineCoreRequests)
self.processor = Processor(vllm_config=vllm_config,
tokenizer=self.tokenizer,
mm_registry=mm_registry)
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
self.output_processor = OutputProcessor(self.tokenizer,
log_stats=self.log_stats)
if self.observability_config.otlp_traces_endpoint is not None:
tracer = init_tracer("vllm.llm_engine",
self.observability_config.otlp_traces_endpoint)
self.output_processor.tracer = tracer
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
self.engine_core = EngineCoreClient.make_client(
multiprocess_mode=multiprocess_mode,
asyncio_mode=False,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=self.log_stats,
)
self.logger_manager: Optional[StatLoggerManager] = None # type: ignore
if self.log_stats:
self.logger_manager = StatLoggerManager(
vllm_config=vllm_config,
custom_stat_loggers=stat_loggers,
enable_default_loggers=log_stats,
)
self.logger_manager.log_engine_initialized()
if not multiprocess_mode:
# for v0 compatibility
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore
if self.external_launcher_dp:
# If we use DP in external launcher mode, we reuse the
# existing DP group used for data communication.
self.dp_group = get_dp_group().cpu_group
# Don't keep the dummy data in memory
self.reset_mm_cache()
@patch_to(LLMEngine)
def __del__(self):
if dp_group := getattr(self, "dp_group",
None) and not self.external_launcher_dp:
stateless_destroy_torch_distributed_process_group(dp_group)
if get_cpu_all_reduce_shared_mem() is not None:
get_cpu_all_reduce_shared_mem()._cleanup()

View File

@@ -0,0 +1,20 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from vllm_br.executor.ray_distributed_executor import ( # noqa: F401
_init_workers_ray_br)
from . import ray_distributed_executor
__all__ = ["_init_workers_ray_br", "ray_distributed_executor"]

View File

@@ -0,0 +1,75 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from concurrent.futures import Future
from typing import Optional, Union
from vllm.executor.ray_utils import RayWorkerWrapper, ray
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.executor.ray_distributed_executor import RayDistributedExecutor
from vllm.v1.outputs import ModelRunnerOutput
class FutureWrapper(Future):
"""A wrapper around a Ray output reference to meet the interface
of .execute_model().
"""
def __init__(self, ref):
super().__init__()
self.ref = ref
def result(self, timeout=None):
if timeout is not None:
raise NotImplementedError("timeout is not supported")
return ray.get(self.ref)
def execute_model(
self,
scheduler_output,
non_block: bool = False,
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
# TODO: current only support non_block is True, need to apdapt new non_block param
assert self.parallel_config.use_ray
refs = []
for pp_rank, tp_group in enumerate(self.pp_tp_workers):
task_refs = [
worker.execute_model_ray.remote(scheduler_output)
for worker in tp_group
]
last_pp_rank = len(self.pp_tp_workers) - 1
if pp_rank == last_pp_rank:
refs.extend(task_refs)
# When PP is not used, we block here until the result is available.
if self.max_concurrent_batches == 1:
return ray.get(refs[0])
# When PP is used, we return a FutureWrapper immediately so that
# the scheduler can yield to the next batch.
return FutureWrapper(refs[0])
def execute_model_ray(
self,
scheduler_output: SchedulerOutput) -> Optional[ModelRunnerOutput]:
return self.worker.execute_model(scheduler_output)
RayDistributedExecutor.execute_model = execute_model # type: ignore[attr-defined]
RayWorkerWrapper.execute_model_ray = execute_model_ray # type: ignore[attr-defined]

View File

@@ -0,0 +1,25 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# TODO(ychun) temp annotation
# @property # type: ignore
# def AttentionSpec_page_size_bytes(self) -> int:
# # For MLA we only store a single latent vector, BR166 uses BB, so it needs to be multiplied by 2
# coef = 1 if (self.use_mla and envs.VLLM_BR_DEVICE_SPC_NUM <= 16) else 2
# return coef * self.block_size * self.num_kv_heads * self.head_size \
# * get_dtype_size(self.dtype)
# AttentionSpec.page_size_bytes = AttentionSpec_page_size_bytes

41
vllm_br/v1/outputs.py Normal file
View File

@@ -0,0 +1,41 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Optional
import torch
from fastcore.basics import patch_to
from vllm.v1.outputs import (DraftTokenIds, KVConnectorOutput, LogprobsLists,
LogprobsTensors, ModelRunnerOutput)
@patch_to(ModelRunnerOutput)
def __init__(self,
req_ids: list[str],
req_id_to_index: dict[str, int],
sampled_token_ids: list[list[int]],
logprobs: Optional[LogprobsLists],
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]],
pooler_output: list[Optional[torch.Tensor]],
kv_connector_output: Optional[KVConnectorOutput] = None,
num_nans_in_logits: Optional[dict[str, int]] = None,
draft_token_ids: Optional["DraftTokenIds"] = None):
self._orig___init__(req_ids, req_id_to_index, sampled_token_ids, logprobs,
prompt_logprobs_dict, pooler_output,
kv_connector_output, num_nans_in_logits)
self.draft_token_ids = draft_token_ids

View File

@@ -0,0 +1,17 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from . import ops # noqa: F401

Binary file not shown.

View File

@@ -0,0 +1,17 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from . import logprobs, topk_topp_sampler # noqa: F401

View File

@@ -0,0 +1,40 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
"""Some utilities for logprobs, including logits."""
import torch
from fastcore.basics import patch_to
from vllm.v1.sample.ops import logprobs
@patch_to(logprobs)
def batched_count_greater_than(x: torch.Tensor,
values: torch.Tensor) -> torch.Tensor:
"""
Counts elements in each row of x that are greater than the corresponding
value in values. Use torch.compile to generate an optimized kernel for
this function. otherwise, it will create additional copies of the input
tensors and cause memory issues.
Args:
x (torch.Tensor): A 2D tensor of shape (batch_size, n_elements).
values (torch.Tensor): A 2D tensor of shape (batch_size, 1).
Returns:
torch.Tensor: A 1D tensor of shape (batch_size,) with the counts.
"""
return (x >= values).sum(-1)

View File

@@ -0,0 +1,138 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
from fastcore.basics import patch_to
from vllm.v1.sample.ops import topk_topp_sampler
def topk_topp_sampler_forward_native(
self,
logits: torch.Tensor,
generators: dict[int, torch.Generator],
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""
PyTorch-native implementation of top-k and top-p sampling.
The logits tensor may be updated in-place.
"""
logits = apply_top_k_top_p(logits, k, p)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)
def apply_top_k_only(
logits: torch.Tensor,
k: torch.Tensor,
) -> torch.Tensor:
"""
Apply top-k mask to the logits.
This implementation doesn't involve sorting the entire vocab.
The logits tensor may be updated in-place.
"""
no_top_k_mask = k == logits.shape[1]
# Set non-top-k rows to 1 so that we can gather.
k = k.masked_fill(no_top_k_mask, 1)
max_top_k = k.max()
# topk.values tensor has shape [batch_size, max_top_k].
# Convert top k to 0-based index in range [0, max_top_k).
k_index = k.sub_(1).unsqueeze(1)
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
# Handle non-topk rows.
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
logits.masked_fill_(logits < top_k_mask, -float("inf"))
return logits
# scatter usage not support on br, need fix.
@patch_to(topk_topp_sampler)
def apply_top_k_top_p(
logits: torch.Tensor,
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""Apply top-k and top-p masks to the logits.
If a top-p is used, this function will sort the logits tensor,
which can be slow for large batches.
The logits tensor may be updated in-place.
"""
if p is None:
if k is None:
return logits
# Avoid sorting vocab for top-k only case.
return apply_top_k_only(logits, k)
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
if k is not None:
# Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B
# Get all the top_k values.
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
top_k_mask = logits_sort < top_k_mask
logits_sort.masked_fill_(top_k_mask, -float("inf"))
if p is not None:
# Apply top-p.
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
# at least one
top_p_mask[:, -1] = False
logits_sort.masked_fill_(top_p_mask, -float("inf"))
# Re-sort the probabilities.
logits = logits_sort.clone()
logits = logits.scatter_(dim=-1, index=logits_idx, src=logits_sort)
return logits
def random_sample(
probs: torch.Tensor,
generators: dict[int, torch.Generator],
) -> torch.Tensor:
"""Randomly sample from the probabilities.
We use this function instead of torch.multinomial because torch.multinomial
causes CPU-GPU synchronization.
"""
q = torch.empty_like(probs)
# NOTE(woosuk): To batch-process the requests without their own seeds,
# which is the common case, we first assume that every request does
# not have its own seed. Then, we overwrite the values for the requests
# that have their own seeds.
if len(generators) != probs.shape[0]:
q.exponential_()
if generators:
# TODO(woosuk): This can be slow because we handle each request
# one by one. Optimize this.
for i, generator in generators.items():
q[i].exponential_(generator=generator)
return probs.div_(q).argmax(dim=-1).view(-1)
# vllm.v1.sample.ops.topk_topp_sampler.TopKTopPSampler.forward_native = topk_topp_sampler_forward_native

View File

@@ -0,0 +1,17 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from . import eagle # noqa: F401

View File

@@ -0,0 +1,265 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from functools import wraps
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import numpy as np
import torch
from fastcore.basics import patch_to
import vllm_br.envs as biren_envs
from vllm.logger import init_logger
from vllm.utils import is_pin_memory_available
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm_br.v1.worker.model_runner import SUPACommonAttentionMetadata
logger = init_logger(__name__)
PADDING_SLOT_ID = -1
def wrapper_EagleProposer_init(fn):
# FIXME: temporary fix for enabling MLA in EagleProposer
@wraps(fn)
def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
self.draft_model_config.weight_type = biren_envs.VLLM_BR_WEIGHT_TYPE
self.draft_model_config.use_ds_mla = True
self.draft_model_config.use_ds_mla_sparse = hasattr(
self.draft_model_config.hf_config, "index_topk")
return wrapper
EagleProposer.__init__ = wrapper_EagleProposer_init(
EagleProposer.__init__) # noqa: E501
@patch_to(EagleProposer)
def prepare_inputs(
self,
common_attn_metadata: SUPACommonAttentionMetadata,
sampled_token_ids: list[list[int]],
num_draft_tokens: list[int],
) -> tuple[SUPACommonAttentionMetadata, torch.Tensor]:
"""
This function is used to prepare the inputs for speculative decoding.
It updates to the common_attn_metadata to account for the rejected
tokens (and newly sampled tokens). It also returns the token indices
of the tokens that should be fed to the speculator.
"""
# E.g.
# common_attn_metadata.query_start_loc{_cpu}:
# [0, q1, q1 + q2, q1 + q2 + q3]
# common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
# num_rejected_tokens: [n1, n2, n3]
# This function computes the intermediate values:
# num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
# And returns:
# common_attn_metadata.query_start_loc{_cpu}:
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
# common_attn_metadata.seq_lens{_cpu}:
# [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
# token_indices: [0, 1, ..., q1 - n1 - 1,
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
num_rejected_tokens = [
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
]
num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32)
device = common_attn_metadata.query_start_loc.device
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \
- num_rejected_tokens
# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
new_query_len_per_req = (query_start_loc_cpu[1:] -
query_start_loc_cpu[:-1])
# [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()
# [q1 - n1, q2 - n2, q3 - n3] ->
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
new_query_start_loc_cpu = torch.zeros(query_start_loc_cpu.shape,
dtype=torch.int32,
pin_memory=is_pin_memory_available())
new_query_start_loc_np = new_query_start_loc_cpu.numpy()
np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])
total_num_tokens = new_query_start_loc_np[-1]
# Example assuming num_tokens_per_req_np = [2, 4, 3]
# this implies that `new_query_start_locs` is:
# [0, 2, 6, 9] ->
# [0, 0, 2, 2, 2, 2, 6, 6, 6]
# _r1_ ____r2____ ___r3__
new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1],
new_num_tokens_per_req_np)
# [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
# [0, 1, 0, 1, 2, 3, 0, 1, 2]
# _r1_ ____r2____ ___r3__
token_offests = self.token_arange_np[:total_num_tokens] \
- new_query_start_locs_expanded
# Expand starting positions to match token pattern
# [0, q1, q1 + q2] ->
# [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
# _r1_ _____r2_______ ___________r3____________
old_query_start_locs_expanded = np.repeat(query_start_loc_cpu[:-1].numpy(),
new_num_tokens_per_req_np)
# Final token indices are:
# [0, 1, // req 1
# q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
# q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
token_indices_np = token_offests + old_query_start_locs_expanded
token_indices = torch.from_numpy(token_indices_np).to(device,
non_blocking=True)
# seq_start_loc = torch.from_numpy(
# np.insert(np.add.accumulate(common_attn_metadata.seq_lens.cpu().numpy()), 0,
# 0)).to(common_attn_metadata.query_start_loc, non_blocking=True)
spec_common_attn_metadata = SUPACommonAttentionMetadata(
query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True),
seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
query_start_loc_cpu=new_query_start_loc_cpu,
seq_lens_cpu=new_seq_lens_cpu,
num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
num_reqs=common_attn_metadata.num_reqs,
num_actual_tokens=total_num_tokens,
max_query_len=new_query_len_per_req.max().item(),
max_seq_len=new_seq_lens_cpu.max().item(),
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
causal=True,
# seq_start_loc=seq_start_loc
)
return spec_common_attn_metadata, token_indices
@patch_to(EagleProposer)
def prepare_inputs_padded(self,
common_attn_metadata: SUPACommonAttentionMetadata,
spec_decode_metadata: SpecDecodeMetadata,
valid_sampled_tokens_count: torch.Tensor) -> \
tuple[SUPACommonAttentionMetadata, torch.Tensor, torch.Tensor]:
"""
This function is used to prepare the inputs for speculative decoding
It updates the common_attn_metadata for speculative decoding,
but does not consider the rejected tokens. Instead, all tokens
are included as inputs to the speculator, with the rejected tokens
used as padding and filtered out later by `token_indices_to_sample`.
No blocking CPU operations should be introduced in this function.
"""
num_draft_tokens_gpu = torch.cat([
spec_decode_metadata.cu_num_draft_tokens[0:1],
spec_decode_metadata.cu_num_draft_tokens[1:] -
spec_decode_metadata.cu_num_draft_tokens[:-1]
])
num_rejected_tokens_gpu = torch.where(
num_draft_tokens_gpu > 0,
num_draft_tokens_gpu + 1 - valid_sampled_tokens_count,
torch.zeros_like(num_draft_tokens_gpu))
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
new_query_len_per_req = (query_start_loc_cpu[1:] -
query_start_loc_cpu[:-1])
total_num_tokens = query_start_loc_cpu[-1].item()
token_indices = self.arange[:total_num_tokens]
seq_start_loc = torch.from_numpy(
np.insert(
np.add.accumulate(common_attn_metadata.seq_lens.cpu().numpy()), 0,
0)).to(common_attn_metadata.query_start_loc, non_blocking=True)
spec_common_attn_metadata = SUPACommonAttentionMetadata(
query_start_loc=common_attn_metadata.query_start_loc,
seq_lens=common_attn_metadata.seq_lens,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
num_reqs=common_attn_metadata.num_reqs,
num_actual_tokens=total_num_tokens,
max_query_len=new_query_len_per_req.max().item(),
max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping[token_indices.long()],
causal=True,
# context_lens=context_lens,
# max_decode_seq_len=self.seq_lens.np[:num_reqs].max(),
seq_start_loc=seq_start_loc)
token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1 \
- num_rejected_tokens_gpu
return spec_common_attn_metadata, token_indices, token_indices_to_sample
def wrapper_EagleProposer_propose(fn):
@wraps(fn)
def wrapper(
self,
# [num_tokens]
target_token_ids: torch.Tensor,
# [num_tokens]
target_positions: torch.Tensor,
# [num_tokens, hidden_size]
target_hidden_states: torch.Tensor,
# [batch_size]
next_token_ids: torch.Tensor,
last_token_indices: Optional[torch.Tensor],
common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata,
mm_embeds: Optional[list[torch.Tensor]] = None,
):
if last_token_indices is None:
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
last_token_indices = last_token_indices.long()
return fn(
self,
# [num_tokens]
target_token_ids,
# [num_tokens]
target_positions,
# [num_tokens, hidden_size]
target_hidden_states,
# [batch_size]
next_token_ids,
last_token_indices,
common_attn_metadata,
sampling_metadata,
mm_embeds)
return wrapper
EagleProposer.propose = wrapper_EagleProposer_propose(
EagleProposer.propose) # noqa: E501

View File

@@ -0,0 +1,15 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,49 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import copy
from typing import TYPE_CHECKING
from vllm.config import VllmConfig
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin)
from vllm_br.forward_context import set_forward_context
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
# @staticmethod
def kv_connector_no_forward(scheduler_output: "SchedulerOutput",
vllm_config: VllmConfig) -> ModelRunnerOutput:
# KV send/recv even if no work to do.
with set_forward_context(
None,
vllm_config), KVConnectorModelRunnerMixin._get_kv_connector_output(
scheduler_output, wait_for_save=False) as kv_connector_output:
pass
if (not kv_connector_output.finished_sending
and not kv_connector_output.finished_recving):
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.kv_connector_output = kv_connector_output
return output
KVConnectorModelRunnerMixin.kv_connector_no_forward = kv_connector_no_forward

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,413 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import threading
from dataclasses import dataclass
from typing import Any, Callable, Optional
import torch
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed import get_ep_group
from vllm.distributed.device_communicators.pynccl_allocator import (
set_graph_pool_id)
from vllm.forward_context import (create_forward_context, get_forward_context,
override_forward_context)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import has_deep_gemm
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
from vllm_br.compilation.supa_graph import SUPAGraphWrapper
from vllm_br.config.compilation import SUPAGraphMode
logger = init_logger(__name__)
@dataclass
class UbatchMetadata:
context: UBatchContext
input_ids: torch.Tensor
positions: torch.Tensor
inputs_embeds: Optional[torch.Tensor]
intermediate_tensors: Optional[IntermediateTensors]
num_tokens: int
@dataclass
class SUPAGraphMetaData:
supagraph: torch.supa.SUPAGraph
ubatch_metadata: UbatchMetadata
outputs: Optional[Any] = None
class SMControlContextManager:
def __init__(self, comm_sms: int, set_comm_sms: Callable[[int], None],
set_compute_sms: Callable[[int], None]):
"""
Context manager for controlling SM (Streaming Multiprocessor)
allocation. Upon entering the context, it sets the number of SMs
allocated for communication and computation to comm_sms and
total_sms - comm_sms respectively. Upon exiting, it restores the
allocation to use all available SMs (i.e. total_sms).
Args:
comm_sms (int): The number of SMs to allocate for communication.
(The remainder will be used for computation.)
set_comm_sms (Callable[[int], None]):
A function that sets the number of SMs for communication.
set_compute_sms (Callable[[int], None]):
A function that sets the number of SMs for computation.
"""
assert current_platform.is_supa(), \
"SM control is currently only supported on SUPA"
props = torch.supa.get_device_properties(torch.supa.current_device())
total_sms = props.multi_processor_count
assert comm_sms < total_sms
self.total_sms = total_sms
self.compute_sms = total_sms - comm_sms
self.comm_sms = comm_sms
self.set_comm_sms = set_comm_sms
self.set_compute_sms = set_compute_sms
def __enter__(self):
self.set_comm_sms(self.comm_sms)
self.set_compute_sms(self.compute_sms)
def __exit__(self, exc_type, exc_value, traceback):
self.set_comm_sms(self.total_sms)
self.set_compute_sms(self.total_sms)
class UBatchWrapper:
def __init__(self, runnable: Callable, vllm_config: VllmConfig,
runtime_mode: SUPAGraphMode, device: torch.supa.device):
self.runnable = runnable
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.comm_stream = torch.supa.Stream(device=device)
# Two ubatch threads plus the main thread
self.ready_barrier = threading.Barrier(3)
self.supagraphs: dict[int, SUPAGraphMetaData] = {}
self.supagraph_wrapper = None
self.graph_pool = None
if runtime_mode is not SUPAGraphMode.NONE:
self.supagraph_wrapper = SUPAGraphWrapper(
runnable, vllm_config, runtime_mode=runtime_mode)
self.graph_pool = current_platform.get_global_graph_pool()
self.sm_control = self._create_sm_control_context(vllm_config)
self.device = device
@staticmethod
def _create_sm_control_context(vllm_config: VllmConfig):
comm_sms = envs.VLLM_DBO_COMM_SMS
set_comm_sms = lambda sms: None
if vllm_config.parallel_config.enable_expert_parallel:
# Currently only DeepEP highthroughput supports SM control so this
# only affects that case.
all2all_manager = get_ep_group(
).device_communicator.all2all_manager
if all2all_manager.max_sms_used() is not None:
comm_sms = min(comm_sms, all2all_manager.max_sms_used())
if comm_sms > 0:
set_comm_sms = lambda sms: all2all_manager.set_num_sms(sms)
# TODO(lucas): support other kernels besides DeepGEMM
set_compute_sms = lambda sms: None
if has_deep_gemm() and comm_sms > 0:
import deep_gemm as dg
set_compute_sms = lambda sms: dg.set_num_sms(sms)
return SMControlContextManager(comm_sms=comm_sms,
set_comm_sms=set_comm_sms,
set_compute_sms=set_compute_sms)
def __getattr__(self, key: str):
# allow accessing the attributes of the runnable.
if hasattr(self.runnable, key):
return getattr(self.runnable, key)
raise AttributeError(f"Attribute {key} not exists in the runnable of "
f"supagraph wrapper: {self.runnable}")
def unwrap(self) -> Callable:
# in case we need to access the original runnable.
return self.runnable
def _capture_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
"""
Capture a supagraph for a microbatched run.
The logic here is somewhat complicated because we need to make sure that
each of the ubatch threads initialize the supa context before we start
the graph capture.
The flow is as follows:
1. The main thread starts up each ubatch thread. Each thread will
initialize its supa context (torch.supa.current_blas_handle())
before going to sleep upon entering the ubatch_context.
2. The main thread starts the graph capture and wakes up the first
ubatch thread.
3. Each ubatch thread runs the model to completion and returns the
completed output tensors back to the main thread.
4. The main thread stores the captured supagraph along with its metadata
and returns
"""
@torch.inference_mode()
def _capture_ubatch_thread(results, ubatch_metadata):
torch.supa.set_device(self.device)
ubatch_context = ubatch_metadata.context
with torch.supa.stream(ubatch_context.compute_stream):
_ = torch.supa.current_blas_handle()
with torch.supa.stream(ubatch_context.comm_stream):
_ = torch.supa.current_blas_handle()
with ubatch_context:
model_output = model(
input_ids=ubatch_metadata.input_ids,
positions=ubatch_metadata.positions,
intermediate_tensors=ubatch_metadata.intermediate_tensors,
inputs_embeds=ubatch_metadata.inputs_embeds,
)
results.append((ubatch_metadata.context.id, model_output))
results: list[tuple[int, torch.Tensor]] = []
compute_stream = ubatch_metadata[0].context.compute_stream
num_tokens = ubatch_metadata[0].num_tokens + \
ubatch_metadata[1].num_tokens
# Ubatches will manually manage the forward context, so we override
# it to None here so we can have it restored correctly later
with override_forward_context(None):
ubatch_threads = []
for metadata in ubatch_metadata:
thread = threading.Thread(target=_capture_ubatch_thread,
args=(
results,
metadata,
))
ubatch_threads.append(thread)
thread.start()
self.ready_barrier.wait() # Wait for both threads to be ready
# Capture the supagraph
supagraph_metadata = \
SUPAGraphMetaData(
supagraph=torch.supa.SUPAGraph(),
ubatch_metadata=ubatch_metadata,
)
if self.graph_pool is not None:
set_graph_pool_id(self.graph_pool)
else:
set_graph_pool_id(current_platform.graph_pool_handle())
with torch.supa.graph(supagraph_metadata.supagraph,
stream=compute_stream,
pool=self.graph_pool):
ubatch_metadata[0].context.cpu_wait_event.set()
for thread in ubatch_threads:
thread.join()
sorted_results = [value for position, value in sorted(results)]
result = torch.cat(sorted_results, dim=0)
supagraph_metadata.outputs = result
self.supagraphs[num_tokens] = supagraph_metadata
return supagraph_metadata.outputs
def _run_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
@torch.inference_mode()
def _ubatch_thread(results, model, ubatch_metadata):
with ubatch_metadata.context:
model_output = model(
input_ids=ubatch_metadata.input_ids,
positions=ubatch_metadata.positions,
intermediate_tensors=ubatch_metadata.intermediate_tensors,
inputs_embeds=ubatch_metadata.inputs_embeds,
)
results.append((ubatch_metadata.context.id, model_output))
results: list[tuple[int, torch.Tensor]] = []
# Ubatch threads will manually manage the forward context, so we
# override it to None here so we can have it restored correctly
# after both threads have finished
with override_forward_context(None):
ubatch_threads = []
for metadata in ubatch_metadata:
thread = threading.Thread(target=_ubatch_thread,
args=(
results,
model,
metadata,
))
ubatch_threads.append(thread)
thread.start()
self.ready_barrier.wait() # Wait for both threads to be ready
ubatch_metadata[0].context.cpu_wait_event.set()
for thread in ubatch_threads:
thread.join()
sorted_results = [value for position, value in sorted(results)]
result = torch.cat(sorted_results, dim=0)
return result
def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, input_ids,
positions, inputs_embeds, intermediate_tensors,
compute_stream, dp_metadata, batch_descriptor,
supagraph_runtime_mode) -> list[UbatchMetadata]:
# Create one forward context per ubatch
forward_contexts = []
for i, ubatch_slice in enumerate(ubatch_slices):
forward_contexts.append(
create_forward_context(
attn_metadata[i] if attn_metadata is not None else None,
self.vllm_config,
dp_metadata=dp_metadata,
batch_descriptor=batch_descriptor,
supagraph_runtime_mode=supagraph_runtime_mode))
ubatch_ctxs = make_ubatch_contexts(
num_micro_batches=len(ubatch_slices),
comm_stream=self.comm_stream,
compute_stream=compute_stream,
forward_contexts=forward_contexts,
ready_barrier=self.ready_barrier)
ubatch_metadata: list[UbatchMetadata] = []
for i, ubatch_slice in enumerate(ubatch_slices):
sliced_input_ids, sliced_positions, sliced_inputs_embeds, \
sliced_intermediate_tensors = \
self._slice_model_inputs(
ubatch_slice.token_slice, input_ids, positions,
inputs_embeds, intermediate_tensors)
ubatch_metadata.append(
UbatchMetadata(
context=ubatch_ctxs[i],
input_ids=sliced_input_ids,
positions=sliced_positions,
inputs_embeds=sliced_inputs_embeds,
intermediate_tensors=sliced_intermediate_tensors,
num_tokens=ubatch_slice.token_slice.stop -
ubatch_slice.token_slice.start))
return ubatch_metadata
def _slice_model_inputs(self, tokens_slice: slice, input_ids, positions,
inputs_embeds, intermediate_tensors):
sliced_input_ids = input_ids[tokens_slice]
# if we are using mrope. Mrope adds an additional dimension to the
# positions tensor
if positions.ndim == 2:
sliced_positions = positions[:, tokens_slice]
else:
sliced_positions = positions[tokens_slice]
sliced_inputs_embeds = inputs_embeds[
tokens_slice] if inputs_embeds else None
sliced_intermediate_tensors = intermediate_tensors[
tokens_slice] if intermediate_tensors else None
return (sliced_input_ids, sliced_positions, sliced_inputs_embeds,
sliced_intermediate_tensors)
def __call__(self, *args, **kwargs):
forward_context = get_forward_context()
batch_descriptor = forward_context.batch_descriptor
ubatch_slices = forward_context.ubatch_slices
supagraph_runtime_mode = forward_context.cudagraph_runtime_mode
# If there's no ubatching, just run the runnable object
if ubatch_slices is None:
# This is to account for the case where ubatching was aborted.
# When we capture full graphs we only capture one graph per shape,
# meaning that if we have a ubatched supagraph for the current
# num_tokens, we don't have a non-ubatched one. Without this
# check, the supagraph wrapper will try to capture a supagraph
# for this shape during a normal run.
if supagraph_runtime_mode is SUPAGraphMode.FULL:
assert batch_descriptor is not None
if batch_descriptor.num_tokens in self.supagraphs:
supagraph_runtime_mode = SUPAGraphMode.NONE
if supagraph_runtime_mode in (SUPAGraphMode.NONE,
SUPAGraphMode.PIECEWISE):
return self.runnable(*args, **kwargs)
else:
assert self.supagraph_wrapper is not None
return self.supagraph_wrapper(*args, **kwargs)
attn_metadata = forward_context.attn_metadata
num_tokens = (ubatch_slices[0].token_slice.stop -
ubatch_slices[0].token_slice.start) * 2
input_ids = kwargs['input_ids']
positions = kwargs['positions']
intermediate_tensors = kwargs['intermediate_tensors']
inputs_embeds = kwargs['inputs_embeds']
compute_stream = torch.supa.current_stream()
dp_metadata = forward_context.dp_metadata
# We shouldn't be here unless we are running with multiple DP ranks
assert dp_metadata is not None
if num_tokens not in self.supagraphs \
and supagraph_runtime_mode is SUPAGraphMode.FULL:
ubatch_metadata = self._make_ubatch_metadata(
ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata,
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
compute_stream=compute_stream,
dp_metadata=dp_metadata,
batch_descriptor=batch_descriptor,
supagraph_runtime_mode=SUPAGraphMode.NONE)
with self.sm_control:
return self._capture_ubatches(ubatch_metadata, self.model)
elif num_tokens in self.supagraphs \
and supagraph_runtime_mode is SUPAGraphMode.FULL:
supagraph_metadata = self.supagraphs[num_tokens]
supagraph_metadata.supagraph.replay()
return supagraph_metadata.outputs
else:
ubatch_metadata = self._make_ubatch_metadata(
ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata,
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
compute_stream=compute_stream,
dp_metadata=dp_metadata,
batch_descriptor=batch_descriptor,
supagraph_runtime_mode=SUPAGraphMode.NONE)
with self.sm_control:
return self._run_ubatches(ubatch_metadata, self.model)

View File

@@ -0,0 +1,155 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Optional
from vllm.config import CompilationLevel, VllmConfig
from vllm.logger import logger
from vllm_br.config.compilation import SUPAGraphMode
from vllm_br.forward_context import BatchDescriptor
_BATCH_SIZE_ALIGNMENT = 8
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 34)
]
class SupagraphDispatcher:
"""
Runtime supagraph dispatcher to dispatch keys for multiple set of
supagraphs.
The dispatcher stores two sets of dispatch keys, one for PIECEWISE and one
for FULL supagraph runtime mode. The keys are initialized depending on
attention support and what supagraph mode is set in CompilationConfig. The
keys stored in dispatcher are the only source of truth for valid
supagraphs that can be dispatched at runtime.
At runtime, the dispatch method generates the runtime supagraph mode (FULL,
PIECEWISE, or NONE for no supagraph) and the valid key (batch descriptor)
based on the input key. After dispatching (communicate via forward context),
the supagraph wrappers will trust the dispatch key to do either capturing
or replaying (if mode matched), or pass through to the underlying runnable
without supagraph (if mode no match or mode is NONE).
"""
def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.use_default_list = vllm_config.compilation_config.max_capture_size > 256 or self.compilation_config.max_capture_size == 0
self.capture_list = _BATCH_SIZES_TO_CAPTURE if self.use_default_list else self.compilation_config.cudagraph_capture_sizes
# TODO(liming): Remove this hard code once we support piecewise
self.supagraph_mode = SUPAGraphMode.FULL
# Dict to store valid supagraph dispatching keys.
self.supagraph_keys: dict[SUPAGraphMode, set[BatchDescriptor]] = {
SUPAGraphMode.PIECEWISE: set(),
SUPAGraphMode.FULL: set(),
SUPAGraphMode.FULL_DECODE_ONLY: set(),
}
assert not self.supagraph_mode.requires_piecewise_compilation() or \
(self.compilation_config.level == CompilationLevel.PIECEWISE and
self.compilation_config.splitting_ops_contain_attention()), \
"Compilation level should be CompilationLevel.PIECEWISE when "\
"supagraph_mode piecewise supagraphs is used, "\
f"supagraph_mode={self.supagraph_mode}, "\
f"compilation_level={self.compilation_config.level}, "\
f"splitting_ops={self.compilation_config.splitting_ops}"
self.keys_initialized = False
def add_supagraph_key(self, runtime_mode: SUPAGraphMode,
batch_descriptor: BatchDescriptor):
assert runtime_mode in [SUPAGraphMode.PIECEWISE,SUPAGraphMode.FULL_DECODE_ONLY, SUPAGraphMode.FULL], \
f"Invalid supagraph runtime mode: {runtime_mode}"
self.supagraph_keys[runtime_mode].add(batch_descriptor)
def initialize_supagraph_keys(self, supagraph_mode: SUPAGraphMode,
uniform_decode_query_len: int):
# This should be called only after attention backend is initialized.
# Note: we create all valid keys possible for supagraph but do not
# guarantee all keys would be used. For example, we create keys for
# piecewise supagraphs when it is piecewise compilation, which is always
# valid, but for attention backend support unified routine, we may not
# trigger capturing/replaying the piecewise supagraphs depending on
# CompilationConfig.supagraph_mode. In addition, if we allow lazy
# capturing in future PR, some keys may never be triggered.
if supagraph_mode == SUPAGraphMode.FULL:
max_num_tokens = (uniform_decode_query_len *
self.vllm_config.scheduler_config.max_num_seqs)
supagraph_capture_sizes_for_decode = [
x for x in self.capture_list
if x <= max_num_tokens and x >= uniform_decode_query_len
]
for bs in supagraph_capture_sizes_for_decode:
self.add_supagraph_key(
supagraph_mode,
BatchDescriptor(num_tokens=bs, uniform_decode=True))
# if decode supagraph mode is FULL, and we don't already have mixed
# mode full supagraphs then add them here.
if supagraph_mode == SUPAGraphMode.FULL_DECODE_ONLY:
max_num_tokens = uniform_decode_query_len * \
self.vllm_config.scheduler_config.max_num_seqs
supagraph_capture_sizes_for_decode = [
x for x in self.capture_list
if x <= max_num_tokens and x >= uniform_decode_query_len
]
for bs in supagraph_capture_sizes_for_decode:
self.add_supagraph_key(
supagraph_mode,
BatchDescriptor(num_tokens=bs, uniform_decode=True))
self.keys_initialized = True
def dispatch(
self, batch_descriptor: BatchDescriptor
) -> tuple[SUPAGraphMode, Optional[BatchDescriptor]]:
"""
Given a batch descriptor, dispatch to a supagraph mode.
A new batch descriptor is returned as we might dispatch a uniform batch
to a graph that supports a more general batch (uniform to non-uniform).
"""
# if not initialized, just skip dispatching.
if not self.keys_initialized:
logger.warning_once("supagraph dispatching keys are not "
"initialized. No supagraph will be used.")
return SUPAGraphMode.NONE, None
if batch_descriptor in self.supagraph_keys[
SUPAGraphMode.FULL_DECODE_ONLY]:
return SUPAGraphMode.FULL_DECODE_ONLY, batch_descriptor
# check if key exists for full supagraph
if batch_descriptor in self.supagraph_keys[SUPAGraphMode.FULL]:
return SUPAGraphMode.FULL, batch_descriptor
# # otherwise, check if non-uniform key exists
non_uniform_key = batch_descriptor.non_uniform
if non_uniform_key in self.supagraph_keys[SUPAGraphMode.FULL]:
return SUPAGraphMode.FULL, non_uniform_key
#
# # also check if non-uniform key exists for more "general"
# # piecewise supagraph
# if non_uniform_key in self.supagraph_keys[SUPAGraphMode.PIECEWISE]:
# return SUPAGraphMode.PIECEWISE, non_uniform_key
# finally, just return no supagraphs
return SUPAGraphMode.NONE, None

View File

@@ -0,0 +1,195 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import threading
import torch
from vllm import forward_context
from vllm.forward_context import ForwardContext
from vllm.utils import current_stream
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
class SUPAUBatchContext:
"""
Context manager for micro-batching synchronization using threading events.
"""
def __init__(self,
id: int,
comm_stream: torch.supa.Stream,
compute_stream: torch.supa.Stream,
forward_context: ForwardContext,
ready_barrier: threading.Barrier,
cpu_wait_event: threading.Event,
cpu_signal_event: threading.Event,
gpu_comm_done_event: torch.supa.Event,
gpu_compute_done_event: torch.supa.Event,
schedule: str = "default"):
self.id = id
self.comm_stream = comm_stream
self.compute_stream = compute_stream
self.forward_context = forward_context
self.ready_barrier = ready_barrier
self.cpu_wait_event = cpu_wait_event
self.cpu_signal_event = cpu_signal_event
self.current_stream = compute_stream
self.gpu_comm_done_event = gpu_comm_done_event
self.gpu_compute_done_event = gpu_compute_done_event
self.schedule = schedule
self.recv_hook = None
def __enter__(self):
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
_THREAD_ID_TO_CONTEXT[threading.get_ident()] = self.id
_CURRENT_CONTEXTS[self.id] = self
self.ready_barrier.wait()
self.cpu_wait_event.wait()
self.cpu_wait_event.clear()
self._restore_context()
# Assume we want to start on the compute stream
self.update_stream(self.compute_stream)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
_CURRENT_CONTEXTS[self.id] = None
del _THREAD_ID_TO_CONTEXT[threading.get_ident()]
self.maybe_run_recv_hook()
self.cpu_signal_event.set()
self.cpu_wait_event.clear()
return False
def _restore_context(self):
forward_context._forward_context = self.forward_context
def update_stream(self, stream):
self.current_stream = stream
if current_stream() != self.current_stream:
torch.supa.set_stream(self.current_stream)
def _signal_comm_done(self):
self.gpu_comm_done_event.record(self.comm_stream)
def _signal_compute_done(self):
self.gpu_compute_done_event.record(self.compute_stream)
def _wait_compute_done(self):
self.comm_stream.wait_event(self.gpu_compute_done_event)
def _wait_comm_done(self):
self.compute_stream.wait_event(self.gpu_comm_done_event)
def _cpu_yield(self):
# It is critical for correctness that only one thread is running
# at a time. These asserts just make sure that this is the only
# thread running before waking the other one up and going to sleep
assert forward_context._forward_context == self.forward_context
assert current_stream() == self.current_stream
assert not self.cpu_wait_event.is_set()
self.cpu_signal_event.set()
self.cpu_wait_event.wait()
self.cpu_wait_event.clear()
self._restore_context()
def switch_to_comm(self):
self.update_stream(self.comm_stream)
def switch_to_compute(self):
self.update_stream(self.compute_stream)
def switch_to_comm_sync(self):
self._signal_compute_done()
self.update_stream(self.comm_stream)
self._wait_compute_done()
def switch_to_compute_sync(self):
self._signal_comm_done()
self.update_stream(self.compute_stream)
self._wait_comm_done()
def maybe_run_recv_hook(self):
if self.recv_hook is not None:
self.recv_hook()
self.recv_hook = None
def yield_(self):
self.current_stream = current_stream()
self._cpu_yield()
self.update_stream(self.current_stream)
def yield_and_switch_from_compute_to_comm(self):
assert current_stream() == self.compute_stream
self._signal_compute_done()
self._cpu_yield()
assert self.current_stream == self.compute_stream
self.update_stream(self.comm_stream)
self._wait_compute_done()
def yield_and_switch_from_comm_to_compute(self):
assert current_stream() == self.comm_stream
self._signal_comm_done()
self._cpu_yield()
assert self.current_stream == self.comm_stream
self.update_stream(self.compute_stream)
self._wait_comm_done()
def supa_make_ubatch_contexts(
num_micro_batches: int,
compute_stream: torch.supa.Stream,
comm_stream: torch.supa.Stream,
forward_contexts: list[ForwardContext],
ready_barrier: threading.Barrier,
schedule: str = "default",
) -> list[UBatchContext]:
assert num_micro_batches == 2, "only been tested with 2 micro-batches"
"""
Create a context manager for micro-batching synchronization.
"""
cpu_events = [threading.Event() for _ in range(num_micro_batches)]
gpu_comm_done_events = [
torch.supa.Event() for _ in range(num_micro_batches)
]
gpu_compute_done_events = [
torch.supa.Event() for _ in range(num_micro_batches)
]
assert len(forward_contexts) == 2
ctxs = []
for i in range(num_micro_batches):
ctx = UBatchContext(id=i,
compute_stream=compute_stream,
comm_stream=comm_stream,
forward_context=forward_contexts[i],
ready_barrier=ready_barrier,
cpu_wait_event=cpu_events[i],
cpu_signal_event=cpu_events[(i + 1) %
num_micro_batches],
gpu_comm_done_event=gpu_comm_done_events[i],
gpu_compute_done_event=gpu_compute_done_events[i],
schedule=schedule)
ctxs.append(ctx)
return ctxs
UBatchContext = SUPAUBatchContext
make_ubatch_contexts = supa_make_ubatch_contexts

View File

@@ -0,0 +1,86 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from typing import TYPE_CHECKING, Optional
import torch
from vllm.model_executor.models.utils import extract_layer_index
from vllm.platforms import current_platform
if TYPE_CHECKING:
from vllm.attention.layer import Attention
def bind_kv_cache(
kv_caches: dict[str, torch.Tensor],
forward_context: dict[str, "Attention"],
runner_kv_caches: list[torch.Tensor],
num_attn_module: Optional[int] = 1,
) -> None:
"""
Bind the allocated KV cache to both ModelRunner and forward context so
that the KV cache can be used in the forward pass.
This function:
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
kv_caches.
2) Associates each attention layer in the `forward_context` with its
corresponding KV cache in kv_caches.
Args:
kv_caches: The allocated kv_caches with layer names as keys.
forward_context: The global forward context containing all Attention
layers with layer names as keys.
runner_kv_caches: The kv_cache declared by ModelRunner.
"""
# Bind kv_caches to ModelRunner
assert len(runner_kv_caches) == 0
# Convert kv_caches dict to a list of tensors in the order of layer_index.
index2name = defaultdict(list)
for layer_name in kv_caches:
index2name[extract_layer_index(layer_name,
num_attn_module)].append(layer_name)
for layer_index in sorted(index2name.keys()):
layer_names = index2name[layer_index]
if len(layer_names) > 1:
# One typical case is encoder-decoder model, e.g., bart.
# The cross attention and self attention in the same decoder layer
# has different layer_name but the same layer_index.
# TODO - analyze where runner_kv_caches is used and the right
# way to ensure it properly reflects multiple attention layers
# in the same decoder block.
if current_platform.is_cuda() or current_platform.is_xpu(
) or current_platform.is_supa():
# We know that the GPU runner is not impacted by this
# case. Some test code depends on runner_kv_caches, but
# not in a way that's impacted by ignoring this.
pass
else:
raise NotImplementedError
layer_name = layer_names[0]
runner_kv_caches.append(kv_caches[layer_name])
# Bind kv_caches to forward context
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]

429
vllm_br/v1/worker/worker.py Normal file
View File

@@ -0,0 +1,429 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# SPDX-License-Identifier: Apache-2.0
"""A GPU worker class."""
import copy
import datetime
import gc
from typing import TYPE_CHECKING, Optional, Union
import torch
import torch.nn as nn
import vllm.envs as envs
import vllm_br.envs as br_envs
from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
DraftTokenIds, ModelRunnerOutput)
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.worker_base import WorkerBase
from vllm_br.platform import SUPAPlatform
from vllm_br.utils import GiB_bytes, SUPAMemorySnapshot
from vllm_br.v1.worker.model_runner import SUPAModelRunner
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
class SUPAWorker(WorkerBase):
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
):
super().__init__(
vllm_config=vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
)
self.kv_transfer_config = vllm_config.kv_transfer_config
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
# Buffers saved before sleep
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info(
"Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir,
)
self.profiler = torch.profiler.profile(
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir, use_gzip=True),
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.SUPA, # type: ignore
],
schedule=torch.profiler.schedule(wait=0,
warmup=0,
active=1,
repeat=1),
profile_memory=False,
record_shapes=True,
with_stack=False,
use_supa_simple=True, # type: ignore
)
else:
self.profiler = None
def sleep(self, level: int = 1) -> None:
raise NotImplementedError
def wake_up(self, tags: Optional[list[str]] = None) -> None:
raise NotImplementedError
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
def init_device(self):
if self.device_config.device.type == "supa":
self.device = torch.device(f"supa:{self.local_rank}")
if self.kv_transfer_config is not None:
device_cursor = self.kv_transfer_config.get_from_extra_config(
"device_cursor", 0)
self.device = torch.device(
f"supa:{self.local_rank + int(device_cursor)}")
SUPAPlatform.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype)
# Initialize the distributed environment BEFORE taking
# memory snapshot
# This ensures SUCCL buffers are allocated before we measure
# available memory
self._init_worker_distributed_environment()
# Set random seed.
set_random_seed(self.model_config.seed)
gc.collect()
torch.supa.empty_cache()
self.init_gpu_memory = SUPAPlatform.get_device_total_memory()
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Construct the model runner
self.model_runner: SUPAModelRunner = SUPAModelRunner( # type: ignore
self.vllm_config, self.device)
if self.rank == 0:
# If usage stat is enabled, collect relevant info.
report_usage_stats(self.vllm_config)
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
# to hijack tensor allocation.
def load_model(self) -> None:
if self.vllm_config.model_config.enable_sleep_mode:
raise NotImplementedError('SUPA do not support sleep mode')
else:
from contextlib import nullcontext
context = nullcontext()
with context:
self.model_runner.load_model()
@torch.inference_mode()
def determine_available_memory(self) -> int:
"""Profiles the peak memory usage of the model to determine how much
memory can be used for KV cache without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the free memory that can be used for KV cache in
bytes.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
torch.supa.empty_cache()
_, total_gpu_memory = torch.supa.mem_get_info()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
before_profile = SUPAMemorySnapshot()
after_profile = SUPAMemorySnapshot()
before_profile.measure()
self.model_runner.profile_run()
after_profile.measure()
free_gpu_memory, _ = torch.supa.mem_get_info()
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
assert self.init_gpu_memory > free_gpu_memory, (
"Error in memory profiling. "
f"Initial free memory {self.init_gpu_memory}, current free memory"
f" {free_gpu_memory}. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.")
# GPU did not change their memory usage during the profiling.
peak_memory = torch.supa.memory_allocated()
# Check for any memory left around that may have been allocated on the
# gpu outside of `torch`. NCCL operations, for example, can use a few
# GB during a forward pass
torch.supa.empty_cache()
torch_allocated_bytes = SUPAPlatform.get_memory_stats(
self.device, "allocated_bytes.all.current")
total_allocated_bytes = (torch.supa.mem_get_info()[1] -
torch.supa.mem_get_info()[0])
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
#if non_torch_allocations > 0:
# peak_memory += non_torch_allocations
available_kv_cache_memory = (
total_gpu_memory * self.cache_config.gpu_memory_utilization -
peak_memory)
memory_for_current_instance = total_gpu_memory * \
self.cache_config.gpu_memory_utilization
diff_profile = after_profile - before_profile
msg = (f"Memory profiling takes {diff_profile.timestamp:.2f} seconds\n"
"the current vLLM instance can use "
"total_gpu_memory "
f"({(total_gpu_memory / GiB_bytes):.2f}GiB)"
" x gpu_memory_utilization "
f"({self.cache_config.gpu_memory_utilization:.2f})"
f" = {(memory_for_current_instance / GiB_bytes):.2f}GiB\n"
"model weights take "
f"{(self.model_runner.model_memory_usage / GiB_bytes):.2f}GiB;"
" non_torch_memory takes "
f"{(non_torch_allocations / GiB_bytes):.2f}GiB;"
" PyTorch activation peak memory takes "
f"{(diff_profile.torch_peak / GiB_bytes):.2f}GiB;"
" the rest of the memory reserved for KV Cache is "
f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.")
logger.info(msg)
return int(available_kv_cache_memory)
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec()
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
if self.vllm_config.model_config.enable_sleep_mode:
raise NotImplementedError('SUPA do not support sleep mode')
else:
from contextlib import nullcontext
context = nullcontext()
with context:
self.model_runner.initialize_kv_cache(kv_cache_config)
def compile_or_warm_up_model(self) -> None:
# warm up sizes that are not in cudagraph capture sizes,
# but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill.
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
if not self.model_config.enforce_eager:
warmup_sizes = [
x for x in warmup_sizes
if x not in self.scheduler_config.cuda_graph_sizes
]
for size in sorted(warmup_sizes, reverse=True):
logger.info("Compile and warming up model for size %d", size)
self.model_runner._dummy_run(size,
skip_eplb=True,
remove_lora=False)
self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
if not self.model_config.enforce_eager:
self.model_runner.capture_model()
# Warm up sampler and preallocate memory buffer for logits and other
# sampling related tensors of max possible shape to avoid memory
# fragmentation issue.
# NOTE: This is called after `capture_model` on purpose to prevent
# memory buffers from being cleared by `SUPAPlatform.empty_cache`.
if get_pp_group().is_last_rank:
max_num_reqs = min(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens,
)
hidden_states, last_hidden_states = \
self.model_runner._dummy_run(
num_tokens=max_num_reqs,
skip_eplb=True,
)
if self.model_runner.is_pooling_model:
self.model_runner._dummy_pooler_run(hidden_states)
else:
self.model_runner._dummy_sampler_run(
hidden_states=last_hidden_states)
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.model_runner.get_supported_tasks()
@torch.inference_mode()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]:
intermediate_tensors = None
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
if forward_pass and not get_pp_group().is_first_rank:
# intermediate_tensors = IntermediateTensors(
# get_pp_group().recv_tensor_dict(
# all_gather_group=get_tp_group()))
# use cpu send/recv
if br_envs.VLLM_PP_CPU_SEND_RECV:
cpu_dict = get_pp_group().recv_tensor_dict()
gpu_dict = {
k: v.to(torch.supa.current_device())
for k, v in cpu_dict.items()
}
intermediate_tensors = IntermediateTensors(gpu_dict)
else:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict())
output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)
if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
return output
assert isinstance(output, IntermediateTensors)
parallel_config = self.vllm_config.parallel_config
assert parallel_config.distributed_executor_backend != (
"external_launcher") and not get_pp_group().is_last_rank
# use cpu send/recv
if br_envs.VLLM_PP_CPU_SEND_RECV:
cpu_dict = {k: v.cpu() for k, v in output.tensors.items()}
get_pp_group().send_tensor_dict(cpu_dict)
else:
get_pp_group().send_tensor_dict(output.tensors)
kv_connector_output = output.kv_connector_output
if not kv_connector_output:
return None
# In case of PP with kv transfer, we need to pass through the
# kv_connector_output
if (not kv_connector_output.finished_sending
and not kv_connector_output.finished_recving):
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.kv_connector_output = kv_connector_output
return output
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
return self.model_runner.take_draft_token_ids()
def profile(self, is_start: bool = True):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
if is_start:
self.profiler.start()
else:
self.profiler.stop()
def execute_dummy_batch(self) -> None:
self.model_runner._dummy_run(1)
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.model_runner.remove_lora(lora_id)
def list_loras(self) -> set[int]:
return self.model_runner.list_loras()
def pin_lora(self, lora_id: int) -> bool:
return self.model_runner.pin_lora(lora_id)
def check_health(self) -> None:
# worker will always be healthy as long as it's running.
return
def save_sharded_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
from vllm.model_executor.model_loader.loader import ShardedStateLoader
ShardedStateLoader.save_model(
self.model_runner.model,
path,
pattern=pattern,
max_size=max_size,
)
def _init_worker_distributed_environment(self) -> None:
"""Initialize the distributed environment."""
set_custom_all_reduce(
not self.parallel_config.disable_custom_all_reduce)
init_distributed_environment(self.parallel_config.world_size,
self.rank,
self.distributed_init_method,
self.local_rank,
"sccl",
timeout=datetime.timedelta(seconds=100))
ensure_model_parallel_initialized(
self.parallel_config.tensor_parallel_size,
self.parallel_config.pipeline_parallel_size)
ensure_kv_transfer_initialized(self.vllm_config)
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.
# TODO: add checkers
return
if torch_dtype == torch.bfloat16: # noqa: SIM102
capability = SUPAPlatform.get_device_capability()
gpu_name = SUPAPlatform.get_device_name()
if capability is None:
compute_str = "does not have a compute capability"
else:
version_str = capability.as_version_str()
compute_str = f"has compute capability {version_str}"
raise ValueError(
"Bfloat16 is only supported on GPUs with compute capability "
f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
"You can use float16 instead by explicitly setting the "
"`dtype` flag in CLI, for example: --dtype=half.")