first commit
This commit is contained in:
19
vllm_br/v1/__init__.py
Normal file
19
vllm_br/v1/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from . import attention # noqa: F401
|
||||
from . import executor # noqa: F401
|
||||
from . import core, engine, kv_cache_interface, outputs, sample # noqa: F401
|
||||
BIN
vllm_br/v1/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm_br/v1/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm_br/v1/__pycache__/kv_cache_interface.cpython-310.pyc
Normal file
BIN
vllm_br/v1/__pycache__/kv_cache_interface.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm_br/v1/__pycache__/outputs.cpython-310.pyc
Normal file
BIN
vllm_br/v1/__pycache__/outputs.cpython-310.pyc
Normal file
Binary file not shown.
17
vllm_br/v1/attention/__init__.py
Normal file
17
vllm_br/v1/attention/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from . import backends # noqa: F401
|
||||
BIN
vllm_br/v1/attention/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm_br/v1/attention/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
17
vllm_br/v1/attention/backends/__init__.py
Normal file
17
vllm_br/v1/attention/backends/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
from . import mla # noqa: F401
|
||||
from .utils import *
|
||||
Binary file not shown.
Binary file not shown.
BIN
vllm_br/v1/attention/backends/__pycache__/utils.cpython-310.pyc
Normal file
BIN
vllm_br/v1/attention/backends/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
657
vllm_br/v1/attention/backends/attention_v1.py
Normal file
657
vllm_br/v1/attention/backends/attention_v1.py
Normal file
@@ -0,0 +1,657 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
"""Attention layer with FlashAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, ClassVar, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
|
||||
get_flash_attn_version)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import logger
|
||||
from vllm.v1.attention.backends.flash_attn import _get_sliding_window_configs
|
||||
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
get_kv_cache_layout,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm_br.config.compilation import SUPAGraphMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
# from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
|
||||
class SUPAFlashAttentionBackend(AttentionBackend):
|
||||
|
||||
# NOTE: When piecewise cudagraph is enabled, this
|
||||
# makes sure the output tensor is allocated inside the cudagraph.
|
||||
# NOTE: currently, we do not support accept_output_buffer=True
|
||||
accept_output_buffer: bool = False
|
||||
supports_quant_query_input: bool = True
|
||||
|
||||
@classmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
attn_type = cls.__name__.removesuffix("Backend")
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {supported_head_sizes}. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes.")
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "SUPAFLASH_ATTN_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["SUPAFlashAttentionImpl"]:
|
||||
return SUPAFlashAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["SUPAFlashAttentionMetadata"]:
|
||||
return SUPAFlashAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["SUPAFlashAttentionMetadataBuilder"]:
|
||||
return SUPAFlashAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_usharp_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
th_gran = SUPAFlashAttentionBackend.get_kv_cache_usharp_alignment(
|
||||
block_size)
|
||||
n_block = max(1, (num_blocks + th_gran - 1) // th_gran)
|
||||
logger.debug(
|
||||
f'Origin kv cache shape is [2, {num_blocks}, {block_size}, {num_kv_heads}, {head_size}, For SUPA Speed up, use [2, {n_block}, {th_gran * block_size}, {num_kv_heads * head_size}]' # noqa: G004
|
||||
)
|
||||
return (2, n_block, th_gran * block_size, num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_usharp_alignment(block_size: int) -> int:
|
||||
max_h_limit = 2048
|
||||
return max_h_limit // block_size
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order() -> tuple[int, ...]:
|
||||
# `stride_order` indicates the permutation that gets
|
||||
# us from `get_kv_cache_shape` to the actual memory layout we want.
|
||||
cache_layout = get_kv_cache_layout()
|
||||
if cache_layout == "NHD":
|
||||
stride_order = (0, 1, 2, 3, 4)
|
||||
elif cache_layout == "HND":
|
||||
stride_order = (0, 1, 3, 2, 4)
|
||||
else:
|
||||
raise ValueError(f"Unknown cache layout format {cache_layout}.")
|
||||
return stride_order
|
||||
|
||||
@staticmethod
|
||||
def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:
|
||||
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
||||
return torch.float8_e4m3fn
|
||||
else:
|
||||
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SUPAFlashAttentionMetadata:
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
# BIREN Attention Params
|
||||
seq_start_loc: torch.Tensor
|
||||
context_lens: torch.Tensor
|
||||
max_decode_seq_len: int
|
||||
do_cache: bool # when use attentionsplit, do cache = False
|
||||
num_actual_reqs: torch.Tensor
|
||||
|
||||
# Graph mode
|
||||
supagraph_runtime_mode: SUPAGraphMode
|
||||
|
||||
# For handling prefill decode split
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
|
||||
# For cascade attention.
|
||||
use_cascade: bool
|
||||
common_prefix_len: int
|
||||
cu_prefix_query_lens: Optional[torch.Tensor]
|
||||
prefix_kv_lens: Optional[torch.Tensor]
|
||||
suffix_kv_lens: Optional[torch.Tensor]
|
||||
|
||||
# Optional aot scheduling
|
||||
scheduler_metadata: Optional[torch.Tensor] = None
|
||||
prefix_scheduler_metadata: Optional[torch.Tensor] = None
|
||||
max_num_splits: int = 0
|
||||
|
||||
causal: bool = True
|
||||
|
||||
# for local attention
|
||||
# @dataclass
|
||||
# class LocalAttentionMetadata:
|
||||
# local_query_start_loc: torch.Tensor
|
||||
# local_seqused_k: torch.Tensor
|
||||
# local_block_table: torch.Tensor
|
||||
# local_max_query_len: int
|
||||
# local_max_seq_len: int
|
||||
# local_scheduler_metadata: Optional[torch.Tensor]
|
||||
|
||||
# local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
||||
|
||||
|
||||
class SUPAFlashAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[SUPAFlashAttentionMetadata]):
|
||||
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.ALWAYS
|
||||
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
self.model_config = vllm_config.model_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
|
||||
self.num_heads_q = self.model_config.get_num_attention_heads(
|
||||
self.parallel_config)
|
||||
self.num_heads_kv = self.model_config.get_num_kv_heads(
|
||||
self.parallel_config)
|
||||
self.kv_cache_dtype = kv_cache_spec.dtype
|
||||
self.headdim = self.model_config.get_head_size()
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
supports_spec_as_decode = True
|
||||
self._init_reorder_batch_threshold(1, supports_spec_as_decode)
|
||||
|
||||
self.max_num_splits = 0 # No upper bound on the number of splits.
|
||||
# self.aot_schedule = (get_flash_attn_version() == 3)
|
||||
self.aot_schedule = False
|
||||
|
||||
self.use_full_cuda_graph = \
|
||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
self.max_cudagraph_size = self.compilation_config.max_capture_size
|
||||
|
||||
# if self.use_full_cuda_graph and self.aot_schedule:
|
||||
# if self.max_cudagraph_size > 992:
|
||||
# # This condition derives from FA3's internal heuristic.
|
||||
# # TODO(woosuk): Support larger cudagraph sizes.
|
||||
# raise ValueError(
|
||||
# "Capture size larger than 992 is not supported for "
|
||||
# "full cuda graph.")
|
||||
|
||||
# self.scheduler_metadata = torch.zeros(
|
||||
# vllm_config.scheduler_config.max_num_seqs + 1,
|
||||
# dtype=torch.int32,
|
||||
# device=self.device,
|
||||
# )
|
||||
# # When using cuda graph, we need to set the upper bound of the
|
||||
# # number of splits so that large enough intermediate buffers are
|
||||
# # pre-allocated during capture.
|
||||
# self.max_num_splits = (
|
||||
# envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH)
|
||||
|
||||
# Sliding window size to be used with the AOT scheduler will be
|
||||
# populated on first build() call.
|
||||
self.aot_sliding_window: Optional[tuple[int, int]] = None
|
||||
|
||||
# model_config = runner.model_config
|
||||
|
||||
# self.runner = runner
|
||||
# self.num_heads_q = model_config.get_num_attention_heads(
|
||||
# runner.parallel_config)
|
||||
# self.num_heads_kv = model_config.get_num_kv_heads(
|
||||
# runner.parallel_config)
|
||||
# self.headdim = model_config.get_head_size()
|
||||
# self.block_size = kv_cache_spec.block_size
|
||||
# self.kv_cache_spec = kv_cache_spec
|
||||
# self.block_table = block_table
|
||||
|
||||
# self.aot_schedule = False
|
||||
# logger.warning(
|
||||
# "AOT Schedule is disabled when using SUPAFlashAttention.")
|
||||
|
||||
# # Sliding window size to be used with the AOT scheduler will be
|
||||
# # populated on first build() call.
|
||||
# self.aot_sliding_window: Optional[tuple[int, int]] = None
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> SUPAFlashAttentionMetadata:
|
||||
"""
|
||||
fast_build disables AOT scheduling, used when there will be few
|
||||
iterations i.e. spec-decode
|
||||
"""
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
|
||||
split_decodes_and_prefills(common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold,
|
||||
require_uniform=True)
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
causal = common_attn_metadata.causal
|
||||
num_actual_reqs = common_attn_metadata.num_actual_reqs
|
||||
seq_start_loc = common_attn_metadata.seq_start_loc
|
||||
context_lens = common_attn_metadata.context_lens
|
||||
|
||||
# the overhead of the aot schedule is not worth it for spec-decode
|
||||
aot_schedule = self.aot_schedule and not fast_build
|
||||
|
||||
if self.aot_sliding_window is None:
|
||||
self.aot_sliding_window = (-1, -1)
|
||||
# For the AOT scheduler we need the sliding window value to be
|
||||
# constant for all layers to. We have to populate this on the first
|
||||
# build() call so the layers are constructed (cannot populate)
|
||||
# in __init__.
|
||||
if aot_schedule:
|
||||
sliding_window_configs = _get_sliding_window_configs(
|
||||
self.vllm_config)
|
||||
if len(sliding_window_configs) == 1:
|
||||
sliding_window_config = sliding_window_configs.pop()
|
||||
if sliding_window_config is not None:
|
||||
self.aot_sliding_window = sliding_window_config
|
||||
elif len(sliding_window_configs) > 1:
|
||||
self.aot_schedule = False
|
||||
aot_schedule = False
|
||||
|
||||
max_num_splits = 0 # 0 means use FA3's heuristics, not CG compatible
|
||||
if self.use_full_cuda_graph and \
|
||||
num_actual_tokens <= self.max_cudagraph_size:
|
||||
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
|
||||
# usage, because the intermediate buffers of size [num_splits,
|
||||
# num_heads, num_tokens, head_size] are allocated. Therefore,
|
||||
# we only set num_splits when using cuda graphs.
|
||||
max_num_splits = self.max_num_splits
|
||||
|
||||
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
|
||||
max_seq_len, causal):
|
||||
if self.aot_schedule:
|
||||
raise NotImplementedError(
|
||||
'aot schedule not support in SUPA attention')
|
||||
return None
|
||||
|
||||
# for local attention
|
||||
# local_attn_metadata = None
|
||||
# if self.runner.attention_chunk_size is not None:
|
||||
# seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
|
||||
# virt_block_table_tensor = make_local_attention_virtual_batches(
|
||||
# self.runner.attention_chunk_size,
|
||||
# self.runner.query_start_loc_np[:num_reqs + 1],
|
||||
# self.runner.seq_lens_np[:num_reqs],
|
||||
# block_table_tensor,
|
||||
# self.block_size,
|
||||
# )
|
||||
# local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
|
||||
# self.runner.device, non_blocking=False)
|
||||
# local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
|
||||
# self.runner.device, non_blocking=False)
|
||||
# local_max_query_len = seqlens_q_local_np.max()
|
||||
# local_max_seq_len = virt_k_seqlens_np.max()
|
||||
# local_scheduler_metadata = schedule(
|
||||
# batch_size=local_query_start_loc.shape[0] - 1,
|
||||
# cu_query_lens=local_query_start_loc,
|
||||
# max_query_len=local_max_query_len,
|
||||
# seqlens=local_seqused_k,
|
||||
# max_seq_len=local_max_seq_len,
|
||||
# causal=True)
|
||||
|
||||
# local_attn_metadata = SUPAFlashAttentionMetadata.LocalAttentionMetadata(
|
||||
# local_query_start_loc=local_query_start_loc,
|
||||
# local_seqused_k=local_seqused_k,
|
||||
# local_block_table=virt_block_table_tensor,
|
||||
# local_max_query_len=local_max_query_len,
|
||||
# local_max_seq_len=local_max_seq_len,
|
||||
# local_scheduler_metadata=local_scheduler_metadata,
|
||||
# )
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
|
||||
if use_cascade:
|
||||
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
prefix_kv_lens = torch.tensor([common_prefix_len],
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len).to(
|
||||
self.device, non_blocking=True)
|
||||
prefix_scheduler_metadata = schedule(
|
||||
batch_size=1,
|
||||
cu_query_lens=cu_prefix_query_lens,
|
||||
max_query_len=num_actual_tokens,
|
||||
seqlens=prefix_kv_lens,
|
||||
max_seq_len=common_prefix_len,
|
||||
causal=False)
|
||||
scheduler_metadata = schedule(batch_size=num_reqs,
|
||||
cu_query_lens=query_start_loc,
|
||||
max_query_len=max_query_len,
|
||||
seqlens=suffix_kv_lens,
|
||||
max_seq_len=max_seq_len -
|
||||
common_prefix_len,
|
||||
causal=True)
|
||||
else:
|
||||
cu_prefix_query_lens = None
|
||||
prefix_kv_lens = None
|
||||
suffix_kv_lens = None
|
||||
prefix_scheduler_metadata = None
|
||||
scheduler_metadata = schedule(batch_size=num_reqs,
|
||||
cu_query_lens=query_start_loc,
|
||||
max_query_len=max_query_len,
|
||||
seqlens=seq_lens,
|
||||
max_seq_len=max_seq_len,
|
||||
causal=causal)
|
||||
|
||||
if common_attn_metadata.max_decode_seq_len is None:
|
||||
max_decode_seq_len = max_decode_seq_len = int(
|
||||
seq_lens.max().item())
|
||||
else:
|
||||
max_decode_seq_len = common_attn_metadata.max_decode_seq_len
|
||||
|
||||
attn_metadata = SUPAFlashAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
# local_attn_metadata=local_attn_metadata,
|
||||
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
||||
max_num_splits=max_num_splits,
|
||||
causal=causal,
|
||||
# Biren Attention Params
|
||||
seq_start_loc=seq_start_loc,
|
||||
context_lens=context_lens,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
num_prefills=num_prefills,
|
||||
num_decodes=num_decodes,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
do_cache=True,
|
||||
num_actual_reqs=num_actual_reqs,
|
||||
supagraph_runtime_mode=common_attn_metadata.supagraph_runtime_mode)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class SUPAFlashAttentionImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
sinks: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes,
|
||||
dtype=torch.float32,
|
||||
device="cpu")
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.sliding_window = sliding_window or None
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
if logits_soft_cap is None:
|
||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||
logits_soft_cap = 0
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
SUPAFlashAttentionBackend.validate_head_size(head_size)
|
||||
|
||||
self.attn_type = attn_type
|
||||
|
||||
if attn_type not in (AttentionType.DECODER,
|
||||
AttentionType.ENCODER_ONLY):
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashAttentionImpl")
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype) \
|
||||
and not flash_attn_supports_fp8():
|
||||
raise NotImplementedError(
|
||||
"FlashAttention does not support fp8 kv-cache on this device.")
|
||||
|
||||
self.sinks: Optional[torch.Tensor] = None
|
||||
if sinks is not None:
|
||||
if sinks.shape[0] != num_heads:
|
||||
raise ValueError(
|
||||
"Sinks must have the same number of heads as the number of "
|
||||
f"heads in the layer. Expected {num_heads}, but got "
|
||||
f"{sinks.shape[0]}.")
|
||||
if sinks.dtype != torch.float32:
|
||||
raise ValueError("Sinks must be of type float32, but got "
|
||||
f"{sinks.dtype}.")
|
||||
self.sinks = sinks
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: SUPAFlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
NOTE: FP8 quantization, flash-attn expect the size of
|
||||
{q,k,v}_descale to be (num_sequences, num_kv_heads).
|
||||
We use torch's .expand() to avoid duplicating values
|
||||
"""
|
||||
assert output is None, "Output tensor should not provided."
|
||||
if attn_metadata is None:
|
||||
# FIXME: this may lead to wrong block estimatation
|
||||
# Profiling run.
|
||||
return query
|
||||
|
||||
is_encoder = self.attn_type in (AttentionType.ENCODER_ONLY,
|
||||
AttentionType.ENCODER)
|
||||
# NOTE: supa attn use [batch_size, num_tokens, num_heads * head_size] as shape
|
||||
if kv_cache is not None and attn_metadata.do_cache and not is_encoder:
|
||||
torch_br.supa_kvcache_store_infer_v2(
|
||||
kv_cache,
|
||||
key,
|
||||
value, # type: ignore
|
||||
attn_metadata.slot_mapping,
|
||||
self.head_size)
|
||||
|
||||
if self.sinks is not None:
|
||||
return self.forward_sw_sinks(query, kv_cache, attn_metadata)
|
||||
|
||||
if self.attn_type in (AttentionType.ENCODER_ONLY,
|
||||
AttentionType.ENCODER):
|
||||
assert len(query.shape) == 3
|
||||
return torch_br.supa_flash_attention_infer( # type: ignore
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_metadata.query_start_loc,
|
||||
self.head_size,
|
||||
len(attn_metadata.query_start_loc), # type: ignore
|
||||
self.alibi_slopes,
|
||||
softmax_scale=self.scale,
|
||||
is_causal=_get_causal_option(self.attn_type))
|
||||
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
if attn_metadata.supagraph_runtime_mode is None or (
|
||||
attn_metadata.supagraph_runtime_mode
|
||||
in (SUPAGraphMode.NONE, SUPAGraphMode.FULL_DECODE_ONLY)):
|
||||
# prefill + decode(non-mtp)
|
||||
if num_prefill_tokens > 0:
|
||||
output_prefill = torch_br.br_flash_attn_with_kvcache_infer( # type: ignore
|
||||
query,
|
||||
kv_cache,
|
||||
attn_metadata.query_start_loc,
|
||||
attn_metadata.seq_start_loc,
|
||||
attn_metadata.block_table,
|
||||
self.head_size,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
softmax_scale=self.scale,
|
||||
num_reqs=attn_metadata.num_actual_reqs)
|
||||
return output_prefill
|
||||
## decode only
|
||||
output_decode = torch_br.supa_attention_decoder_infer_v2( # type: ignore
|
||||
query, # type: ignore
|
||||
kv_cache,
|
||||
attn_metadata.block_table,
|
||||
attn_metadata.seq_lens,
|
||||
attn_metadata.max_decode_seq_len,
|
||||
self.head_size,
|
||||
attn_metadata.num_prefills,
|
||||
self.alibi_slopes,
|
||||
softmax_scale=self.scale)
|
||||
return output_decode
|
||||
else:
|
||||
output_prefill = torch_br.br_flash_attn_with_kvcache_infer( # type: ignore
|
||||
query,
|
||||
kv_cache,
|
||||
attn_metadata.query_start_loc,
|
||||
attn_metadata.seq_start_loc,
|
||||
attn_metadata.block_table,
|
||||
self.head_size,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
softmax_scale=self.scale,
|
||||
num_reqs=attn_metadata.num_actual_reqs)
|
||||
return output_prefill
|
||||
|
||||
# sliding window with sinks impl
|
||||
def forward_sw_sinks(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: SUPAFlashAttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
# prefix-enabled attention
|
||||
output = torch_br.supa_flash_attn_cache_infer( # type: ignore
|
||||
query,
|
||||
kv_cache,
|
||||
attn_metadata.query_start_loc,
|
||||
attn_metadata.seq_start_loc,
|
||||
attn_metadata.block_table,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.slot_mapping,
|
||||
attn_metadata.max_seq_len,
|
||||
self.head_size,
|
||||
window_size=self.sliding_window,
|
||||
sinks=self.sinks)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _get_causal_option(attn_type: str) -> bool:
|
||||
"""
|
||||
Determine whether the given attention type is suitable for causal
|
||||
attention mechanisms.
|
||||
|
||||
Args:
|
||||
attn_type (AttentionType): The type of attention being evaluated
|
||||
|
||||
Returns:
|
||||
bool: Returns `True` if the attention type is suitable for causal
|
||||
attention (i.e., not encoder, encoder-only, or encoder-decoder),
|
||||
otherwise returns `False`.
|
||||
"""
|
||||
return not (attn_type == AttentionType.ENCODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY
|
||||
or attn_type == AttentionType.ENCODER_DECODER)
|
||||
19
vllm_br/v1/attention/backends/mla/__init__.py
Normal file
19
vllm_br/v1/attention/backends/mla/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from . import flashmla # noqa: F401
|
||||
from . import flashmla_sparse # noqa: F401
|
||||
from . import indexer # noqa: F401
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
657
vllm_br/v1/attention/backends/mla/flashmla.py
Normal file
657
vllm_br/v1/attention/backends/mla/flashmla.py
Normal file
@@ -0,0 +1,657 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
import itertools
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
get_tp_group, tensor_model_parallel_all_reduce)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase, ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.v1.attention.backends.flash_attn import _get_sliding_window_configs
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonImpl,
|
||||
MLACommonMetadataBuilder)
|
||||
from vllm.v1.attention.backends.mla.flashmla import (FlashMLABackend,
|
||||
FlashMLAMetadata)
|
||||
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm_br import envs
|
||||
from vllm_br.model_executor.layers.br_utils import _convert_to_numa_tensor
|
||||
from vllm_br.utils import get_grandparent_pid
|
||||
from vllm_br.v1.attention.backends.utils import SUPACommonAttentionMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SupaFlashMLABackend(FlashMLABackend):
|
||||
|
||||
# NOTE: When piecewise cudagraph is enabled, this
|
||||
# makes sure the output tensor is allocated inside the cudagraph.
|
||||
# NOTE: currently, we do not support accept_output_buffer=True
|
||||
accept_output_buffer: bool = False
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "SUPAFLASHMLA"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["SupaFlashMLAMetadata"]:
|
||||
return SupaFlashMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["SupaFlashMLAMetadataBuilder"]:
|
||||
return SupaFlashMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["SupaFlashMLAImpl"]:
|
||||
return SupaFlashMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_usharp_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
th_gran = SupaFlashMLABackend.get_kv_cache_usharp_alignment(block_size)
|
||||
n_block = max(1, (num_blocks + th_gran - 1) // th_gran)
|
||||
# return (2, n_block, th_gran * block_size, num_kv_heads * head_size)
|
||||
logger.debug(
|
||||
f'Origin kv cache shape is [1, {num_blocks}, {block_size}, {num_kv_heads}, {head_size}, For SUPA Speed up, use [1, {n_block}, {th_gran * block_size}, {num_kv_heads * head_size}]' # noqa: G004
|
||||
)
|
||||
# TODO, shared kv only used in deepseek
|
||||
return (1, n_block, th_gran * block_size, num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_usharp_alignment(block_size: int) -> int:
|
||||
max_h_limit = 2048
|
||||
return max_h_limit // block_size
|
||||
|
||||
|
||||
@dataclass
|
||||
class SupaFlashMLAMetadata:
|
||||
# class SupaFlashMLAMetadata(FlashMLAMetadata):
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
# BIREN Attention Params
|
||||
seq_start_loc: torch.Tensor
|
||||
context_lens: torch.Tensor
|
||||
max_decode_seq_len: int
|
||||
do_cache: bool # when use attentionsplit, do cache = False
|
||||
|
||||
# For handling prefill decode split
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
num_actual_reqs: torch.Tensor
|
||||
|
||||
# For cascade attention.
|
||||
use_cascade: bool
|
||||
common_prefix_len: int
|
||||
cu_prefix_query_lens: Optional[torch.Tensor]
|
||||
prefix_kv_lens: Optional[torch.Tensor]
|
||||
suffix_kv_lens: Optional[torch.Tensor]
|
||||
|
||||
# Optional aot scheduling
|
||||
scheduler_metadata: Optional[torch.Tensor] = None
|
||||
prefix_scheduler_metadata: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class SupaFlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
|
||||
FlashMLAMetadata)
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config)
|
||||
|
||||
self.cg_buf_tile_scheduler_metadata = None
|
||||
self.cg_buf_num_splits = None
|
||||
|
||||
device_properties = torch.cuda.get_device_properties(self.device)
|
||||
num_sms = device_properties.multi_processor_count
|
||||
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
self.cg_buf_tile_scheduler_metadata = torch.zeros(
|
||||
# Upper bound on size (<= #SMs, TileSchedulerMetaDataSize)
|
||||
# TileSchedulerMetaDataSize = 8
|
||||
(num_sms, 8),
|
||||
device=self.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
self.cg_buf_num_splits = torch.empty(
|
||||
(vllm_config.scheduler_config.max_num_seqs + 1),
|
||||
device=self.device,
|
||||
dtype=torch.int32)
|
||||
|
||||
self.aot_schedule = False
|
||||
logger.warning(
|
||||
"AOT Schedule is disabled when using SUPAFlashAttention.")
|
||||
|
||||
# Sliding window size to be used with the AOT scheduler will be
|
||||
# populated on first build() call.
|
||||
self.aot_sliding_window: Optional[tuple[int, int]] = None
|
||||
|
||||
supports_spec_as_decode = True
|
||||
self._init_reorder_batch_threshold(1, supports_spec_as_decode)
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: SUPACommonAttentionMetadata,
|
||||
fast_build: bool = False):
|
||||
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
max_seq_len = int(common_attn_metadata.seq_lens_cpu[:num_reqs].max())
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
num_actual_reqs = common_attn_metadata.num_actual_reqs
|
||||
|
||||
aot_schedule = self.aot_schedule and not fast_build
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
|
||||
split_decodes_and_prefills(common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold,
|
||||
require_uniform=True)
|
||||
|
||||
if self.aot_sliding_window is None:
|
||||
self.aot_sliding_window = (-1, -1)
|
||||
# For the AOT scheduler we need the sliding window value to be
|
||||
# constant for all layers to. We have to populate this on the first
|
||||
# build() call so the layers are constructed (cannot populate)
|
||||
# in __init__.
|
||||
if aot_schedule:
|
||||
sliding_window_configs = _get_sliding_window_configs(
|
||||
self.vllm_config)
|
||||
if len(sliding_window_configs) == 1:
|
||||
sliding_window_config = sliding_window_configs.pop()
|
||||
if sliding_window_config is not None:
|
||||
self.aot_sliding_window = sliding_window_config
|
||||
elif len(sliding_window_configs) > 1:
|
||||
self.aot_schedule = False
|
||||
aot_schedule = False
|
||||
|
||||
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
|
||||
max_seq_len, causal):
|
||||
if self.aot_schedule:
|
||||
raise NotImplementedError(
|
||||
'aot schedule not support in SUPA attention')
|
||||
return None
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
|
||||
if use_cascade:
|
||||
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
prefix_kv_lens = torch.tensor([common_prefix_len],
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] -
|
||||
common_prefix_len)
|
||||
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
|
||||
self.runner.device)
|
||||
prefix_scheduler_metadata = schedule(
|
||||
batch_size=1,
|
||||
cu_query_lens=cu_prefix_query_lens,
|
||||
max_query_len=num_actual_tokens,
|
||||
seqlens=prefix_kv_lens,
|
||||
max_seq_len=common_prefix_len,
|
||||
causal=False)
|
||||
scheduler_metadata = schedule(batch_size=num_reqs,
|
||||
cu_query_lens=query_start_loc,
|
||||
max_query_len=max_query_len,
|
||||
seqlens=suffix_kv_lens,
|
||||
max_seq_len=max_seq_len -
|
||||
common_prefix_len,
|
||||
causal=True)
|
||||
else:
|
||||
cu_prefix_query_lens = None
|
||||
prefix_kv_lens = None
|
||||
suffix_kv_lens = None
|
||||
prefix_scheduler_metadata = None
|
||||
scheduler_metadata = schedule(batch_size=num_reqs,
|
||||
cu_query_lens=query_start_loc,
|
||||
max_query_len=max_query_len,
|
||||
seqlens=seq_lens,
|
||||
max_seq_len=max_seq_len,
|
||||
causal=True)
|
||||
|
||||
if common_attn_metadata.seq_start_loc is None:
|
||||
if len(seq_lens) > 8:
|
||||
seq_lens_cpu = seq_lens.cpu()
|
||||
seq_start_loc = torch.tensor(
|
||||
[0] + list(itertools.accumulate(seq_lens_cpu)),
|
||||
device=query_start_loc.device,
|
||||
dtype=torch.int32)
|
||||
else:
|
||||
seq_start_loc = torch.tensor(
|
||||
[0] + list(itertools.accumulate(seq_lens)),
|
||||
device=query_start_loc.device,
|
||||
dtype=torch.int32)
|
||||
else:
|
||||
seq_start_loc = common_attn_metadata.seq_start_loc
|
||||
|
||||
if common_attn_metadata.context_lens is None:
|
||||
context_lens = seq_lens - (query_start_loc[1:] -
|
||||
query_start_loc[:-1])
|
||||
else:
|
||||
context_lens = common_attn_metadata.context_lens
|
||||
|
||||
if common_attn_metadata.max_decode_seq_len is None:
|
||||
max_decode_seq_len = max_decode_seq_len = int(
|
||||
seq_lens.max().item())
|
||||
else:
|
||||
max_decode_seq_len = common_attn_metadata.max_decode_seq_len
|
||||
|
||||
attn_metadata = SupaFlashMLAMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
||||
# Biren Attention Params
|
||||
seq_start_loc=seq_start_loc,
|
||||
context_lens=context_lens,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
do_cache=True,
|
||||
num_actual_reqs=num_actual_reqs)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: SUPACommonAttentionMetadata) -> bool:
|
||||
# Full CUDA Graph always supported (FA2 support checked separately)
|
||||
return False
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
# class SupaFlashMLAImpl(FlashMLAImpl):
|
||||
class SupaFlashMLAImpl(MLACommonImpl[SupaFlashMLAMetadata]):
|
||||
can_return_lse_for_decode: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
q_lora_rank: Optional[int],
|
||||
kv_lora_rank: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
qk_head_dim: int,
|
||||
v_head_dim: int,
|
||||
kv_b_proj: ColumnParallelLinear,
|
||||
rotary_emb: RotaryEmbedding,
|
||||
# # q_proj should be q_b_proj if q_lora_rank is not None, but from an
|
||||
# # attention backend perspective we rely on the layer to pass in the
|
||||
# # correct matrix
|
||||
q_proj: ColumnParallelLinear, # q_b_proj
|
||||
# kv_b_proj: ColumnParallelLinear,
|
||||
o_proj: RowParallelLinear,
|
||||
kv_a_proj_with_mqa: ReplicatedLinear,
|
||||
kv_a_layernorm: Any,
|
||||
q_a_proj: ReplicatedLinear,
|
||||
q_a_layernorm: Any,
|
||||
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, q_lora_rank,
|
||||
kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim,
|
||||
qk_head_dim, v_head_dim, kv_b_proj, **mla_args)
|
||||
|
||||
self.rotary_emb = rotary_emb
|
||||
|
||||
self.q_proj = q_proj
|
||||
self.kv_b_proj = kv_b_proj
|
||||
self.o_proj = o_proj
|
||||
self.kv_a_proj_with_mqa = kv_a_proj_with_mqa
|
||||
self.kv_a_layernorm = kv_a_layernorm
|
||||
self.q_a_layernorm = q_a_layernorm
|
||||
self.q_a_proj = q_a_proj
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
cur_device = torch.supa.current_device()
|
||||
self.spc_num = torch_br.supa.get_device_properties(
|
||||
cur_device).max_compute_units
|
||||
|
||||
if envs.VLLM_BR_USE_FUSED_ALLREDUCE and self.tp_size == 8 and self.spc_num == 16:
|
||||
# Initialize the p2p info
|
||||
torch.supa.init_p2p_remote_id(cur_device)
|
||||
|
||||
assert self.q_lora_rank is not None
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"SUPAFlashMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"SUPAFlashMLAImpl")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"SUPAFlashMLA V1 with FP8 KV cache not yet supported")
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
|
||||
def get_layer_weight(layer):
|
||||
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
|
||||
for attr in WEIGHT_NAMES:
|
||||
if hasattr(layer, attr):
|
||||
return getattr(layer, attr)
|
||||
raise AttributeError(
|
||||
f"Layer '{layer}' has no recognized weight attribute:"
|
||||
f" {WEIGHT_NAMES}.")
|
||||
|
||||
def get_and_maybe_dequant_weights(layer: LinearBase):
|
||||
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
|
||||
# NOTE: This should only be used offline, since it's O(N^3)
|
||||
eye = torch.eye(layer.input_size_per_partition,
|
||||
dtype=act_dtype,
|
||||
device=get_layer_weight(layer).device)
|
||||
dequant_weights = layer.quant_method.apply(layer,
|
||||
eye,
|
||||
bias=None)
|
||||
del eye
|
||||
# standardize to (output, input)
|
||||
return dequant_weights.T
|
||||
return layer.weight
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
# handle deepseek_v3 weight
|
||||
w_q_a = get_and_maybe_dequant_weights(self.q_a_proj).T
|
||||
w_kv_a = get_and_maybe_dequant_weights(self.kv_a_proj_with_mqa).T
|
||||
w_qkv_a = torch.cat([w_q_a, w_kv_a], dim=-1)
|
||||
# w_qkv_a must make two copies in br166
|
||||
align_size = 32
|
||||
die_spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
||||
if die_spc_num > 16:
|
||||
w_qkv_a = torch.cat([w_qkv_a, w_qkv_a], dim=-1)
|
||||
self.w_qkv_a = _convert_to_numa_tensor(w_qkv_a, align_size,
|
||||
"colmajor", w_qkv_a.dtype)
|
||||
|
||||
w_kv_b = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
||||
w_k_b, w_v_b = w_kv_b.reshape(
|
||||
self.kv_lora_rank, -1,
|
||||
self.qk_nope_head_dim + self.v_head_dim).split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
w_k_b = w_k_b.permute(1, 2, 0).contiguous()
|
||||
w_v_b = w_v_b.permute(1, 0, 2).contiguous()
|
||||
|
||||
w_o = get_and_maybe_dequant_weights(self.o_proj.to(w_v_b.device)).T
|
||||
hidden_dim = w_o.shape[-1]
|
||||
w_o = w_o.reshape(-1, self.v_head_dim, hidden_dim)
|
||||
w_vo = torch.bmm(w_v_b, w_o).reshape(-1, hidden_dim)
|
||||
self.w_vo = _convert_to_numa_tensor(w_vo,
|
||||
align_size,
|
||||
"colmajor",
|
||||
w_qkv_a.dtype,
|
||||
parallel_type="row_parallel")
|
||||
|
||||
# replace q_b_proj as q_proj
|
||||
w_q_b = get_and_maybe_dequant_weights(self.q_proj).T
|
||||
w_q_b_nope, w_q_b_rope = w_q_b.reshape(
|
||||
self.q_lora_rank, -1, self.qk_head_dim).split(
|
||||
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
w_q_b_nope = w_q_b_nope.permute(1, 0, 2).contiguous()
|
||||
w_q_b_rope = w_q_b_rope.reshape(self.q_lora_rank, -1)
|
||||
|
||||
w_qk_b_nope = torch.bmm(w_q_b_nope, w_k_b).permute(
|
||||
1, 0, 2).contiguous().reshape(self.q_lora_rank, -1)
|
||||
# w_qk_b_nope w_q_b_rope is independent head, separate like QKVParallelLinear
|
||||
if die_spc_num > 16:
|
||||
qk_b_nope0, qk_b_nope1 = torch.chunk(w_qk_b_nope, 2, dim=-1)
|
||||
qk_b_rope0, qk_b_rope1 = torch.chunk(w_q_b_rope, 2, dim=-1)
|
||||
w_qk_b = torch.cat(
|
||||
[qk_b_nope0, qk_b_rope0, qk_b_nope1, qk_b_rope1], dim=-1)
|
||||
else:
|
||||
w_qk_b = torch.cat([w_qk_b_nope, w_q_b_rope], dim=-1)
|
||||
self.w_qk_b = _convert_to_numa_tensor(w_qk_b, align_size,
|
||||
"colmajor", w_qkv_a.dtype)
|
||||
|
||||
self.q_a_proj.weight = None
|
||||
self.kv_a_proj_with_mqa.weight = None
|
||||
self.q_proj.weight = None
|
||||
self.kv_b_proj.weight = None
|
||||
self.o_proj.weight = None
|
||||
|
||||
if self.kv_a_layernorm.weight.dtype != torch.float32:
|
||||
self.kv_a_layernorm.weight.data = self.kv_a_layernorm.weight.to(
|
||||
torch.float32)
|
||||
if self.q_a_layernorm.weight.dtype != torch.float32:
|
||||
self.q_a_layernorm.weight.data = self.q_a_layernorm.weight.to(
|
||||
torch.float32)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
torch.supa.empty_cache()
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: FlashMLAMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
hidden_states: torch.Tensor, # query in unified attn
|
||||
positions: torch.Tensor, # reuse k_c_normed as position
|
||||
k_pe: torch.Tensor, # value in unified attn
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: SupaFlashMLAMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with torch SPDA and PagedAttention.
|
||||
|
||||
Args:
|
||||
hidden_states: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [1, num_blocks, block_size * num_kv_heads * head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is None, "Output tensor should not provided."
|
||||
if envs.VLLM_BR_USE_CPU_ALL_REDUCE != 0 and not hasattr(
|
||||
self, "grandparent_pid"):
|
||||
self.grandparent_pid = get_grandparent_pid()
|
||||
|
||||
# profile and warm up mla attention kernel
|
||||
if attn_metadata is None:
|
||||
return hidden_states
|
||||
|
||||
# handle deepseek_v3 mla
|
||||
if hidden_states.shape[1] <= 512:
|
||||
query, key = torch_br.supa_mla_prefix_infer_v2(
|
||||
hidden_states, self.w_qkv_a, self.w_qk_b,
|
||||
self.q_a_layernorm.weight, self.kv_a_layernorm.weight,
|
||||
self.rotary_emb.sin_cache, self.rotary_emb.cos_cache,
|
||||
positions, kv_cache, attn_metadata.slot_mapping,
|
||||
self.num_heads, self.qk_head_dim, self.qk_nope_head_dim,
|
||||
self.qk_rope_head_dim, self.kv_lora_rank, self.v_head_dim,
|
||||
self.q_lora_rank, self.kv_a_layernorm.variance_epsilon)
|
||||
else:
|
||||
query, key = torch_br.supa_mla_prefix_infer_v3(
|
||||
hidden_states, self.w_qkv_a, self.w_qk_b,
|
||||
self.q_a_layernorm.weight, self.kv_a_layernorm.weight,
|
||||
self.rotary_emb.sin_cache, self.rotary_emb.cos_cache,
|
||||
positions, kv_cache, attn_metadata.slot_mapping,
|
||||
self.num_heads, self.qk_head_dim, self.qk_nope_head_dim,
|
||||
self.qk_rope_head_dim, self.kv_lora_rank, self.v_head_dim,
|
||||
self.q_lora_rank, self.kv_a_layernorm.variance_epsilon)
|
||||
|
||||
if query.shape[0] == 1:
|
||||
output = torch.empty_like(query)
|
||||
else:
|
||||
output = torch_br._empty_ut_only(
|
||||
[1, query.shape[1], query.shape[0] * self.kv_lora_rank],
|
||||
device=query.device,
|
||||
dtype=query.dtype,
|
||||
tensor_type="colmajor",
|
||||
axis=2,
|
||||
sbp="SB" if envs.VLLM_BR_DEVICE_SPC_NUM > 16 else None)
|
||||
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
#decoder_qloc = attn_metadata.query_start_loc[:attn_metadata.num_decodes + 1].cpu()
|
||||
#if decoder_qloc.shape[0] > 1:
|
||||
# assert torch.all(torch.diff(decoder_qloc) == 1), f"Must ensure that it is an increasing queue with a step of 1 !\nq_loc:{attn_metadata.query_start_loc}"
|
||||
#print("num_prefill_tokens:", num_prefill_tokens)
|
||||
if num_prefill_tokens > 0:
|
||||
assert len(query.shape) == 3
|
||||
output = torch_br.br_flash_attn_with_kvcache_infer( # type: ignore
|
||||
query,
|
||||
kv_cache,
|
||||
attn_metadata.query_start_loc,
|
||||
attn_metadata.seq_start_loc,
|
||||
attn_metadata.block_table,
|
||||
self.head_size,
|
||||
alibi_slopes=None,
|
||||
softmax_scale=self.scale,
|
||||
v_head_size=self.kv_lora_rank,
|
||||
num_reqs=attn_metadata.num_actual_reqs,
|
||||
)
|
||||
else:
|
||||
assert len(query.shape) == 3 and attn_metadata.num_prefills == 0
|
||||
output = torch_br.supa_attention_decoder_infer_v2( # type: ignore
|
||||
query, # type: ignore
|
||||
kv_cache,
|
||||
attn_metadata.block_table,
|
||||
attn_metadata.seq_lens,
|
||||
attn_metadata.max_decode_seq_len,
|
||||
self.head_size,
|
||||
attn_metadata.num_prefills,
|
||||
alibi_slopes=None,
|
||||
softmax_scale=self.scale,
|
||||
v_head_size=self.kv_lora_rank,
|
||||
)
|
||||
|
||||
# now linear+allreduce only support M <= 512 and tp_size == 4 | 8 and spc_num == 16
|
||||
seq_len = hidden_states.shape[-2]
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
support_types = ((16, 4), (16, 8), (32, 2), (32, 4))
|
||||
fused_comm = (envs.VLLM_BR_USE_FUSED_ALLREDUCE
|
||||
and seq_len <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN
|
||||
and
|
||||
(envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types)
|
||||
|
||||
if fused_comm:
|
||||
tp_rank = get_tp_group().rank_in_group
|
||||
global_rank = get_tp_group().rank
|
||||
rank_i = global_rank % tp_size
|
||||
assert rank_i == tp_rank
|
||||
o_proj_out = torch_br.supa_fused_linear_allreduce_opt(
|
||||
output, self.w_vo, hidden_states.shape[-1], tp_rank, tp_size,
|
||||
global_rank, 0)
|
||||
else:
|
||||
# do o_proj
|
||||
output_parallel = torch_br.br_fused_mlp_infer(
|
||||
output, [self.w_vo], output_w=hidden_states.shape[-1])
|
||||
if self.tp_size > 1:
|
||||
o_proj_out = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
o_proj_out = output_parallel
|
||||
|
||||
return o_proj_out
|
||||
450
vllm_br/v1/attention/backends/mla/flashmla_sparse.py
Normal file
450
vllm_br/v1/attention/backends/mla/flashmla_sparse.py
Normal file
@@ -0,0 +1,450 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
import itertools
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, ClassVar, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch_br
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionLayer, AttentionMetadata
|
||||
from vllm.attention.ops.flashmla import get_mla_metadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.flashmla_sparse import (
|
||||
FlashMLASparseBackend, FlashMLASparseImpl, FlashMLASparseMetadata,
|
||||
FlashMLASparseMetadataBuilder)
|
||||
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.deepseek_v2 import Indexer
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_NO_DEFAULT = object()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SupaFlashMLASparseMetadata(FlashMLASparseMetadata):
|
||||
# BIREN Attention Params
|
||||
seq_start_loc: torch.Tensor = _NO_DEFAULT
|
||||
context_lens: torch.Tensor = _NO_DEFAULT
|
||||
max_decode_seq_len: int = -1
|
||||
num_prefills: int = -1
|
||||
num_decodes: int = -1
|
||||
num_prefill_tokens: int = -1
|
||||
num_decode_tokens: int = -1
|
||||
|
||||
def __post_init__(self):
|
||||
if self.seq_start_loc is _NO_DEFAULT or self.context_lens is _NO_DEFAULT or \
|
||||
self.max_decode_seq_len == -1 or self.num_prefills == -1 or \
|
||||
self.num_decodes == -1 or self.num_prefill_tokens == -1 or \
|
||||
self.num_decode_tokens == -1:
|
||||
raise TypeError("__init__ missing required argument")
|
||||
|
||||
|
||||
class SupaFlashMLASparseMetadataBuilder(FlashMLASparseMetadataBuilder):
|
||||
|
||||
reorder_batch_threshold: int = 1
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
super().__init__(
|
||||
kv_cache_spec=kv_cache_spec,
|
||||
layer_names=layer_names,
|
||||
vllm_config=vllm_config,
|
||||
device=device,
|
||||
)
|
||||
self.vllm_config = vllm_config
|
||||
self.num_speculative_tokens = (
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
if self.vllm_config.speculative_config else 0)
|
||||
# Now deepgemm fp8_paged_mqa_logits does not support next_n > 2
|
||||
self.reorder_batch_threshold += min(self.num_speculative_tokens, 1)
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
"""On SUPA, we want prefill at front and decode at back.
|
||||
"""
|
||||
# TODO update doc
|
||||
# We now want to reorder the batch so that the "decode" requests are and
|
||||
# the front and the "prefill" requests are at the using the least amount
|
||||
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
|
||||
# where attention is likely memory-bound and "prefill" to mean requests
|
||||
# where attention is likely compute-bound, TODO(lucas): figure out a
|
||||
# better naming here)
|
||||
decodes = []
|
||||
prefills = []
|
||||
num_decode_tokens = 0
|
||||
num_prefill_tokens = 0
|
||||
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
num_spec_tokens = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
|
||||
# for now treat 1 scheduled token as "decode" even if its not,
|
||||
# we should update this to something like < 8 in the future but
|
||||
# currently the TritonMLA._forward_decode only supports
|
||||
# num_tokens = 1
|
||||
if num_tokens - num_spec_tokens == 1:
|
||||
decodes.append(i)
|
||||
num_decode_tokens += num_tokens
|
||||
else:
|
||||
prefills.append(i)
|
||||
num_prefill_tokens += num_tokens
|
||||
# TODO update doc
|
||||
# We hope that this is fairly minimal since decodes
|
||||
# should be around for a number of iterations so hopefully they are
|
||||
# relatively stationary (and new request are generally appended to the
|
||||
# persistent batch so already should be at the back)
|
||||
# To achieve this we loop over the decodes in descending order and
|
||||
# the prefills in ascending order. We swap decodes from the "back"
|
||||
# i.e. past where the last decode should be in the reodorered with
|
||||
# prefills from the front of the batch.
|
||||
# `decodes` and `prefills` are already in ascending order just based on
|
||||
# the above loop
|
||||
num_decodes = len(decodes)
|
||||
num_prefills = len(prefills)
|
||||
modified_batch = False
|
||||
|
||||
# for i in range(1, min(num_decodes, num_prefills) + 1):
|
||||
# # If the decode is at the "back" of the batch, i, we can swap it
|
||||
# # with the prefill closest to the front of the batch
|
||||
# decode_idx = decodes[num_decodes - i]
|
||||
# if decode_idx < num_decodes:
|
||||
# break
|
||||
|
||||
# input_batch.swap_states(prefills[i - 1], decode_idx)
|
||||
# modified_batch = True
|
||||
for i in range(1, min(num_decodes, num_prefills) + 1):
|
||||
# If the decode is at the "back" of the batch, i, we can swap it
|
||||
# with the prefill closest to the front of the batch
|
||||
prefills_idx = prefills[num_prefills - i]
|
||||
if prefills_idx < num_prefills:
|
||||
break
|
||||
|
||||
input_batch.swap_states(decodes[i - 1], prefills_idx)
|
||||
modified_batch = True
|
||||
|
||||
# Save for next `build` call
|
||||
# TODO(lucas): this is a bit of a hack, we should probably have a
|
||||
# better way of doing this
|
||||
self._num_decodes = num_decodes
|
||||
self._num_prefills = num_prefills
|
||||
self._num_decode_tokens = num_decode_tokens
|
||||
self._num_prefill_tokens = num_prefill_tokens
|
||||
|
||||
return modified_batch
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> SupaFlashMLASparseMetadata:
|
||||
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
starts = np.asarray(common_attn_metadata.query_start_loc_cpu,
|
||||
dtype=np.int32)
|
||||
seg_lengths = np.diff(starts)
|
||||
req_id_per_token = np.repeat(
|
||||
np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths)
|
||||
# Zero-fill for cudagraphs
|
||||
self.req_id_per_token_buffer.fill_(0)
|
||||
self.req_id_per_token_buffer[:req_id_per_token.shape[0]]\
|
||||
.copy_(torch.from_numpy(req_id_per_token), non_blocking=True)
|
||||
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
|
||||
|
||||
fp8_extra_metadata = None
|
||||
if self.use_fp8_kv_cache:
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||
cache_seqlens=self.topk_tokens_tensor,
|
||||
num_q_tokens_per_head_k=num_tokens * self.num_heads,
|
||||
topk=self.topk_tokens,
|
||||
num_heads_q=self.num_heads,
|
||||
num_heads_k=1,
|
||||
is_fp8_kvcache=True,
|
||||
)
|
||||
|
||||
num_sm_parts = tile_scheduler_metadata.size(0)
|
||||
# Copy to persistent buffer for full-CG support
|
||||
tile_scheduler_metadata_buffer = \
|
||||
self.tile_scheduler_metadata_buffer[:num_sm_parts]
|
||||
tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata)
|
||||
self.num_splits_buffer.copy_(num_splits)
|
||||
|
||||
fp8_extra_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
|
||||
scheduler_metadata=tile_scheduler_metadata_buffer,
|
||||
num_splits=self.num_splits_buffer,
|
||||
# cache_lens and block_table are basically unused in sparse case
|
||||
# but the decode kernel will treat -1 and indices >= cache_lens
|
||||
# as invalid so we make sure cache_lens is large enough to not
|
||||
# accidentally mark indices invalid, we will use -1 exclusively
|
||||
# to mark invalid indices
|
||||
cache_lens=self.max_model_len_tensor,
|
||||
dummy_block_table=self.dummy_block_table)
|
||||
|
||||
# Add biren attention params
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
|
||||
if common_attn_metadata.seq_start_loc is None:
|
||||
if len(seq_lens) > 8:
|
||||
seq_lens_cpu = seq_lens.cpu()
|
||||
seq_start_loc = torch.tensor(
|
||||
[0] + list(itertools.accumulate(seq_lens_cpu)),
|
||||
device=query_start_loc.device,
|
||||
dtype=torch.int32)
|
||||
else:
|
||||
seq_start_loc = torch.tensor(
|
||||
[0] + list(itertools.accumulate(seq_lens)),
|
||||
device=query_start_loc.device,
|
||||
dtype=torch.int32)
|
||||
else:
|
||||
seq_start_loc = common_attn_metadata.seq_start_loc
|
||||
|
||||
if common_attn_metadata.context_lens is None:
|
||||
context_lens = seq_lens - (query_start_loc[1:] -
|
||||
query_start_loc[:-1])
|
||||
else:
|
||||
context_lens = common_attn_metadata.context_lens
|
||||
|
||||
if common_attn_metadata.max_decode_seq_len is None:
|
||||
max_decode_seq_len = max_decode_seq_len = int(
|
||||
seq_lens.max().item())
|
||||
else:
|
||||
max_decode_seq_len = common_attn_metadata.max_decode_seq_len
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold)
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
assert num_decode_tokens + num_prefill_tokens == num_tokens
|
||||
|
||||
metadata = SupaFlashMLASparseMetadata(
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
max_seq_len=common_attn_metadata.max_seq_len,
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
block_table=common_attn_metadata.block_table_tensor,
|
||||
req_id_per_token=req_id_per_token,
|
||||
block_size=self.kv_cache_spec.block_size,
|
||||
topk_tokens=self.topk_tokens,
|
||||
fp8_extra_metadata=fp8_extra_metadata,
|
||||
seq_start_loc=seq_start_loc,
|
||||
context_lens=context_lens,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
num_prefills=num_prefills,
|
||||
num_decodes=num_decodes,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
)
|
||||
return metadata
|
||||
|
||||
|
||||
class SupaFlashMLASparseBackend(FlashMLASparseBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "SUPA_FLASHMLA_SPARSE_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type[AttentionMetadata]:
|
||||
return SupaFlashMLASparseMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["SupaFlashMLASparseMetadataBuilder"]:
|
||||
return SupaFlashMLASparseMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["SupaFlashMLASparseImpl"]:
|
||||
return SupaFlashMLASparseImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_usharp_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
th_gran = SupaFlashMLASparseBackend.get_kv_cache_usharp_alignment(
|
||||
block_size)
|
||||
n_block = max(1, (num_blocks + th_gran - 1) // th_gran)
|
||||
logger.debug(
|
||||
f'Origin kv cache shape is [2, {num_blocks}, {block_size}, {num_kv_heads}, {head_size}, For SUPA Speed up, use [2, {n_block}, {th_gran * block_size}, {num_kv_heads * head_size}]' # noqa: G004
|
||||
)
|
||||
return (2, n_block, th_gran * block_size, num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_usharp_alignment(block_size: int) -> int:
|
||||
max_h_limit = 2048
|
||||
return max_h_limit // block_size
|
||||
|
||||
|
||||
class SupaFlashMLASparseImpl(FlashMLASparseImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
topk_indice_buffer: Optional[torch.Tensor] = None,
|
||||
indexer: Optional["Indexer"] = None,
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, topk_indice_buffer,
|
||||
indexer, **mla_args)
|
||||
|
||||
def _forward_bf16_kv(
|
||||
self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
attn_metadata: SupaFlashMLASparseMetadata) -> torch.Tensor:
|
||||
bsz = 1
|
||||
seq_len_q, num_heads, _ = q.shape
|
||||
|
||||
# topk_indices = topk_indices.unsqueeze(0)
|
||||
index_mask = torch.full((bsz, seq_len_q, seq_len_q),
|
||||
1,
|
||||
dtype=torch.int32,
|
||||
device=q.device)
|
||||
# .scatter_(-1, valid_mask.to(torch.int64), 0).to(torch.int32).supa()
|
||||
|
||||
for idx_bsz in range(bsz):
|
||||
for idx_q in range(seq_len_q):
|
||||
for idx_k in range(topk_indices.shape[-1]):
|
||||
target_idx = topk_indices[idx_q][idx_k]
|
||||
if target_idx >= 0 and target_idx < seq_len_q:
|
||||
index_mask[idx_bsz][idx_q][topk_indices[idx_q]
|
||||
[idx_k]] = 0
|
||||
|
||||
query = q.transpose(0,
|
||||
1).contiguous() # [num_heads, seq_len, head_dim]
|
||||
# output is always [1, seq_len, num_heads * head_dim] however query;s shape is
|
||||
output = torch_br.supa_flash_attn_cache_infer(
|
||||
query,
|
||||
kv_c_and_k_pe_cache[:
|
||||
1], # [1, num_blocks, block_szie,self.head_size]
|
||||
attn_metadata.query_start_loc,
|
||||
attn_metadata.seq_start_loc,
|
||||
attn_metadata.block_table,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.slot_mapping,
|
||||
attn_metadata.max_seq_len,
|
||||
self.head_size,
|
||||
softmax_scale=self.softmax_scale,
|
||||
v_head_size=self.kv_lora_rank,
|
||||
mask=index_mask)
|
||||
|
||||
output = output.reshape(seq_len_q, num_heads,
|
||||
self.kv_lora_rank).contiguous()
|
||||
return output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
q: torch.Tensor,
|
||||
k_c_normed: torch.Tensor, # key in unified attn
|
||||
k_pe: torch.Tensor, # value in unified attn
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: SupaFlashMLASparseMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
|
||||
# MQA 576/512 approach for both prefill and decode
|
||||
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for MLACommonImpl")
|
||||
|
||||
if attn_metadata is None:
|
||||
# The zero fill is required when used with DP + EP
|
||||
# to ensure all ranks within a DP group compute the
|
||||
# same expert outputs.
|
||||
return output.fill_(0)
|
||||
|
||||
num_actual_toks = attn_metadata.num_actual_tokens
|
||||
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
|
||||
q = q[:num_actual_toks, ...]
|
||||
k_c_normed = k_c_normed[:num_actual_toks, ...]
|
||||
k_pe = k_pe[:num_actual_toks, ...]
|
||||
|
||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
|
||||
dim=-1)
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
q_nope = q_nope.transpose(0, 1)
|
||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||
ql_nope = torch.bmm(q_nope, self.W_UK_T)
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
ql_nope = ql_nope.transpose(0, 1)
|
||||
|
||||
topk_indices = self.topk_indices_buffer[:num_actual_toks]
|
||||
|
||||
# TODO: handle index / kv_cache correctly
|
||||
# topk_indices_global = triton_convert_req_index_to_global_index(
|
||||
# attn_metadata.req_id_per_token,
|
||||
# attn_metadata.block_table,
|
||||
# topk_indices,
|
||||
# BLOCK_SIZE=attn_metadata.block_size,
|
||||
# NUM_TOPK_TOKENS=attn_metadata.topk_tokens,
|
||||
# )
|
||||
|
||||
q = torch.cat([ql_nope, q_pe], dim=-1)
|
||||
|
||||
# write the latent and rope to kv cache
|
||||
if kv_cache.numel() > 0:
|
||||
_, num_blocks, block_size, head_size = kv_cache.shape
|
||||
k_pe_tmp = k_pe.squeeze(1).unsqueeze(0)
|
||||
key_supa = torch.cat([k_c_normed, k_pe_tmp], dim=2)
|
||||
torch_br.supa_kvcache_store_infer_v2(kv_cache, key_supa, key_supa,
|
||||
attn_metadata.slot_mapping,
|
||||
head_size)
|
||||
|
||||
if self.kv_cache_dtype != "fp8_ds_mla":
|
||||
attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices,
|
||||
attn_metadata)
|
||||
else:
|
||||
raise RuntimeError("Not support fp8 on br.")
|
||||
|
||||
self._v_up_proj(attn_out, out=output[:num_actual_toks])
|
||||
return output
|
||||
140
vllm_br/v1/attention/backends/mla/indexer.py
Normal file
140
vllm_br/v1/attention/backends/mla/indexer.py
Normal file
@@ -0,0 +1,140 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.indexer import (
|
||||
DeepseekV32IndexerBackend, DeepSeekV32IndexerDecodeMetadata,
|
||||
DeepseekV32IndexerMetadata, DeepseekV32IndexerMetadataBuilder,
|
||||
DeepseekV32IndexerPrefillMetadata, split_prefill_chunks)
|
||||
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SupaDeepseekV32IndexerBackend(DeepseekV32IndexerBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["SupaDeepseekV32IndexerMetadataBuilder"]:
|
||||
return SupaDeepseekV32IndexerMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_usharp_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
th_gran = SupaDeepseekV32IndexerBackend.get_kv_cache_usharp_alignment(
|
||||
block_size)
|
||||
n_block = max(1, (num_blocks + th_gran - 1) // th_gran)
|
||||
logger.debug(
|
||||
f'Origin kv cache shape is [1, {num_blocks}, {block_size}, {num_kv_heads}, {head_size}, For SUPA Speed up, use [1, {n_block}, {th_gran * block_size}, {num_kv_heads * head_size}]' # noqa: G004
|
||||
)
|
||||
return (1, n_block, th_gran * block_size, num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_usharp_alignment(block_size: int) -> int:
|
||||
max_h_limit = 2048
|
||||
return max_h_limit // block_size
|
||||
|
||||
|
||||
class SupaDeepseekV32IndexerMetadataBuilder(DeepseekV32IndexerMetadataBuilder):
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> DeepseekV32IndexerMetadata:
|
||||
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold)
|
||||
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
assert num_decode_tokens + num_prefill_tokens == num_tokens
|
||||
|
||||
prefill_metadata = None
|
||||
if num_prefills > 0:
|
||||
chunk_seq_ids = split_prefill_chunks(
|
||||
common_attn_metadata.seq_lens_cpu,
|
||||
self.max_prefill_buffer_size,
|
||||
num_decodes,
|
||||
)
|
||||
chunks = [
|
||||
self.build_one_prefill_chunk(
|
||||
reqs_start, reqs_end, query_start_loc_cpu,
|
||||
common_attn_metadata.seq_lens_cpu,
|
||||
common_attn_metadata.block_table_tensor)
|
||||
for reqs_start, reqs_end in chunk_seq_ids
|
||||
]
|
||||
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
|
||||
chunks=chunks, )
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
torch.diff(common_attn_metadata.query_start_loc[:num_decodes + 1],
|
||||
out=self.decode_lens_buffer[:num_decodes])
|
||||
decode_lens = self.decode_lens_buffer[:num_decodes]
|
||||
decode_lens_cpu = torch.diff(
|
||||
common_attn_metadata.query_start_loc_cpu[:num_decodes + 1])
|
||||
|
||||
# Use CPU to avoid GPU sync; breaking async scheduling
|
||||
requires_padding = (decode_lens_cpu.max()
|
||||
> decode_lens_cpu.min()).item()
|
||||
|
||||
# self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
|
||||
# seq_lens, self.kv_cache_spec.block_size, self.num_sms)
|
||||
self.scheduler_metadata_buffer = None
|
||||
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
|
||||
block_table=common_attn_metadata.
|
||||
block_table_tensor[:num_decodes, ...],
|
||||
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
|
||||
decode_lens=decode_lens,
|
||||
requires_padding=requires_padding,
|
||||
schedule_metadata=self.scheduler_metadata_buffer,
|
||||
)
|
||||
|
||||
attn_metadata = DeepseekV32IndexerMetadata(
|
||||
seq_lens=common_attn_metadata.seq_lens,
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
max_seq_len=common_attn_metadata.max_seq_len,
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
head_dim=128,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
prefill=prefill_metadata,
|
||||
decode=decode_metadata,
|
||||
)
|
||||
|
||||
# if get_tensor_model_parallel_rank() == 0:
|
||||
# logger.info(f"attn_metadata: {attn_metadata}")
|
||||
return attn_metadata
|
||||
47
vllm_br/v1/attention/backends/utils.py
Normal file
47
vllm_br/v1/attention/backends/utils.py
Normal file
@@ -0,0 +1,47 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm_br.config.compilation import SUPAGraphMode
|
||||
|
||||
|
||||
@dataclass
|
||||
class SUPACommonAttentionMetadata(CommonAttentionMetadata):
|
||||
"""
|
||||
Attention metadata attributes that can be shared by layers in different KV
|
||||
cache groups and thus having different block table.
|
||||
"""
|
||||
query_start_loc: torch.Tensor
|
||||
"""(batch_size + 1,), the start location of each request in query Tensor"""
|
||||
seq_lens: torch.Tensor
|
||||
"""(batch_size,), the length of each request including both computed tokens
|
||||
and newly scheduled tokens"""
|
||||
|
||||
num_actual_reqs: torch.Tensor | None = None
|
||||
"""(1,), numble of actual request in the batch"""
|
||||
supagraph_runtime_mode: SUPAGraphMode | None = None
|
||||
context_lens: torch.Tensor | None = None
|
||||
"""(batch_size,), the length of each request including computed tokens only"""
|
||||
max_decode_seq_len: int | None = None
|
||||
"""The maximum length of the decoded sequence in the batch."""
|
||||
seq_start_loc: torch.Tensor | None = None
|
||||
"""(batch_size + 1,), the start location of each request in sequence Tensor.
|
||||
This is used to compute the sequence length of each request.
|
||||
If not provided, it will be computed from seq_lens."""
|
||||
17
vllm_br/v1/core/__init__.py
Normal file
17
vllm_br/v1/core/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from . import kv_cache_utils, sched # noqa: F401
|
||||
BIN
vllm_br/v1/core/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm_br/v1/core/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm_br/v1/core/__pycache__/kv_cache_utils.cpython-310.pyc
Normal file
BIN
vllm_br/v1/core/__pycache__/kv_cache_utils.cpython-310.pyc
Normal file
Binary file not shown.
219
vllm_br/v1/core/kv_cache_utils.py
Normal file
219
vllm_br/v1/core/kv_cache_utils.py
Normal file
@@ -0,0 +1,219 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from fastcore.basics import patch_to
|
||||
|
||||
import vllm.v1.core.kv_cache_utils
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import logger
|
||||
from vllm.v1.core.kv_cache_utils import (
|
||||
create_kv_cache_group_specs, get_max_concurrency_for_kv_cache_config,
|
||||
get_uniform_page_size, may_override_num_blocks)
|
||||
from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheGroupSpec,
|
||||
KVCacheSpec, KVCacheTensor,
|
||||
UniformTypeKVCacheSpecs)
|
||||
from vllm_br.v1.attention.backends.attention_v1 import (
|
||||
SUPAFlashAttentionBackend)
|
||||
|
||||
|
||||
@patch_to(vllm.v1.core.kv_cache_utils)
|
||||
def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
|
||||
kv_cache_spec: dict[str, KVCacheSpec],
|
||||
available_memory: int) -> KVCacheConfig:
|
||||
"""
|
||||
Generates the KV cache configuration for a model with one type of KV cache.
|
||||
Divide the available memory equally among all layers.
|
||||
|
||||
Args:
|
||||
vllm_config: The global VllmConfig
|
||||
kv_cache_spec: The kv cache spec of each attention layer in the model
|
||||
available_memory: Memory available for KV cache in bytes.
|
||||
|
||||
Returns:
|
||||
The generated KVCacheConfig
|
||||
"""
|
||||
page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()}
|
||||
assert len(page_sizes) == 1
|
||||
page_size = page_sizes.pop()
|
||||
|
||||
# NOTE: SUPA has layouts
|
||||
# Both MLA/FlashAttention use the same gran
|
||||
th_gran = SUPAFlashAttentionBackend.get_kv_cache_usharp_alignment(
|
||||
vllm_config.cache_config.block_size)
|
||||
num_blocks = int(available_memory // page_size // len(kv_cache_spec))
|
||||
|
||||
# NOTE: limit gpu blocks number due to the shape restriction of colmajor layout
|
||||
num_blocks = min(th_gran * 1024, num_blocks // th_gran * th_gran)
|
||||
|
||||
num_blocks = max(num_blocks, 0)
|
||||
|
||||
if vllm_config.cache_config.num_gpu_blocks_override is not None:
|
||||
num_gpu_blocks_override = \
|
||||
vllm_config.cache_config.num_gpu_blocks_override
|
||||
logger.info(
|
||||
"Overriding num_gpu_blocks=%d with "
|
||||
"num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override)
|
||||
num_blocks = num_gpu_blocks_override
|
||||
|
||||
num_tokens = num_blocks * vllm_config.cache_config.block_size
|
||||
num_tokens_str = f"{num_tokens:,}"
|
||||
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
|
||||
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
|
||||
max_concurrency = num_tokens / vllm_config.model_config.max_model_len
|
||||
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
|
||||
max_model_len_str, max_concurrency)
|
||||
|
||||
per_layer_size = page_size * num_blocks
|
||||
# All layers have the same KV cache spec, so we create one kv cache group
|
||||
# for all layers.
|
||||
grouped_layer_names = [list(kv_cache_spec.keys())]
|
||||
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks,
|
||||
tensors={
|
||||
layer_name: KVCacheTensor(size=per_layer_size)
|
||||
for layer_name in kv_cache_spec
|
||||
},
|
||||
kv_cache_groups=create_kv_cache_group_specs(kv_cache_spec,
|
||||
grouped_layer_names),
|
||||
)
|
||||
return kv_cache_config
|
||||
|
||||
|
||||
logger.info('===[Patch] patch _get_kv_cache_config_uniform_type')
|
||||
|
||||
|
||||
# @patch_to(vllm.v1.core.kv_cache_utils)
|
||||
def get_num_blocks(vllm_config: VllmConfig, num_layers: int,
|
||||
available_memory: int, page_size: int) -> int:
|
||||
"""
|
||||
Get the number of kv cache blocks.
|
||||
|
||||
Args:
|
||||
vllm_config: The global VllmConfig
|
||||
num_layers: The number of layers
|
||||
available_memory: Memory available for KV cache in bytes.
|
||||
page_size: The page size of the KV cache.
|
||||
"""
|
||||
th_gran = SUPAFlashAttentionBackend.get_kv_cache_usharp_alignment(
|
||||
vllm_config.cache_config.block_size)
|
||||
num_blocks = int(available_memory // page_size // num_layers)
|
||||
num_blocks = min(th_gran * 1024, num_blocks // th_gran * th_gran)
|
||||
num_blocks = max(num_blocks, 0)
|
||||
num_blocks = may_override_num_blocks(vllm_config, num_blocks)
|
||||
return num_blocks
|
||||
|
||||
|
||||
@patch_to(vllm.v1.core.kv_cache_utils)
|
||||
def get_kv_cache_config_from_groups(vllm_config: VllmConfig,
|
||||
kv_cache_groups: list[KVCacheGroupSpec],
|
||||
kv_cache_specs: dict[str, KVCacheSpec],
|
||||
available_memory: int) -> KVCacheConfig:
|
||||
"""
|
||||
Generate the KV cache configuration from the KV cache groups and spec
|
||||
of each layer.
|
||||
|
||||
Args:
|
||||
vllm_config: The global VllmConfig
|
||||
kv_cache_groups: The KV cache groups
|
||||
kv_cache_specs: The KV cache spec of each attention layer in the model
|
||||
available_memory: Memory available for KV cache in bytes
|
||||
Returns:
|
||||
The generated KVCacheConfig
|
||||
"""
|
||||
if len(kv_cache_groups) == 0:
|
||||
# Attention free models do not have KV cache.
|
||||
# Return num_blocks=1 as BlockPool always needs a null_block.
|
||||
return KVCacheConfig(
|
||||
num_blocks=1,
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=kv_cache_groups,
|
||||
)
|
||||
|
||||
# Determine how model runners should initialize the KV cache tensors.
|
||||
# assert len(kv_cache_groups) == 1 # supa not support multi group
|
||||
if len(kv_cache_groups) == 1 and \
|
||||
isinstance(kv_cache_groups[0].kv_cache_spec, UniformTypeKVCacheSpecs):
|
||||
# Special case: all layers have the same type of KV cache but with
|
||||
# different hidden size. Allocate different amount of memory for each
|
||||
# layer based on its hidden size.
|
||||
th_gran = SUPAFlashAttentionBackend.get_kv_cache_usharp_alignment(
|
||||
vllm_config.cache_config.block_size)
|
||||
num_blocks = available_memory // kv_cache_groups[
|
||||
0].kv_cache_spec.page_size_bytes
|
||||
num_blocks = min(th_gran * 1024, num_blocks // th_gran * th_gran)
|
||||
|
||||
num_blocks = max(num_blocks, 0)
|
||||
num_blocks = may_override_num_blocks(vllm_config, num_blocks)
|
||||
per_layer_specs = kv_cache_groups[0].kv_cache_spec.kv_cache_specs
|
||||
kv_cache_tensors = [
|
||||
KVCacheTensor(size=per_layer_specs[layer_name].page_size_bytes *
|
||||
num_blocks,
|
||||
shared_by=[layer_name])
|
||||
for layer_name in kv_cache_groups[0].layer_names
|
||||
]
|
||||
else:
|
||||
# General case:
|
||||
# We will have group_size memory pools, each is shared by one layer from
|
||||
# each group. As layers of different groups have different block table,
|
||||
# they will use different parts of the shared Tensor.
|
||||
# The memory layout for 3 groups (full.0, full.1), (sw.0, sw.2),
|
||||
# (sw.1, padding) will be: (group_size = 2)
|
||||
# full.0, sw.0, sw.1: share a Tensor with size=available_memory//2
|
||||
# full.1, sw.2: share another Tensor with size=available_memory//2
|
||||
group_size = max(len(group.layer_names) for group in kv_cache_groups)
|
||||
|
||||
page_size = get_uniform_page_size(kv_cache_specs)
|
||||
assert group_size > 0, "group_size must be greater than 0"
|
||||
num_blocks = get_num_blocks(vllm_config, group_size, available_memory,
|
||||
page_size)
|
||||
kv_cache_tensors = []
|
||||
for i in range(group_size):
|
||||
shared_by = []
|
||||
for j in range(len(kv_cache_groups)):
|
||||
if i < len(kv_cache_groups[j].layer_names):
|
||||
shared_by.append(kv_cache_groups[j].layer_names[i])
|
||||
kv_cache_tensors.append(
|
||||
KVCacheTensor(size=page_size * num_blocks,
|
||||
shared_by=shared_by))
|
||||
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks,
|
||||
kv_cache_tensors=kv_cache_tensors,
|
||||
kv_cache_groups=kv_cache_groups,
|
||||
)
|
||||
|
||||
min_block_size = min(
|
||||
[group.kv_cache_spec.block_size for group in kv_cache_groups])
|
||||
|
||||
# Print the KV cache size and maximum concurrency.
|
||||
num_tokens = num_blocks // len(kv_cache_groups) * min_block_size
|
||||
if vllm_config.parallel_config.decode_context_parallel_size > 1:
|
||||
num_tokens *= vllm_config.parallel_config.decode_context_parallel_size
|
||||
logger.info(
|
||||
"Multiplying the GPU KV cache size by the dcp_world_size %d.",
|
||||
vllm_config.parallel_config.decode_context_parallel_size)
|
||||
num_tokens_str = f"{num_tokens:,}"
|
||||
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
|
||||
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
|
||||
max_concurrency = get_max_concurrency_for_kv_cache_config(
|
||||
vllm_config, kv_cache_config)
|
||||
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
|
||||
max_model_len_str, max_concurrency)
|
||||
return kv_cache_config
|
||||
|
||||
|
||||
logger.info('===[Patch] patch get_kv_cache_config_from_groups')
|
||||
17
vllm_br/v1/core/sched/__init__.py
Normal file
17
vllm_br/v1/core/sched/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from . import scheduler # noqa: F401
|
||||
BIN
vllm_br/v1/core/sched/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm_br/v1/core/sched/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm_br/v1/core/sched/__pycache__/scheduler.cpython-310.pyc
Normal file
BIN
vllm_br/v1/core/sched/__pycache__/scheduler.cpython-310.pyc
Normal file
Binary file not shown.
558
vllm_br/v1/core/sched/scheduler.py
Normal file
558
vllm_br/v1/core/sched/scheduler.py
Normal file
@@ -0,0 +1,558 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from fastcore.basics import patch_to
|
||||
|
||||
from vllm.distributed.kv_events import KVEventBatch
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
|
||||
SchedulerOutput)
|
||||
from vllm.v1.core.sched.request_queue import (SchedulingPolicy,
|
||||
create_request_queue)
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.engine import EngineCoreEventType
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@patch_to(Scheduler)
|
||||
def schedule(self) -> SchedulerOutput:
|
||||
# NOTE(woosuk) on the scheduling algorithm:
|
||||
# There's no "decoding phase" nor "prefill phase" in the scheduler.
|
||||
# Each request just has the num_computed_tokens and
|
||||
# num_tokens_with_spec. num_tokens_with_spec =
|
||||
# len(prompt_token_ids) + len(output_token_ids) + len(spec_token_ids).
|
||||
# At each step, the scheduler tries to assign tokens to the requests
|
||||
# so that each request's num_computed_tokens can catch up its
|
||||
# num_tokens_with_spec. This is general enough to cover
|
||||
# chunked prefills, prefix caching, speculative decoding,
|
||||
# and the "jump decoding" optimization in the future.
|
||||
|
||||
scheduled_new_reqs: list[Request] = []
|
||||
scheduled_resumed_reqs: list[Request] = []
|
||||
scheduled_running_reqs: list[Request] = []
|
||||
preempted_reqs: list[Request] = []
|
||||
|
||||
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
|
||||
num_scheduled_tokens: dict[str, int] = {}
|
||||
token_budget = self.max_num_scheduled_tokens
|
||||
# Encoder-related.
|
||||
scheduled_encoder_inputs: dict[str, list[int]] = {}
|
||||
encoder_compute_budget = self.max_num_encoder_input_tokens
|
||||
# Spec decode-related.
|
||||
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
|
||||
|
||||
# For logging.
|
||||
scheduled_timestamp = time.monotonic()
|
||||
|
||||
# First, schedule the RUNNING requests.
|
||||
req_index = 0
|
||||
while req_index < len(self.running) and token_budget > 0:
|
||||
request = self.running[req_index]
|
||||
|
||||
num_new_tokens = (request.num_tokens_with_spec +
|
||||
request.num_output_placeholders -
|
||||
request.num_computed_tokens)
|
||||
if (0 < self.scheduler_config.long_prefill_token_threshold <
|
||||
num_new_tokens):
|
||||
num_new_tokens = (
|
||||
self.scheduler_config.long_prefill_token_threshold)
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
|
||||
# Make sure the input position does not exceed the max model len.
|
||||
# This is necessary when using spec decoding.
|
||||
num_new_tokens = min(
|
||||
num_new_tokens,
|
||||
self.max_model_len - 1 - request.num_computed_tokens)
|
||||
|
||||
# Schedule encoder inputs.
|
||||
encoder_inputs_to_schedule = None
|
||||
new_encoder_compute_budget = encoder_compute_budget
|
||||
if request.has_encoder_inputs:
|
||||
(encoder_inputs_to_schedule, num_new_tokens,
|
||||
new_encoder_compute_budget) = self._try_schedule_encoder_inputs(
|
||||
request, request.num_computed_tokens, num_new_tokens,
|
||||
encoder_compute_budget)
|
||||
|
||||
if self.scheduler_config.chunked_prefill_enabled and request.num_output_tokens == 0:
|
||||
# shortest chunked prefill length is num_spec_tokens + 1
|
||||
prefill_schedul_threshold = self.num_spec_tokens + 1
|
||||
# Calculate remaining prompt tokens when request is in prefill phase
|
||||
remaining_prompt_tokens = request.num_tokens - request.num_computed_tokens - num_new_tokens
|
||||
if num_new_tokens > prefill_schedul_threshold:
|
||||
# Boundary condition: when remaining tokens equal or less than threshold,
|
||||
# reduce current round's token count to prevent phase misclassification
|
||||
# in reorder batch later in next round
|
||||
if 0 < remaining_prompt_tokens <= prefill_schedul_threshold:
|
||||
num_new_tokens -= (prefill_schedul_threshold -
|
||||
remaining_prompt_tokens + 1)
|
||||
num_new_tokens = 0 if num_new_tokens < prefill_schedul_threshold else num_new_tokens
|
||||
elif remaining_prompt_tokens > 0:
|
||||
# cannot schedule less than threshold tokens in chunked prefill
|
||||
num_new_tokens = 0
|
||||
|
||||
if num_new_tokens == 0:
|
||||
# The request cannot be scheduled because one of the following
|
||||
# reasons:
|
||||
# 1. No new tokens to schedule. This may happen when
|
||||
# (1) PP>1 and we have already scheduled all prompt tokens
|
||||
# but they are not finished yet.
|
||||
# (2) Async scheduling and the request has reached to either
|
||||
# its max_total_tokens or max_model_len.
|
||||
# 2. The encoder budget is exhausted.
|
||||
# 3. The encoder cache is exhausted.
|
||||
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
|
||||
# we do not strictly follow the FCFS scheduling policy and
|
||||
# allow the lower-priority requests to be scheduled.
|
||||
req_index += 1
|
||||
continue
|
||||
|
||||
while True:
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_new_tokens,
|
||||
num_lookahead_tokens=self.num_lookahead_tokens)
|
||||
if new_blocks is None:
|
||||
# The request cannot be scheduled.
|
||||
# Preempt the lowest-priority request.
|
||||
if self.policy == SchedulingPolicy.PRIORITY:
|
||||
preempted_req = max(
|
||||
self.running,
|
||||
key=lambda r: (r.priority, r.arrival_time),
|
||||
)
|
||||
self.running.remove(preempted_req)
|
||||
if preempted_req in scheduled_running_reqs:
|
||||
scheduled_running_reqs.remove(preempted_req)
|
||||
else:
|
||||
preempted_req = self.running.pop()
|
||||
|
||||
self.kv_cache_manager.free(preempted_req)
|
||||
self.encoder_cache_manager.free(preempted_req)
|
||||
preempted_req.status = RequestStatus.PREEMPTED
|
||||
preempted_req.num_computed_tokens = 0
|
||||
if self.log_stats:
|
||||
preempted_req.record_event(EngineCoreEventType.PREEMPTED,
|
||||
scheduled_timestamp)
|
||||
|
||||
self.waiting.prepend_request(preempted_req)
|
||||
preempted_reqs.append(preempted_req)
|
||||
if preempted_req == request:
|
||||
# No more request to preempt.
|
||||
can_schedule = False
|
||||
break
|
||||
else:
|
||||
# The request can be scheduled.
|
||||
can_schedule = True
|
||||
break
|
||||
if not can_schedule:
|
||||
break
|
||||
assert new_blocks is not None
|
||||
|
||||
# Schedule the request.
|
||||
scheduled_running_reqs.append(request)
|
||||
req_to_new_blocks[request.request_id] = new_blocks
|
||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||
token_budget -= num_new_tokens
|
||||
req_index += 1
|
||||
|
||||
# Speculative decode related.
|
||||
if request.spec_token_ids:
|
||||
num_scheduled_spec_tokens = (num_new_tokens +
|
||||
request.num_computed_tokens -
|
||||
request.num_tokens)
|
||||
if num_scheduled_spec_tokens > 0:
|
||||
# Trim spec_token_ids list to num_scheduled_spec_tokens.
|
||||
del request.spec_token_ids[num_scheduled_spec_tokens:]
|
||||
scheduled_spec_decode_tokens[request.request_id] = (
|
||||
request.spec_token_ids)
|
||||
|
||||
# Encoder-related.
|
||||
if encoder_inputs_to_schedule:
|
||||
scheduled_encoder_inputs[request.request_id] = (
|
||||
encoder_inputs_to_schedule)
|
||||
# Allocate the encoder cache.
|
||||
for i in encoder_inputs_to_schedule:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
encoder_compute_budget = new_encoder_compute_budget
|
||||
|
||||
# Record the LoRAs in scheduled_running_reqs
|
||||
scheduled_loras: set[int] = set()
|
||||
if self.lora_config:
|
||||
scheduled_loras = set(
|
||||
req.lora_request.lora_int_id for req in scheduled_running_reqs
|
||||
if req.lora_request and req.lora_request.lora_int_id > 0)
|
||||
assert len(scheduled_loras) <= self.lora_config.max_loras
|
||||
|
||||
# Use a temporary RequestQueue to collect requests that need to be
|
||||
# skipped and put back at the head of the waiting queue later
|
||||
skipped_waiting_requests = create_request_queue(self.policy)
|
||||
|
||||
# Next, schedule the WAITING requests.
|
||||
if not preempted_reqs:
|
||||
while self.waiting and token_budget > 0:
|
||||
if len(self.running) == self.max_num_running_reqs:
|
||||
break
|
||||
|
||||
request = self.waiting.peek_request()
|
||||
|
||||
# KVTransfer: skip request if still waiting for remote kvs.
|
||||
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
|
||||
is_ready = self._update_waiting_for_remote_kv(request)
|
||||
if is_ready:
|
||||
request.status = RequestStatus.WAITING
|
||||
else:
|
||||
logger.debug(
|
||||
"%s is still in WAITING_FOR_REMOTE_KVS state.",
|
||||
request.request_id)
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
# Skip request if the structured output request is still waiting
|
||||
# for FSM compilation.
|
||||
if request.status == RequestStatus.WAITING_FOR_FSM:
|
||||
structured_output_req = request.structured_output_request
|
||||
if structured_output_req and structured_output_req.grammar:
|
||||
request.status = RequestStatus.WAITING
|
||||
else:
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
# Check that adding the request still respects the max_loras
|
||||
# constraint.
|
||||
if (self.lora_config and request.lora_request and
|
||||
(len(scheduled_loras) == self.lora_config.max_loras
|
||||
and request.lora_request.lora_int_id not in scheduled_loras)):
|
||||
# Scheduling would exceed max_loras, skip.
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
num_external_computed_tokens = 0
|
||||
load_kv_async = False
|
||||
|
||||
# Get already-cached tokens.
|
||||
if request.num_computed_tokens == 0:
|
||||
# Get locally-cached tokens.
|
||||
new_computed_blocks, num_new_local_computed_tokens = \
|
||||
self.kv_cache_manager.get_computed_blocks(
|
||||
request)
|
||||
|
||||
# Get externally-cached tokens if using a KVConnector.
|
||||
if self.connector is not None:
|
||||
num_external_computed_tokens, load_kv_async = (
|
||||
self.connector.get_num_new_matched_tokens(
|
||||
request, num_new_local_computed_tokens))
|
||||
|
||||
if num_external_computed_tokens is None:
|
||||
# The request cannot be scheduled because
|
||||
# the KVConnector couldn't determine
|
||||
# the number of matched tokens.
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
# Total computed tokens (local + external).
|
||||
num_computed_tokens = (num_new_local_computed_tokens +
|
||||
num_external_computed_tokens)
|
||||
# KVTransfer: WAITING reqs have num_computed_tokens > 0
|
||||
# after async KV recvs are completed.
|
||||
else:
|
||||
new_computed_blocks = (
|
||||
self.kv_cache_manager.create_empty_block_list())
|
||||
num_new_local_computed_tokens = 0
|
||||
num_computed_tokens = request.num_computed_tokens
|
||||
|
||||
encoder_inputs_to_schedule = None
|
||||
new_encoder_compute_budget = encoder_compute_budget
|
||||
|
||||
# KVTransfer: loading remote KV, do not allocate for new work.
|
||||
if load_kv_async:
|
||||
assert num_external_computed_tokens > 0
|
||||
num_new_tokens = 0
|
||||
# Number of tokens to be scheduled.
|
||||
else:
|
||||
# We use `request.num_tokens` instead of
|
||||
# `request.num_prompt_tokens` to consider the resumed
|
||||
# requests, which have output tokens.
|
||||
num_new_tokens = request.num_tokens - num_computed_tokens
|
||||
if (0 < self.scheduler_config.long_prefill_token_threshold <
|
||||
num_new_tokens):
|
||||
num_new_tokens = (
|
||||
self.scheduler_config.long_prefill_token_threshold)
|
||||
|
||||
# chunked prefill has to be enabled explicitly to allow
|
||||
# pooling requests to be chunked
|
||||
if not self.scheduler_config.chunked_prefill_enabled and \
|
||||
num_new_tokens > token_budget:
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
assert num_new_tokens > 0
|
||||
|
||||
# Schedule encoder inputs.
|
||||
if request.has_encoder_inputs:
|
||||
(encoder_inputs_to_schedule, num_new_tokens,
|
||||
new_encoder_compute_budget
|
||||
) = self._try_schedule_encoder_inputs(
|
||||
request, num_computed_tokens, num_new_tokens,
|
||||
encoder_compute_budget)
|
||||
if num_new_tokens == 0:
|
||||
# The request cannot be scheduled.
|
||||
break
|
||||
|
||||
if num_new_tokens <= self.num_spec_tokens + 1:
|
||||
# Too short waiting requests can not be scheduled.
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
# Handles an edge case when P/D Disaggregation
|
||||
# is used with Spec Decoding where an
|
||||
# extra block gets allocated which
|
||||
# creates a mismatch between the number
|
||||
# of local and remote blocks.
|
||||
effective_lookahead_tokens = (0 if request.num_computed_tokens == 0
|
||||
else self.num_lookahead_tokens)
|
||||
|
||||
# Determine if we need to allocate cross-attention blocks.
|
||||
if self.is_encoder_decoder and request.has_encoder_inputs:
|
||||
# TODO(russellb): For Whisper, we know that the input is
|
||||
# always padded to the maximum length. If we support other
|
||||
# encoder-decoder models, this will need to be updated if we
|
||||
# want to only allocate what is needed.
|
||||
num_encoder_tokens =\
|
||||
self.scheduler_config.max_num_encoder_input_tokens
|
||||
else:
|
||||
num_encoder_tokens = 0
|
||||
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_new_tokens + num_external_computed_tokens,
|
||||
num_new_local_computed_tokens,
|
||||
new_computed_blocks,
|
||||
num_lookahead_tokens=effective_lookahead_tokens,
|
||||
delay_cache_blocks=load_kv_async,
|
||||
num_encoder_tokens=num_encoder_tokens,
|
||||
)
|
||||
|
||||
if new_blocks is None:
|
||||
# The request cannot be scheduled.
|
||||
break
|
||||
|
||||
# KVTransfer: the connector uses this info to determine
|
||||
# if a load is needed. Note that
|
||||
# This information is used to determine if a load is
|
||||
# needed for this request.
|
||||
if self.connector is not None:
|
||||
self.connector.update_state_after_alloc(
|
||||
request,
|
||||
new_computed_blocks + new_blocks,
|
||||
num_external_computed_tokens,
|
||||
)
|
||||
|
||||
# Request was already popped from self.waiting
|
||||
# unless it was re-added above due to new_blocks being None.
|
||||
request = self.waiting.pop_request()
|
||||
if load_kv_async:
|
||||
# If loading async, allocate memory and put request
|
||||
# into the WAITING_FOR_REMOTE_KV state.
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
continue
|
||||
|
||||
req_index += 1
|
||||
self.running.append(request)
|
||||
if self.log_stats:
|
||||
request.record_event(EngineCoreEventType.SCHEDULED,
|
||||
scheduled_timestamp)
|
||||
if request.status == RequestStatus.WAITING:
|
||||
scheduled_new_reqs.append(request)
|
||||
elif request.status == RequestStatus.PREEMPTED:
|
||||
scheduled_resumed_reqs.append(request)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid request status: {request.status}")
|
||||
|
||||
if self.lora_config and request.lora_request:
|
||||
scheduled_loras.add(request.lora_request.lora_int_id)
|
||||
req_to_new_blocks[request.request_id] = (
|
||||
self.kv_cache_manager.get_blocks(request.request_id))
|
||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||
token_budget -= num_new_tokens
|
||||
request.status = RequestStatus.RUNNING
|
||||
request.num_computed_tokens = num_computed_tokens
|
||||
# Count the number of prefix cached tokens.
|
||||
if request.num_cached_tokens < 0:
|
||||
request.num_cached_tokens = num_computed_tokens
|
||||
# Encoder-related.
|
||||
if encoder_inputs_to_schedule:
|
||||
scheduled_encoder_inputs[request.request_id] = (
|
||||
encoder_inputs_to_schedule)
|
||||
# Allocate the encoder cache.
|
||||
for i in encoder_inputs_to_schedule:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
encoder_compute_budget = new_encoder_compute_budget
|
||||
|
||||
# Put back any skipped requests at the head of the waiting queue
|
||||
if skipped_waiting_requests:
|
||||
self.waiting.prepend_requests(skipped_waiting_requests)
|
||||
|
||||
# Check if the scheduling constraints are satisfied.
|
||||
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
||||
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
|
||||
assert token_budget >= 0
|
||||
assert len(self.running) <= self.max_num_running_reqs
|
||||
# Since some requests in the RUNNING queue may not be scheduled in
|
||||
# this step, the total number of scheduled requests can be smaller than
|
||||
# len(self.running).
|
||||
assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) +
|
||||
len(scheduled_running_reqs) <= len(self.running))
|
||||
|
||||
# Get the longest common prefix among all requests in the running queue.
|
||||
# This can be potentially used for cascade attention.
|
||||
num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
|
||||
if self.running:
|
||||
any_request = self.running[0]
|
||||
num_common_prefix_blocks = (
|
||||
self.kv_cache_manager.get_num_common_prefix_blocks(
|
||||
any_request, len(self.running)))
|
||||
|
||||
# Construct the scheduler output.
|
||||
new_reqs_data = [
|
||||
NewRequestData.from_request(
|
||||
req, req_to_new_blocks[req.request_id].get_block_ids())
|
||||
for req in scheduled_new_reqs
|
||||
]
|
||||
cached_reqs_data = self._make_cached_request_data(
|
||||
scheduled_running_reqs,
|
||||
scheduled_resumed_reqs,
|
||||
num_scheduled_tokens,
|
||||
scheduled_spec_decode_tokens,
|
||||
req_to_new_blocks,
|
||||
)
|
||||
scheduled_requests = (scheduled_new_reqs + scheduled_running_reqs +
|
||||
scheduled_resumed_reqs)
|
||||
structured_output_request_ids, grammar_bitmask = (self.get_grammar_bitmask(
|
||||
scheduled_requests, scheduled_spec_decode_tokens))
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=new_reqs_data,
|
||||
scheduled_cached_reqs=cached_reqs_data,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
||||
scheduled_encoder_inputs=scheduled_encoder_inputs,
|
||||
num_common_prefix_blocks=num_common_prefix_blocks,
|
||||
# finished_req_ids is an existing state in the scheduler,
|
||||
# instead of being newly scheduled in this step.
|
||||
# It contains the request IDs that are finished in between
|
||||
# the previous and the current steps.
|
||||
finished_req_ids=self.finished_req_ids,
|
||||
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(
|
||||
),
|
||||
structured_output_request_ids=structured_output_request_ids,
|
||||
grammar_bitmask=grammar_bitmask,
|
||||
)
|
||||
|
||||
# NOTE(Kuntai): this function is designed for multiple purposes:
|
||||
# 1. Plan the KV cache store
|
||||
# 2. Wrap up all the KV cache load / save ops into an opaque object
|
||||
# 3. Clear the internal states of the connector
|
||||
if self.connector is not None:
|
||||
meta = self.connector.build_connector_meta(scheduler_output)
|
||||
scheduler_output.kv_connector_metadata = meta
|
||||
|
||||
# collect KV cache events from KV cache manager
|
||||
events = self.kv_cache_manager.take_events()
|
||||
|
||||
# collect KV cache events from connector
|
||||
if self.connector is not None:
|
||||
connector_events = self.connector.take_events()
|
||||
if connector_events:
|
||||
if events is None:
|
||||
events = list(connector_events)
|
||||
else:
|
||||
events.extend(connector_events)
|
||||
|
||||
# publish collected KV cache events
|
||||
if events:
|
||||
batch = KVEventBatch(ts=time.time(), events=events)
|
||||
self.kv_event_publisher.publish(batch)
|
||||
|
||||
self._update_after_schedule(scheduler_output)
|
||||
return scheduler_output
|
||||
|
||||
|
||||
@patch_to(Scheduler)
|
||||
def _make_cached_request_data(
|
||||
self,
|
||||
running_reqs: list[Request],
|
||||
resumed_reqs: list[Request],
|
||||
num_scheduled_tokens: dict[str, int],
|
||||
spec_decode_tokens: dict[str, list[int]],
|
||||
req_to_new_blocks: dict[str, KVCacheBlocks],
|
||||
) -> CachedRequestData:
|
||||
req_ids: list[str] = []
|
||||
new_token_ids: list[list[int]] = []
|
||||
new_block_ids: list[Optional[tuple[list[int], ...]]] = []
|
||||
num_computed_tokens: list[int] = []
|
||||
|
||||
use_connector = self.connector is not None
|
||||
for req in itertools.chain(running_reqs, resumed_reqs):
|
||||
req_id = req.request_id
|
||||
req_ids.append(req_id)
|
||||
num_tokens = (num_scheduled_tokens[req_id] -
|
||||
len(spec_decode_tokens.get(req_id, ())))
|
||||
# if self.use_pp:
|
||||
if not use_connector:
|
||||
# When using PP, the scheduler sends the sampled tokens back,
|
||||
# because there's no direct communication between the first-
|
||||
# stage worker and the last-stage worker. Otherwise, we don't
|
||||
# need to send the sampled tokens back because the model runner
|
||||
# will cache them.
|
||||
token_ids = req.all_token_ids[req.num_computed_tokens:req.
|
||||
num_computed_tokens + num_tokens]
|
||||
new_token_ids.append(token_ids)
|
||||
elif use_connector:
|
||||
# When using a KVConnector, we add a placeholder to avoid index
|
||||
# out of bounds errors. TODO: Remove this once the KVConnector
|
||||
# is updated to handle token IDs properly.
|
||||
new_token_ids.append([])
|
||||
new_block_ids.append(
|
||||
req_to_new_blocks[req_id].get_block_ids(allow_none=True))
|
||||
num_computed_tokens.append(req.num_computed_tokens)
|
||||
# Because resumed_reqs is usually empty, it is more efficient to do
|
||||
# in-place appending so that we don't need to allocate a new list.
|
||||
resumed_from_preemption = [False] * len(running_reqs)
|
||||
resumed_from_preemption += [True] * len(resumed_reqs)
|
||||
|
||||
return CachedRequestData(
|
||||
req_ids=req_ids,
|
||||
resumed_from_preemption=resumed_from_preemption,
|
||||
new_token_ids=new_token_ids,
|
||||
new_block_ids=new_block_ids,
|
||||
num_computed_tokens=num_computed_tokens,
|
||||
)
|
||||
19
vllm_br/v1/engine/__init__.py
Normal file
19
vllm_br/v1/engine/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from . import async_llm # noqa
|
||||
from . import core # noqa: F401
|
||||
from . import llm_engine # noqa
|
||||
BIN
vllm_br/v1/engine/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm_br/v1/engine/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm_br/v1/engine/__pycache__/async_llm.cpython-310.pyc
Normal file
BIN
vllm_br/v1/engine/__pycache__/async_llm.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm_br/v1/engine/__pycache__/core.cpython-310.pyc
Normal file
BIN
vllm_br/v1/engine/__pycache__/core.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm_br/v1/engine/__pycache__/llm_engine.cpython-310.pyc
Normal file
BIN
vllm_br/v1/engine/__pycache__/llm_engine.cpython-310.pyc
Normal file
Binary file not shown.
179
vllm_br/v1/engine/async_llm.py
Normal file
179
vllm_br/v1/engine/async_llm.py
Normal file
@@ -0,0 +1,179 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import socket
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from fastcore.basics import patch_to
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.tracing import init_tracer
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.output_processor import OutputProcessor
|
||||
from vllm.v1.engine.processor import Processor
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
|
||||
from vllm_br import envs as envs_br
|
||||
from vllm_br.utils import (create_cpu_all_reduce_shared_mem,
|
||||
get_cpu_all_reduce_shared_mem)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@patch_to(AsyncLLM)
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
use_cached_outputs: bool = False,
|
||||
log_requests: bool = True,
|
||||
start_engine_loop: bool = True,
|
||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||
client_addresses: Optional[dict[str, str]] = None,
|
||||
client_count: int = 1,
|
||||
client_index: int = 0,
|
||||
) -> None:
|
||||
"""
|
||||
Create an AsyncLLM.
|
||||
|
||||
Args:
|
||||
vllm_config: global configuration.
|
||||
executor_class: an Executor impl, e.g. MultiprocExecutor.
|
||||
log_stats: Whether to log stats.
|
||||
usage_context: Usage context of the LLM.
|
||||
mm_registry: Multi-modal registry.
|
||||
use_cached_outputs: Whether to use cached outputs.
|
||||
log_requests: Whether to log requests.
|
||||
start_engine_loop: Whether to start the engine loop.
|
||||
stat_loggers: customized stat loggers for the engine.
|
||||
If not provided, default stat loggers will be used.
|
||||
PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
|
||||
IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise ValueError(
|
||||
"Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
|
||||
"This should not happen. As a workaround, try using "
|
||||
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
|
||||
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
|
||||
if envs_br.VLLM_BR_USE_CPU_ALL_REDUCE != 0:
|
||||
create_cpu_all_reduce_shared_mem()
|
||||
# Ensure we can serialize custom transformer configs
|
||||
maybe_register_config_serialize_by_value()
|
||||
|
||||
self.model_config = vllm_config.model_config
|
||||
self.vllm_config = vllm_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
self.log_requests = log_requests
|
||||
|
||||
self.log_stats = log_stats or (stat_loggers is not None)
|
||||
if not log_stats and stat_loggers is not None:
|
||||
logger.info(
|
||||
"AsyncLLM created with log_stats=False and non-empty custom "
|
||||
"logger list; enabling logging without default stat loggers")
|
||||
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
self.tokenizer = None
|
||||
else:
|
||||
# Tokenizer (+ ensure liveness if running in another process).
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config)
|
||||
|
||||
# Processor (converts Inputs --> EngineCoreRequests).
|
||||
self.processor = Processor(
|
||||
vllm_config=vllm_config,
|
||||
tokenizer=self.tokenizer,
|
||||
mm_registry=mm_registry,
|
||||
)
|
||||
|
||||
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
|
||||
self.output_processor = OutputProcessor(self.tokenizer,
|
||||
log_stats=self.log_stats)
|
||||
if self.observability_config.otlp_traces_endpoint is not None:
|
||||
tracer = init_tracer("vllm.llm_engine",
|
||||
self.observability_config.otlp_traces_endpoint)
|
||||
self.output_processor.tracer = tracer
|
||||
|
||||
# EngineCore (starts the engine in background process).
|
||||
self.engine_core = EngineCoreClient.make_async_mp_client(
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=self.log_stats,
|
||||
client_addresses=client_addresses,
|
||||
client_count=client_count,
|
||||
client_index=client_index,
|
||||
)
|
||||
|
||||
# Loggers.
|
||||
self.logger_manager: Optional[StatLoggerManager] = None # type: ignore
|
||||
if self.log_stats:
|
||||
self.logger_manager = StatLoggerManager(
|
||||
vllm_config=vllm_config,
|
||||
engine_idxs=self.engine_core.engine_ranks_managed,
|
||||
custom_stat_loggers=stat_loggers,
|
||||
enable_default_loggers=log_stats,
|
||||
client_count=client_count,
|
||||
)
|
||||
self.logger_manager.log_engine_initialized()
|
||||
|
||||
self.output_handler: Optional[asyncio.Task] = None # type: ignore
|
||||
try:
|
||||
# Start output handler eagerly if we are in the asyncio eventloop.
|
||||
asyncio.get_running_loop()
|
||||
self._run_output_handler()
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
logger.info(
|
||||
"Torch profiler enabled. AsyncLLM CPU traces will be collected under %s", # noqa: E501
|
||||
envs.VLLM_TORCH_PROFILER_DIR)
|
||||
worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm"
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
],
|
||||
with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
envs.VLLM_TORCH_PROFILER_DIR,
|
||||
worker_name=worker_name,
|
||||
use_gzip=True))
|
||||
else:
|
||||
self.profiler = None
|
||||
|
||||
|
||||
@patch_to(AsyncLLM)
|
||||
def __del__(self):
|
||||
if get_cpu_all_reduce_shared_mem() is not None:
|
||||
get_cpu_all_reduce_shared_mem()._cleanup()
|
||||
self.shutdown()
|
||||
157
vllm_br/v1/engine/core.py
Normal file
157
vllm_br/v1/engine/core.py
Normal file
@@ -0,0 +1,157 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from fastcore.basics import patch_to
|
||||
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.logger import logger
|
||||
from vllm.v1.core.kv_cache_utils import (generate_scheduler_kv_cache_config,
|
||||
get_kv_cache_configs)
|
||||
from vllm.v1.engine import EngineCoreOutputs
|
||||
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
|
||||
|
||||
@patch_to(EngineCore)
|
||||
def _initialize_kv_caches(
|
||||
self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
|
||||
start = time.time()
|
||||
|
||||
# Get all kv cache needed by the model
|
||||
kv_cache_specs = self.model_executor.get_kv_cache_specs()
|
||||
|
||||
has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
|
||||
if has_kv_cache:
|
||||
if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
|
||||
dp_group = getattr(self, "dp_group", None)
|
||||
assert dp_group is not None
|
||||
self.available_gpu_memory_for_kv_cache = \
|
||||
ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
|
||||
available_gpu_memory = [self.available_gpu_memory_for_kv_cache
|
||||
] * len(kv_cache_specs)
|
||||
else:
|
||||
# Profiles the peak memory usage of the model to determine how
|
||||
# much memory can be allocated for kv cache.
|
||||
available_gpu_memory = (
|
||||
self.model_executor.determine_available_memory())
|
||||
self.available_gpu_memory_for_kv_cache = \
|
||||
available_gpu_memory[0]
|
||||
else:
|
||||
# Attention free models don't need memory for kv cache
|
||||
available_gpu_memory = [0] * len(kv_cache_specs)
|
||||
available_gpu_memory = self.model_executor.determine_available_memory()
|
||||
assert len(kv_cache_specs) == len(available_gpu_memory)
|
||||
|
||||
kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs,
|
||||
available_gpu_memory)
|
||||
scheduler_kv_cache_config = generate_scheduler_kv_cache_config(
|
||||
kv_cache_configs)
|
||||
num_gpu_blocks = scheduler_kv_cache_config.num_blocks
|
||||
num_cpu_blocks = 0
|
||||
|
||||
# Initialize kv cache and warmup the execution
|
||||
self.model_executor.initialize_from_config(kv_cache_configs)
|
||||
|
||||
elapsed = time.time() - start
|
||||
logger.info(("init engine (profile, create kv cache, "
|
||||
"warmup model) took %.2f seconds"), elapsed)
|
||||
return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
|
||||
|
||||
|
||||
@patch_to(EngineCore)
|
||||
def step_with_batch_queue(
|
||||
self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]:
|
||||
"""Schedule and execute batches with the batch queue.
|
||||
Note that if nothing to output in this step, None is returned.
|
||||
|
||||
The execution flow is as follows:
|
||||
1. Try to schedule a new batch if the batch queue is not full.
|
||||
If a new batch is scheduled, directly return an empty engine core
|
||||
output. In other words, fulfilling the batch queue has a higher priority
|
||||
than getting model outputs.
|
||||
2. If there is no new scheduled batch, meaning that the batch queue
|
||||
is full or no other requests can be scheduled, we block until the first
|
||||
batch in the job queue is finished.
|
||||
3. Update the scheduler from the output.
|
||||
"""
|
||||
batch_queue = self.batch_queue
|
||||
assert batch_queue is not None
|
||||
|
||||
# Try to schedule a new batch if the batch queue is not full, but
|
||||
# the scheduler may return an empty batch if all requests are scheduled.
|
||||
# Note that this is not blocking.
|
||||
assert len(batch_queue) < self.batch_queue_size
|
||||
|
||||
model_executed = False
|
||||
if self.scheduler.has_requests():
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
future = self.model_executor.execute_model(scheduler_output,
|
||||
non_block=True)
|
||||
batch_queue.appendleft(
|
||||
(future, scheduler_output)) # type: ignore[arg-type]
|
||||
|
||||
model_executed = scheduler_output.total_num_scheduled_tokens > 0
|
||||
if model_executed and len(batch_queue) < self.batch_queue_size \
|
||||
and not batch_queue[-1][0].done():
|
||||
# Don't block on next worker response unless the queue is full
|
||||
# or there are no more requests to schedule.
|
||||
return None, True
|
||||
|
||||
elif not batch_queue:
|
||||
# Queue is empty. We should not reach here since this method should
|
||||
# only be called when the scheduler contains requests or the queue
|
||||
# is non-empty.
|
||||
return None, False
|
||||
|
||||
# Block until the next result is available.
|
||||
future, scheduler_output = batch_queue.pop()
|
||||
model_output = self.execute_model_with_error_logging(
|
||||
lambda _: future.result(), scheduler_output)
|
||||
if scheduler_output.total_num_scheduled_tokens != 0:
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, model_output)
|
||||
if self.use_spec_decode:
|
||||
# Take the draft token ids.
|
||||
# draft_token_ids = self.model_executor.take_draft_token_ids()
|
||||
if model_output.draft_token_ids is not None:
|
||||
model_output.draft_token_ids.req_ids = model_output.req_ids
|
||||
self.scheduler.update_draft_token_ids(
|
||||
model_output.draft_token_ids)
|
||||
else:
|
||||
pass
|
||||
return engine_core_outputs, model_executed
|
||||
else:
|
||||
return None, False
|
||||
|
||||
|
||||
@patch_to(EngineCoreProc)
|
||||
def _process_engine_step(self) -> bool:
|
||||
"""Called only when there are unfinished local requests."""
|
||||
|
||||
# Step the engine core.
|
||||
outputs, model_executed = self.step_fn()
|
||||
# Put EngineCoreOutputs into the output queue.
|
||||
for output in (outputs.items() if outputs else ()):
|
||||
self.output_queue.put_nowait(output)
|
||||
# Post-step hook.
|
||||
# if outputs is not None:
|
||||
# self.post_step(model_executed)
|
||||
|
||||
return model_executed
|
||||
143
vllm_br/v1/engine/llm_engine.py
Normal file
143
vllm_br/v1/engine/llm_engine.py
Normal file
@@ -0,0 +1,143 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastcore.basics import patch_to
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import stateless_destroy_torch_distributed_process_group
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.tracing import init_tracer
|
||||
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.llm_engine import LLMEngine
|
||||
from vllm.v1.engine.output_processor import OutputProcessor
|
||||
from vllm.v1.engine.processor import Processor
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
|
||||
from vllm_br import envs as envs_br
|
||||
from vllm_br.utils import (create_cpu_all_reduce_shared_mem,
|
||||
get_cpu_all_reduce_shared_mem)
|
||||
|
||||
|
||||
@patch_to(LLMEngine)
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
use_cached_outputs: bool = False,
|
||||
multiprocess_mode: bool = False,
|
||||
) -> None:
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise ValueError("Using V1 LLMEngine, but envs.VLLM_USE_V1=False. "
|
||||
"This should not happen. As a workaround, try using "
|
||||
"LLMEngine.from_vllm_config(...) or explicitly set "
|
||||
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
|
||||
|
||||
if stat_loggers is not None:
|
||||
raise NotImplementedError(
|
||||
"Passing StatLoggers to LLMEngine in V1 is not yet supported. "
|
||||
"Set VLLM_USE_V1=0 and file and issue on Github.")
|
||||
if envs_br.VLLM_BR_USE_CPU_ALL_REDUCE != 0:
|
||||
create_cpu_all_reduce_shared_mem()
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
|
||||
self.log_stats = log_stats
|
||||
|
||||
executor_backend = (
|
||||
self.vllm_config.parallel_config.distributed_executor_backend)
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.external_launcher_dp = (parallel_config.data_parallel_size > 1
|
||||
and executor_backend == "external_launcher")
|
||||
# important: init dp group before init the engine_core
|
||||
# In the decoupled engine case this is handled in EngineCoreProc.
|
||||
if not multiprocess_mode and parallel_config.data_parallel_size > 1 \
|
||||
and not self.external_launcher_dp:
|
||||
self.dp_group = parallel_config.stateless_init_dp_group()
|
||||
else:
|
||||
self.dp_group = None
|
||||
self.should_execute_dummy_batch = False
|
||||
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
self.tokenizer = None
|
||||
else:
|
||||
# Tokenizer (+ ensure liveness if running in another process).
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config)
|
||||
|
||||
# Processor (convert Inputs --> EngineCoreRequests)
|
||||
self.processor = Processor(vllm_config=vllm_config,
|
||||
tokenizer=self.tokenizer,
|
||||
mm_registry=mm_registry)
|
||||
|
||||
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
|
||||
self.output_processor = OutputProcessor(self.tokenizer,
|
||||
log_stats=self.log_stats)
|
||||
if self.observability_config.otlp_traces_endpoint is not None:
|
||||
tracer = init_tracer("vllm.llm_engine",
|
||||
self.observability_config.otlp_traces_endpoint)
|
||||
self.output_processor.tracer = tracer
|
||||
|
||||
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
|
||||
self.engine_core = EngineCoreClient.make_client(
|
||||
multiprocess_mode=multiprocess_mode,
|
||||
asyncio_mode=False,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=self.log_stats,
|
||||
)
|
||||
|
||||
self.logger_manager: Optional[StatLoggerManager] = None # type: ignore
|
||||
if self.log_stats:
|
||||
self.logger_manager = StatLoggerManager(
|
||||
vllm_config=vllm_config,
|
||||
custom_stat_loggers=stat_loggers,
|
||||
enable_default_loggers=log_stats,
|
||||
)
|
||||
self.logger_manager.log_engine_initialized()
|
||||
|
||||
if not multiprocess_mode:
|
||||
# for v0 compatibility
|
||||
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore
|
||||
|
||||
if self.external_launcher_dp:
|
||||
# If we use DP in external launcher mode, we reuse the
|
||||
# existing DP group used for data communication.
|
||||
self.dp_group = get_dp_group().cpu_group
|
||||
|
||||
# Don't keep the dummy data in memory
|
||||
self.reset_mm_cache()
|
||||
|
||||
|
||||
@patch_to(LLMEngine)
|
||||
def __del__(self):
|
||||
if dp_group := getattr(self, "dp_group",
|
||||
None) and not self.external_launcher_dp:
|
||||
stateless_destroy_torch_distributed_process_group(dp_group)
|
||||
if get_cpu_all_reduce_shared_mem() is not None:
|
||||
get_cpu_all_reduce_shared_mem()._cleanup()
|
||||
20
vllm_br/v1/executor/__init__.py
Normal file
20
vllm_br/v1/executor/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
from vllm_br.executor.ray_distributed_executor import ( # noqa: F401
|
||||
_init_workers_ray_br)
|
||||
from . import ray_distributed_executor
|
||||
|
||||
__all__ = ["_init_workers_ray_br", "ray_distributed_executor"]
|
||||
BIN
vllm_br/v1/executor/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm_br/v1/executor/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
75
vllm_br/v1/executor/ray_distributed_executor.py
Normal file
75
vllm_br/v1/executor/ray_distributed_executor.py
Normal file
@@ -0,0 +1,75 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from concurrent.futures import Future
|
||||
from typing import Optional, Union
|
||||
|
||||
from vllm.executor.ray_utils import RayWorkerWrapper, ray
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.executor.ray_distributed_executor import RayDistributedExecutor
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
|
||||
class FutureWrapper(Future):
|
||||
"""A wrapper around a Ray output reference to meet the interface
|
||||
of .execute_model().
|
||||
"""
|
||||
|
||||
def __init__(self, ref):
|
||||
super().__init__()
|
||||
self.ref = ref
|
||||
|
||||
def result(self, timeout=None):
|
||||
if timeout is not None:
|
||||
raise NotImplementedError("timeout is not supported")
|
||||
return ray.get(self.ref)
|
||||
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output,
|
||||
non_block: bool = False,
|
||||
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
|
||||
# TODO: current only support non_block is True, need to apdapt new non_block param
|
||||
assert self.parallel_config.use_ray
|
||||
refs = []
|
||||
for pp_rank, tp_group in enumerate(self.pp_tp_workers):
|
||||
task_refs = [
|
||||
worker.execute_model_ray.remote(scheduler_output)
|
||||
for worker in tp_group
|
||||
]
|
||||
|
||||
last_pp_rank = len(self.pp_tp_workers) - 1
|
||||
if pp_rank == last_pp_rank:
|
||||
refs.extend(task_refs)
|
||||
|
||||
# When PP is not used, we block here until the result is available.
|
||||
if self.max_concurrent_batches == 1:
|
||||
return ray.get(refs[0])
|
||||
|
||||
# When PP is used, we return a FutureWrapper immediately so that
|
||||
# the scheduler can yield to the next batch.
|
||||
return FutureWrapper(refs[0])
|
||||
|
||||
|
||||
def execute_model_ray(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput) -> Optional[ModelRunnerOutput]:
|
||||
return self.worker.execute_model(scheduler_output)
|
||||
|
||||
|
||||
RayDistributedExecutor.execute_model = execute_model # type: ignore[attr-defined]
|
||||
RayWorkerWrapper.execute_model_ray = execute_model_ray # type: ignore[attr-defined]
|
||||
25
vllm_br/v1/kv_cache_interface.py
Normal file
25
vllm_br/v1/kv_cache_interface.py
Normal file
@@ -0,0 +1,25 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
# TODO(ychun) temp annotation
|
||||
# @property # type: ignore
|
||||
# def AttentionSpec_page_size_bytes(self) -> int:
|
||||
# # For MLA we only store a single latent vector, BR166 uses BB, so it needs to be multiplied by 2
|
||||
# coef = 1 if (self.use_mla and envs.VLLM_BR_DEVICE_SPC_NUM <= 16) else 2
|
||||
# return coef * self.block_size * self.num_kv_heads * self.head_size \
|
||||
# * get_dtype_size(self.dtype)
|
||||
|
||||
# AttentionSpec.page_size_bytes = AttentionSpec_page_size_bytes
|
||||
41
vllm_br/v1/outputs.py
Normal file
41
vllm_br/v1/outputs.py
Normal file
@@ -0,0 +1,41 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from fastcore.basics import patch_to
|
||||
|
||||
from vllm.v1.outputs import (DraftTokenIds, KVConnectorOutput, LogprobsLists,
|
||||
LogprobsTensors, ModelRunnerOutput)
|
||||
|
||||
|
||||
@patch_to(ModelRunnerOutput)
|
||||
def __init__(self,
|
||||
req_ids: list[str],
|
||||
req_id_to_index: dict[str, int],
|
||||
sampled_token_ids: list[list[int]],
|
||||
logprobs: Optional[LogprobsLists],
|
||||
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]],
|
||||
pooler_output: list[Optional[torch.Tensor]],
|
||||
kv_connector_output: Optional[KVConnectorOutput] = None,
|
||||
num_nans_in_logits: Optional[dict[str, int]] = None,
|
||||
draft_token_ids: Optional["DraftTokenIds"] = None):
|
||||
self._orig___init__(req_ids, req_id_to_index, sampled_token_ids, logprobs,
|
||||
prompt_logprobs_dict, pooler_output,
|
||||
kv_connector_output, num_nans_in_logits)
|
||||
|
||||
self.draft_token_ids = draft_token_ids
|
||||
17
vllm_br/v1/sample/__init__.py
Normal file
17
vllm_br/v1/sample/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from . import ops # noqa: F401
|
||||
BIN
vllm_br/v1/sample/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm_br/v1/sample/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
17
vllm_br/v1/sample/ops/__init__.py
Normal file
17
vllm_br/v1/sample/ops/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from . import logprobs, topk_topp_sampler # noqa: F401
|
||||
BIN
vllm_br/v1/sample/ops/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm_br/v1/sample/ops/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm_br/v1/sample/ops/__pycache__/logprobs.cpython-310.pyc
Normal file
BIN
vllm_br/v1/sample/ops/__pycache__/logprobs.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
40
vllm_br/v1/sample/ops/logprobs.py
Normal file
40
vllm_br/v1/sample/ops/logprobs.py
Normal file
@@ -0,0 +1,40 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
"""Some utilities for logprobs, including logits."""
|
||||
|
||||
import torch
|
||||
from fastcore.basics import patch_to
|
||||
|
||||
from vllm.v1.sample.ops import logprobs
|
||||
|
||||
|
||||
@patch_to(logprobs)
|
||||
def batched_count_greater_than(x: torch.Tensor,
|
||||
values: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Counts elements in each row of x that are greater than the corresponding
|
||||
value in values. Use torch.compile to generate an optimized kernel for
|
||||
this function. otherwise, it will create additional copies of the input
|
||||
tensors and cause memory issues.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): A 2D tensor of shape (batch_size, n_elements).
|
||||
values (torch.Tensor): A 2D tensor of shape (batch_size, 1).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A 1D tensor of shape (batch_size,) with the counts.
|
||||
"""
|
||||
return (x >= values).sum(-1)
|
||||
138
vllm_br/v1/sample/ops/topk_topp_sampler.py
Normal file
138
vllm_br/v1/sample/ops/topk_topp_sampler.py
Normal file
@@ -0,0 +1,138 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from fastcore.basics import patch_to
|
||||
|
||||
from vllm.v1.sample.ops import topk_topp_sampler
|
||||
|
||||
|
||||
def topk_topp_sampler_forward_native(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
generators: dict[int, torch.Generator],
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
PyTorch-native implementation of top-k and top-p sampling.
|
||||
|
||||
The logits tensor may be updated in-place.
|
||||
"""
|
||||
logits = apply_top_k_top_p(logits, k, p)
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return random_sample(probs, generators)
|
||||
|
||||
|
||||
def apply_top_k_only(
|
||||
logits: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply top-k mask to the logits.
|
||||
|
||||
This implementation doesn't involve sorting the entire vocab.
|
||||
|
||||
The logits tensor may be updated in-place.
|
||||
"""
|
||||
no_top_k_mask = k == logits.shape[1]
|
||||
# Set non-top-k rows to 1 so that we can gather.
|
||||
k = k.masked_fill(no_top_k_mask, 1)
|
||||
max_top_k = k.max()
|
||||
# topk.values tensor has shape [batch_size, max_top_k].
|
||||
# Convert top k to 0-based index in range [0, max_top_k).
|
||||
k_index = k.sub_(1).unsqueeze(1)
|
||||
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
|
||||
# Handle non-topk rows.
|
||||
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
|
||||
logits.masked_fill_(logits < top_k_mask, -float("inf"))
|
||||
return logits
|
||||
|
||||
|
||||
# scatter usage not support on br, need fix.
|
||||
@patch_to(topk_topp_sampler)
|
||||
def apply_top_k_top_p(
|
||||
logits: torch.Tensor,
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""Apply top-k and top-p masks to the logits.
|
||||
|
||||
If a top-p is used, this function will sort the logits tensor,
|
||||
which can be slow for large batches.
|
||||
|
||||
The logits tensor may be updated in-place.
|
||||
"""
|
||||
if p is None:
|
||||
if k is None:
|
||||
return logits
|
||||
|
||||
# Avoid sorting vocab for top-k only case.
|
||||
return apply_top_k_only(logits, k)
|
||||
|
||||
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
||||
|
||||
if k is not None:
|
||||
# Apply top-k.
|
||||
top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B
|
||||
# Get all the top_k values.
|
||||
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
|
||||
top_k_mask = logits_sort < top_k_mask
|
||||
logits_sort.masked_fill_(top_k_mask, -float("inf"))
|
||||
|
||||
if p is not None:
|
||||
# Apply top-p.
|
||||
probs_sort = logits_sort.softmax(dim=-1)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
|
||||
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
|
||||
# at least one
|
||||
top_p_mask[:, -1] = False
|
||||
logits_sort.masked_fill_(top_p_mask, -float("inf"))
|
||||
# Re-sort the probabilities.
|
||||
logits = logits_sort.clone()
|
||||
logits = logits.scatter_(dim=-1, index=logits_idx, src=logits_sort)
|
||||
return logits
|
||||
|
||||
|
||||
def random_sample(
|
||||
probs: torch.Tensor,
|
||||
generators: dict[int, torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
"""Randomly sample from the probabilities.
|
||||
|
||||
We use this function instead of torch.multinomial because torch.multinomial
|
||||
causes CPU-GPU synchronization.
|
||||
"""
|
||||
q = torch.empty_like(probs)
|
||||
# NOTE(woosuk): To batch-process the requests without their own seeds,
|
||||
# which is the common case, we first assume that every request does
|
||||
# not have its own seed. Then, we overwrite the values for the requests
|
||||
# that have their own seeds.
|
||||
if len(generators) != probs.shape[0]:
|
||||
q.exponential_()
|
||||
if generators:
|
||||
# TODO(woosuk): This can be slow because we handle each request
|
||||
# one by one. Optimize this.
|
||||
for i, generator in generators.items():
|
||||
q[i].exponential_(generator=generator)
|
||||
return probs.div_(q).argmax(dim=-1).view(-1)
|
||||
|
||||
|
||||
# vllm.v1.sample.ops.topk_topp_sampler.TopKTopPSampler.forward_native = topk_topp_sampler_forward_native
|
||||
17
vllm_br/v1/spec_decode/__init__.py
Normal file
17
vllm_br/v1/spec_decode/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from . import eagle # noqa: F401
|
||||
BIN
vllm_br/v1/spec_decode/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm_br/v1/spec_decode/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm_br/v1/spec_decode/__pycache__/eagle.cpython-310.pyc
Normal file
BIN
vllm_br/v1/spec_decode/__pycache__/eagle.cpython-310.pyc
Normal file
Binary file not shown.
265
vllm_br/v1/spec_decode/eagle.py
Normal file
265
vllm_br/v1/spec_decode/eagle.py
Normal file
@@ -0,0 +1,265 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from functools import wraps
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fastcore.basics import patch_to
|
||||
|
||||
import vllm_br.envs as biren_envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import is_pin_memory_available
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm_br.v1.worker.model_runner import SUPACommonAttentionMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
PADDING_SLOT_ID = -1
|
||||
|
||||
|
||||
def wrapper_EagleProposer_init(fn):
|
||||
# FIXME: temporary fix for enabling MLA in EagleProposer
|
||||
@wraps(fn)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
fn(self, *args, **kwargs)
|
||||
self.draft_model_config.weight_type = biren_envs.VLLM_BR_WEIGHT_TYPE
|
||||
self.draft_model_config.use_ds_mla = True
|
||||
self.draft_model_config.use_ds_mla_sparse = hasattr(
|
||||
self.draft_model_config.hf_config, "index_topk")
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
EagleProposer.__init__ = wrapper_EagleProposer_init(
|
||||
EagleProposer.__init__) # noqa: E501
|
||||
|
||||
|
||||
@patch_to(EagleProposer)
|
||||
def prepare_inputs(
|
||||
self,
|
||||
common_attn_metadata: SUPACommonAttentionMetadata,
|
||||
sampled_token_ids: list[list[int]],
|
||||
num_draft_tokens: list[int],
|
||||
) -> tuple[SUPACommonAttentionMetadata, torch.Tensor]:
|
||||
"""
|
||||
This function is used to prepare the inputs for speculative decoding.
|
||||
It updates to the common_attn_metadata to account for the rejected
|
||||
tokens (and newly sampled tokens). It also returns the token indices
|
||||
of the tokens that should be fed to the speculator.
|
||||
"""
|
||||
# E.g.
|
||||
# common_attn_metadata.query_start_loc{_cpu}:
|
||||
# [0, q1, q1 + q2, q1 + q2 + q3]
|
||||
# common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
|
||||
# num_rejected_tokens: [n1, n2, n3]
|
||||
# This function computes the intermediate values:
|
||||
# num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
|
||||
# And returns:
|
||||
# common_attn_metadata.query_start_loc{_cpu}:
|
||||
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
|
||||
# common_attn_metadata.seq_lens{_cpu}:
|
||||
# [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
|
||||
# token_indices: [0, 1, ..., q1 - n1 - 1,
|
||||
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
|
||||
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
|
||||
|
||||
num_rejected_tokens = [
|
||||
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
|
||||
for i, n in enumerate(num_draft_tokens)
|
||||
]
|
||||
num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32)
|
||||
|
||||
device = common_attn_metadata.query_start_loc.device
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \
|
||||
- num_rejected_tokens
|
||||
|
||||
# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
|
||||
new_query_len_per_req = (query_start_loc_cpu[1:] -
|
||||
query_start_loc_cpu[:-1])
|
||||
# [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
|
||||
new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
|
||||
new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()
|
||||
|
||||
# [q1 - n1, q2 - n2, q3 - n3] ->
|
||||
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
|
||||
new_query_start_loc_cpu = torch.zeros(query_start_loc_cpu.shape,
|
||||
dtype=torch.int32,
|
||||
pin_memory=is_pin_memory_available())
|
||||
new_query_start_loc_np = new_query_start_loc_cpu.numpy()
|
||||
np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])
|
||||
|
||||
total_num_tokens = new_query_start_loc_np[-1]
|
||||
# Example assuming num_tokens_per_req_np = [2, 4, 3]
|
||||
# this implies that `new_query_start_locs` is:
|
||||
# [0, 2, 6, 9] ->
|
||||
# [0, 0, 2, 2, 2, 2, 6, 6, 6]
|
||||
# _r1_ ____r2____ ___r3__
|
||||
new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1],
|
||||
new_num_tokens_per_req_np)
|
||||
# [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
|
||||
# [0, 1, 0, 1, 2, 3, 0, 1, 2]
|
||||
# _r1_ ____r2____ ___r3__
|
||||
token_offests = self.token_arange_np[:total_num_tokens] \
|
||||
- new_query_start_locs_expanded
|
||||
|
||||
# Expand starting positions to match token pattern
|
||||
# [0, q1, q1 + q2] ->
|
||||
# [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
|
||||
# _r1_ _____r2_______ ___________r3____________
|
||||
old_query_start_locs_expanded = np.repeat(query_start_loc_cpu[:-1].numpy(),
|
||||
new_num_tokens_per_req_np)
|
||||
# Final token indices are:
|
||||
# [0, 1, // req 1
|
||||
# q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
|
||||
# q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
|
||||
token_indices_np = token_offests + old_query_start_locs_expanded
|
||||
token_indices = torch.from_numpy(token_indices_np).to(device,
|
||||
non_blocking=True)
|
||||
|
||||
# seq_start_loc = torch.from_numpy(
|
||||
# np.insert(np.add.accumulate(common_attn_metadata.seq_lens.cpu().numpy()), 0,
|
||||
# 0)).to(common_attn_metadata.query_start_loc, non_blocking=True)
|
||||
spec_common_attn_metadata = SUPACommonAttentionMetadata(
|
||||
query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True),
|
||||
seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
|
||||
query_start_loc_cpu=new_query_start_loc_cpu,
|
||||
seq_lens_cpu=new_seq_lens_cpu,
|
||||
num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
num_actual_tokens=total_num_tokens,
|
||||
max_query_len=new_query_len_per_req.max().item(),
|
||||
max_seq_len=new_seq_lens_cpu.max().item(),
|
||||
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||||
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
|
||||
causal=True,
|
||||
# seq_start_loc=seq_start_loc
|
||||
)
|
||||
return spec_common_attn_metadata, token_indices
|
||||
|
||||
|
||||
@patch_to(EagleProposer)
|
||||
def prepare_inputs_padded(self,
|
||||
common_attn_metadata: SUPACommonAttentionMetadata,
|
||||
spec_decode_metadata: SpecDecodeMetadata,
|
||||
valid_sampled_tokens_count: torch.Tensor) -> \
|
||||
tuple[SUPACommonAttentionMetadata, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
This function is used to prepare the inputs for speculative decoding
|
||||
It updates the common_attn_metadata for speculative decoding,
|
||||
but does not consider the rejected tokens. Instead, all tokens
|
||||
are included as inputs to the speculator, with the rejected tokens
|
||||
used as padding and filtered out later by `token_indices_to_sample`.
|
||||
No blocking CPU operations should be introduced in this function.
|
||||
"""
|
||||
num_draft_tokens_gpu = torch.cat([
|
||||
spec_decode_metadata.cu_num_draft_tokens[0:1],
|
||||
spec_decode_metadata.cu_num_draft_tokens[1:] -
|
||||
spec_decode_metadata.cu_num_draft_tokens[:-1]
|
||||
])
|
||||
|
||||
num_rejected_tokens_gpu = torch.where(
|
||||
num_draft_tokens_gpu > 0,
|
||||
num_draft_tokens_gpu + 1 - valid_sampled_tokens_count,
|
||||
torch.zeros_like(num_draft_tokens_gpu))
|
||||
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
|
||||
new_query_len_per_req = (query_start_loc_cpu[1:] -
|
||||
query_start_loc_cpu[:-1])
|
||||
|
||||
total_num_tokens = query_start_loc_cpu[-1].item()
|
||||
token_indices = self.arange[:total_num_tokens]
|
||||
|
||||
seq_start_loc = torch.from_numpy(
|
||||
np.insert(
|
||||
np.add.accumulate(common_attn_metadata.seq_lens.cpu().numpy()), 0,
|
||||
0)).to(common_attn_metadata.query_start_loc, non_blocking=True)
|
||||
spec_common_attn_metadata = SUPACommonAttentionMetadata(
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
seq_lens=common_attn_metadata.seq_lens,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
|
||||
num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
num_actual_tokens=total_num_tokens,
|
||||
max_query_len=new_query_len_per_req.max().item(),
|
||||
max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
|
||||
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||||
slot_mapping=common_attn_metadata.slot_mapping[token_indices.long()],
|
||||
causal=True,
|
||||
# context_lens=context_lens,
|
||||
# max_decode_seq_len=self.seq_lens.np[:num_reqs].max(),
|
||||
seq_start_loc=seq_start_loc)
|
||||
|
||||
|
||||
token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1 \
|
||||
- num_rejected_tokens_gpu
|
||||
|
||||
return spec_common_attn_metadata, token_indices, token_indices_to_sample
|
||||
|
||||
|
||||
def wrapper_EagleProposer_propose(fn):
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(
|
||||
self,
|
||||
# [num_tokens]
|
||||
target_token_ids: torch.Tensor,
|
||||
# [num_tokens]
|
||||
target_positions: torch.Tensor,
|
||||
# [num_tokens, hidden_size]
|
||||
target_hidden_states: torch.Tensor,
|
||||
# [batch_size]
|
||||
next_token_ids: torch.Tensor,
|
||||
last_token_indices: Optional[torch.Tensor],
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
mm_embeds: Optional[list[torch.Tensor]] = None,
|
||||
):
|
||||
|
||||
if last_token_indices is None:
|
||||
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
|
||||
|
||||
last_token_indices = last_token_indices.long()
|
||||
|
||||
return fn(
|
||||
self,
|
||||
# [num_tokens]
|
||||
target_token_ids,
|
||||
# [num_tokens]
|
||||
target_positions,
|
||||
# [num_tokens, hidden_size]
|
||||
target_hidden_states,
|
||||
# [batch_size]
|
||||
next_token_ids,
|
||||
last_token_indices,
|
||||
common_attn_metadata,
|
||||
sampling_metadata,
|
||||
mm_embeds)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
EagleProposer.propose = wrapper_EagleProposer_propose(
|
||||
EagleProposer.propose) # noqa: E501
|
||||
15
vllm_br/v1/worker/__init__.py
Normal file
15
vllm_br/v1/worker/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
BIN
vllm_br/v1/worker/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm_br/v1/worker/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vllm_br/v1/worker/__pycache__/model_runner.cpython-310.pyc
Normal file
BIN
vllm_br/v1/worker/__pycache__/model_runner.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
vllm_br/v1/worker/__pycache__/ubatching.cpython-310.pyc
Normal file
BIN
vllm_br/v1/worker/__pycache__/ubatching.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm_br/v1/worker/__pycache__/utils.cpython-310.pyc
Normal file
BIN
vllm_br/v1/worker/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm_br/v1/worker/__pycache__/worker.cpython-310.pyc
Normal file
BIN
vllm_br/v1/worker/__pycache__/worker.cpython-310.pyc
Normal file
Binary file not shown.
49
vllm_br/v1/worker/kv_connector_model_runner_mixin.py
Normal file
49
vllm_br/v1/worker/kv_connector_model_runner_mixin.py
Normal file
@@ -0,0 +1,49 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
import copy
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||
KVConnectorModelRunnerMixin)
|
||||
from vllm_br.forward_context import set_forward_context
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
|
||||
# @staticmethod
|
||||
def kv_connector_no_forward(scheduler_output: "SchedulerOutput",
|
||||
vllm_config: VllmConfig) -> ModelRunnerOutput:
|
||||
# KV send/recv even if no work to do.
|
||||
with set_forward_context(
|
||||
None,
|
||||
vllm_config), KVConnectorModelRunnerMixin._get_kv_connector_output(
|
||||
scheduler_output, wait_for_save=False) as kv_connector_output:
|
||||
pass
|
||||
|
||||
if (not kv_connector_output.finished_sending
|
||||
and not kv_connector_output.finished_recving):
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
output.kv_connector_output = kv_connector_output
|
||||
return output
|
||||
|
||||
|
||||
KVConnectorModelRunnerMixin.kv_connector_no_forward = kv_connector_no_forward
|
||||
4595
vllm_br/v1/worker/model_runner.py
Normal file
4595
vllm_br/v1/worker/model_runner.py
Normal file
File diff suppressed because it is too large
Load Diff
413
vllm_br/v1/worker/supa_ubatch_wrapper.py
Normal file
413
vllm_br/v1/worker/supa_ubatch_wrapper.py
Normal file
@@ -0,0 +1,413 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
set_graph_pool_id)
|
||||
from vllm.forward_context import (create_forward_context, get_forward_context,
|
||||
override_forward_context)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
|
||||
from vllm_br.compilation.supa_graph import SUPAGraphWrapper
|
||||
from vllm_br.config.compilation import SUPAGraphMode
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UbatchMetadata:
|
||||
context: UBatchContext
|
||||
input_ids: torch.Tensor
|
||||
positions: torch.Tensor
|
||||
inputs_embeds: Optional[torch.Tensor]
|
||||
intermediate_tensors: Optional[IntermediateTensors]
|
||||
num_tokens: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class SUPAGraphMetaData:
|
||||
supagraph: torch.supa.SUPAGraph
|
||||
ubatch_metadata: UbatchMetadata
|
||||
outputs: Optional[Any] = None
|
||||
|
||||
|
||||
class SMControlContextManager:
|
||||
|
||||
def __init__(self, comm_sms: int, set_comm_sms: Callable[[int], None],
|
||||
set_compute_sms: Callable[[int], None]):
|
||||
"""
|
||||
Context manager for controlling SM (Streaming Multiprocessor)
|
||||
allocation. Upon entering the context, it sets the number of SMs
|
||||
allocated for communication and computation to comm_sms and
|
||||
total_sms - comm_sms respectively. Upon exiting, it restores the
|
||||
allocation to use all available SMs (i.e. total_sms).
|
||||
|
||||
Args:
|
||||
comm_sms (int): The number of SMs to allocate for communication.
|
||||
(The remainder will be used for computation.)
|
||||
set_comm_sms (Callable[[int], None]):
|
||||
A function that sets the number of SMs for communication.
|
||||
set_compute_sms (Callable[[int], None]):
|
||||
A function that sets the number of SMs for computation.
|
||||
"""
|
||||
|
||||
assert current_platform.is_supa(), \
|
||||
"SM control is currently only supported on SUPA"
|
||||
|
||||
props = torch.supa.get_device_properties(torch.supa.current_device())
|
||||
total_sms = props.multi_processor_count
|
||||
|
||||
assert comm_sms < total_sms
|
||||
self.total_sms = total_sms
|
||||
self.compute_sms = total_sms - comm_sms
|
||||
self.comm_sms = comm_sms
|
||||
self.set_comm_sms = set_comm_sms
|
||||
self.set_compute_sms = set_compute_sms
|
||||
|
||||
def __enter__(self):
|
||||
self.set_comm_sms(self.comm_sms)
|
||||
self.set_compute_sms(self.compute_sms)
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.set_comm_sms(self.total_sms)
|
||||
self.set_compute_sms(self.total_sms)
|
||||
|
||||
|
||||
class UBatchWrapper:
|
||||
|
||||
def __init__(self, runnable: Callable, vllm_config: VllmConfig,
|
||||
runtime_mode: SUPAGraphMode, device: torch.supa.device):
|
||||
self.runnable = runnable
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.comm_stream = torch.supa.Stream(device=device)
|
||||
# Two ubatch threads plus the main thread
|
||||
self.ready_barrier = threading.Barrier(3)
|
||||
|
||||
self.supagraphs: dict[int, SUPAGraphMetaData] = {}
|
||||
|
||||
self.supagraph_wrapper = None
|
||||
self.graph_pool = None
|
||||
if runtime_mode is not SUPAGraphMode.NONE:
|
||||
self.supagraph_wrapper = SUPAGraphWrapper(
|
||||
runnable, vllm_config, runtime_mode=runtime_mode)
|
||||
self.graph_pool = current_platform.get_global_graph_pool()
|
||||
|
||||
self.sm_control = self._create_sm_control_context(vllm_config)
|
||||
self.device = device
|
||||
|
||||
@staticmethod
|
||||
def _create_sm_control_context(vllm_config: VllmConfig):
|
||||
comm_sms = envs.VLLM_DBO_COMM_SMS
|
||||
|
||||
set_comm_sms = lambda sms: None
|
||||
if vllm_config.parallel_config.enable_expert_parallel:
|
||||
# Currently only DeepEP highthroughput supports SM control so this
|
||||
# only affects that case.
|
||||
all2all_manager = get_ep_group(
|
||||
).device_communicator.all2all_manager
|
||||
|
||||
if all2all_manager.max_sms_used() is not None:
|
||||
comm_sms = min(comm_sms, all2all_manager.max_sms_used())
|
||||
|
||||
if comm_sms > 0:
|
||||
set_comm_sms = lambda sms: all2all_manager.set_num_sms(sms)
|
||||
|
||||
# TODO(lucas): support other kernels besides DeepGEMM
|
||||
set_compute_sms = lambda sms: None
|
||||
if has_deep_gemm() and comm_sms > 0:
|
||||
import deep_gemm as dg
|
||||
set_compute_sms = lambda sms: dg.set_num_sms(sms)
|
||||
|
||||
return SMControlContextManager(comm_sms=comm_sms,
|
||||
set_comm_sms=set_comm_sms,
|
||||
set_compute_sms=set_compute_sms)
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
# allow accessing the attributes of the runnable.
|
||||
if hasattr(self.runnable, key):
|
||||
return getattr(self.runnable, key)
|
||||
raise AttributeError(f"Attribute {key} not exists in the runnable of "
|
||||
f"supagraph wrapper: {self.runnable}")
|
||||
|
||||
def unwrap(self) -> Callable:
|
||||
# in case we need to access the original runnable.
|
||||
return self.runnable
|
||||
|
||||
def _capture_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
|
||||
"""
|
||||
Capture a supagraph for a microbatched run.
|
||||
|
||||
The logic here is somewhat complicated because we need to make sure that
|
||||
each of the ubatch threads initialize the supa context before we start
|
||||
the graph capture.
|
||||
|
||||
The flow is as follows:
|
||||
1. The main thread starts up each ubatch thread. Each thread will
|
||||
initialize its supa context (torch.supa.current_blas_handle())
|
||||
before going to sleep upon entering the ubatch_context.
|
||||
|
||||
2. The main thread starts the graph capture and wakes up the first
|
||||
ubatch thread.
|
||||
|
||||
3. Each ubatch thread runs the model to completion and returns the
|
||||
completed output tensors back to the main thread.
|
||||
|
||||
4. The main thread stores the captured supagraph along with its metadata
|
||||
and returns
|
||||
"""
|
||||
|
||||
@torch.inference_mode()
|
||||
def _capture_ubatch_thread(results, ubatch_metadata):
|
||||
torch.supa.set_device(self.device)
|
||||
ubatch_context = ubatch_metadata.context
|
||||
with torch.supa.stream(ubatch_context.compute_stream):
|
||||
_ = torch.supa.current_blas_handle()
|
||||
with torch.supa.stream(ubatch_context.comm_stream):
|
||||
_ = torch.supa.current_blas_handle()
|
||||
with ubatch_context:
|
||||
model_output = model(
|
||||
input_ids=ubatch_metadata.input_ids,
|
||||
positions=ubatch_metadata.positions,
|
||||
intermediate_tensors=ubatch_metadata.intermediate_tensors,
|
||||
inputs_embeds=ubatch_metadata.inputs_embeds,
|
||||
)
|
||||
|
||||
results.append((ubatch_metadata.context.id, model_output))
|
||||
|
||||
results: list[tuple[int, torch.Tensor]] = []
|
||||
compute_stream = ubatch_metadata[0].context.compute_stream
|
||||
num_tokens = ubatch_metadata[0].num_tokens + \
|
||||
ubatch_metadata[1].num_tokens
|
||||
|
||||
# Ubatches will manually manage the forward context, so we override
|
||||
# it to None here so we can have it restored correctly later
|
||||
with override_forward_context(None):
|
||||
ubatch_threads = []
|
||||
for metadata in ubatch_metadata:
|
||||
thread = threading.Thread(target=_capture_ubatch_thread,
|
||||
args=(
|
||||
results,
|
||||
metadata,
|
||||
))
|
||||
ubatch_threads.append(thread)
|
||||
thread.start()
|
||||
self.ready_barrier.wait() # Wait for both threads to be ready
|
||||
|
||||
# Capture the supagraph
|
||||
supagraph_metadata = \
|
||||
SUPAGraphMetaData(
|
||||
supagraph=torch.supa.SUPAGraph(),
|
||||
ubatch_metadata=ubatch_metadata,
|
||||
)
|
||||
if self.graph_pool is not None:
|
||||
set_graph_pool_id(self.graph_pool)
|
||||
else:
|
||||
set_graph_pool_id(current_platform.graph_pool_handle())
|
||||
with torch.supa.graph(supagraph_metadata.supagraph,
|
||||
stream=compute_stream,
|
||||
pool=self.graph_pool):
|
||||
ubatch_metadata[0].context.cpu_wait_event.set()
|
||||
for thread in ubatch_threads:
|
||||
thread.join()
|
||||
sorted_results = [value for position, value in sorted(results)]
|
||||
result = torch.cat(sorted_results, dim=0)
|
||||
supagraph_metadata.outputs = result
|
||||
self.supagraphs[num_tokens] = supagraph_metadata
|
||||
return supagraph_metadata.outputs
|
||||
|
||||
def _run_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
|
||||
|
||||
@torch.inference_mode()
|
||||
def _ubatch_thread(results, model, ubatch_metadata):
|
||||
with ubatch_metadata.context:
|
||||
model_output = model(
|
||||
input_ids=ubatch_metadata.input_ids,
|
||||
positions=ubatch_metadata.positions,
|
||||
intermediate_tensors=ubatch_metadata.intermediate_tensors,
|
||||
inputs_embeds=ubatch_metadata.inputs_embeds,
|
||||
)
|
||||
results.append((ubatch_metadata.context.id, model_output))
|
||||
|
||||
results: list[tuple[int, torch.Tensor]] = []
|
||||
|
||||
# Ubatch threads will manually manage the forward context, so we
|
||||
# override it to None here so we can have it restored correctly
|
||||
# after both threads have finished
|
||||
with override_forward_context(None):
|
||||
ubatch_threads = []
|
||||
for metadata in ubatch_metadata:
|
||||
thread = threading.Thread(target=_ubatch_thread,
|
||||
args=(
|
||||
results,
|
||||
model,
|
||||
metadata,
|
||||
))
|
||||
ubatch_threads.append(thread)
|
||||
thread.start()
|
||||
self.ready_barrier.wait() # Wait for both threads to be ready
|
||||
ubatch_metadata[0].context.cpu_wait_event.set()
|
||||
for thread in ubatch_threads:
|
||||
thread.join()
|
||||
sorted_results = [value for position, value in sorted(results)]
|
||||
result = torch.cat(sorted_results, dim=0)
|
||||
return result
|
||||
|
||||
def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, input_ids,
|
||||
positions, inputs_embeds, intermediate_tensors,
|
||||
compute_stream, dp_metadata, batch_descriptor,
|
||||
supagraph_runtime_mode) -> list[UbatchMetadata]:
|
||||
|
||||
# Create one forward context per ubatch
|
||||
forward_contexts = []
|
||||
for i, ubatch_slice in enumerate(ubatch_slices):
|
||||
forward_contexts.append(
|
||||
create_forward_context(
|
||||
attn_metadata[i] if attn_metadata is not None else None,
|
||||
self.vllm_config,
|
||||
dp_metadata=dp_metadata,
|
||||
batch_descriptor=batch_descriptor,
|
||||
supagraph_runtime_mode=supagraph_runtime_mode))
|
||||
|
||||
ubatch_ctxs = make_ubatch_contexts(
|
||||
num_micro_batches=len(ubatch_slices),
|
||||
comm_stream=self.comm_stream,
|
||||
compute_stream=compute_stream,
|
||||
forward_contexts=forward_contexts,
|
||||
ready_barrier=self.ready_barrier)
|
||||
|
||||
ubatch_metadata: list[UbatchMetadata] = []
|
||||
for i, ubatch_slice in enumerate(ubatch_slices):
|
||||
sliced_input_ids, sliced_positions, sliced_inputs_embeds, \
|
||||
sliced_intermediate_tensors = \
|
||||
self._slice_model_inputs(
|
||||
ubatch_slice.token_slice, input_ids, positions,
|
||||
inputs_embeds, intermediate_tensors)
|
||||
ubatch_metadata.append(
|
||||
UbatchMetadata(
|
||||
context=ubatch_ctxs[i],
|
||||
input_ids=sliced_input_ids,
|
||||
positions=sliced_positions,
|
||||
inputs_embeds=sliced_inputs_embeds,
|
||||
intermediate_tensors=sliced_intermediate_tensors,
|
||||
num_tokens=ubatch_slice.token_slice.stop -
|
||||
ubatch_slice.token_slice.start))
|
||||
|
||||
return ubatch_metadata
|
||||
|
||||
def _slice_model_inputs(self, tokens_slice: slice, input_ids, positions,
|
||||
inputs_embeds, intermediate_tensors):
|
||||
sliced_input_ids = input_ids[tokens_slice]
|
||||
# if we are using mrope. Mrope adds an additional dimension to the
|
||||
# positions tensor
|
||||
if positions.ndim == 2:
|
||||
sliced_positions = positions[:, tokens_slice]
|
||||
else:
|
||||
sliced_positions = positions[tokens_slice]
|
||||
sliced_inputs_embeds = inputs_embeds[
|
||||
tokens_slice] if inputs_embeds else None
|
||||
sliced_intermediate_tensors = intermediate_tensors[
|
||||
tokens_slice] if intermediate_tensors else None
|
||||
|
||||
return (sliced_input_ids, sliced_positions, sliced_inputs_embeds,
|
||||
sliced_intermediate_tensors)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
forward_context = get_forward_context()
|
||||
batch_descriptor = forward_context.batch_descriptor
|
||||
ubatch_slices = forward_context.ubatch_slices
|
||||
supagraph_runtime_mode = forward_context.cudagraph_runtime_mode
|
||||
|
||||
# If there's no ubatching, just run the runnable object
|
||||
if ubatch_slices is None:
|
||||
|
||||
# This is to account for the case where ubatching was aborted.
|
||||
# When we capture full graphs we only capture one graph per shape,
|
||||
# meaning that if we have a ubatched supagraph for the current
|
||||
# num_tokens, we don't have a non-ubatched one. Without this
|
||||
# check, the supagraph wrapper will try to capture a supagraph
|
||||
# for this shape during a normal run.
|
||||
if supagraph_runtime_mode is SUPAGraphMode.FULL:
|
||||
assert batch_descriptor is not None
|
||||
if batch_descriptor.num_tokens in self.supagraphs:
|
||||
supagraph_runtime_mode = SUPAGraphMode.NONE
|
||||
|
||||
if supagraph_runtime_mode in (SUPAGraphMode.NONE,
|
||||
SUPAGraphMode.PIECEWISE):
|
||||
return self.runnable(*args, **kwargs)
|
||||
else:
|
||||
assert self.supagraph_wrapper is not None
|
||||
return self.supagraph_wrapper(*args, **kwargs)
|
||||
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
num_tokens = (ubatch_slices[0].token_slice.stop -
|
||||
ubatch_slices[0].token_slice.start) * 2
|
||||
input_ids = kwargs['input_ids']
|
||||
positions = kwargs['positions']
|
||||
intermediate_tensors = kwargs['intermediate_tensors']
|
||||
inputs_embeds = kwargs['inputs_embeds']
|
||||
compute_stream = torch.supa.current_stream()
|
||||
|
||||
dp_metadata = forward_context.dp_metadata
|
||||
|
||||
# We shouldn't be here unless we are running with multiple DP ranks
|
||||
assert dp_metadata is not None
|
||||
|
||||
if num_tokens not in self.supagraphs \
|
||||
and supagraph_runtime_mode is SUPAGraphMode.FULL:
|
||||
ubatch_metadata = self._make_ubatch_metadata(
|
||||
ubatch_slices=ubatch_slices,
|
||||
attn_metadata=attn_metadata,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
compute_stream=compute_stream,
|
||||
dp_metadata=dp_metadata,
|
||||
batch_descriptor=batch_descriptor,
|
||||
supagraph_runtime_mode=SUPAGraphMode.NONE)
|
||||
with self.sm_control:
|
||||
return self._capture_ubatches(ubatch_metadata, self.model)
|
||||
elif num_tokens in self.supagraphs \
|
||||
and supagraph_runtime_mode is SUPAGraphMode.FULL:
|
||||
supagraph_metadata = self.supagraphs[num_tokens]
|
||||
supagraph_metadata.supagraph.replay()
|
||||
return supagraph_metadata.outputs
|
||||
else:
|
||||
ubatch_metadata = self._make_ubatch_metadata(
|
||||
ubatch_slices=ubatch_slices,
|
||||
attn_metadata=attn_metadata,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
compute_stream=compute_stream,
|
||||
dp_metadata=dp_metadata,
|
||||
batch_descriptor=batch_descriptor,
|
||||
supagraph_runtime_mode=SUPAGraphMode.NONE)
|
||||
with self.sm_control:
|
||||
return self._run_ubatches(ubatch_metadata, self.model)
|
||||
155
vllm_br/v1/worker/supagraph_dispatcher.py
Normal file
155
vllm_br/v1/worker/supagraph_dispatcher.py
Normal file
@@ -0,0 +1,155 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from vllm.config import CompilationLevel, VllmConfig
|
||||
from vllm.logger import logger
|
||||
from vllm_br.config.compilation import SUPAGraphMode
|
||||
from vllm_br.forward_context import BatchDescriptor
|
||||
|
||||
_BATCH_SIZE_ALIGNMENT = 8
|
||||
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
|
||||
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 34)
|
||||
]
|
||||
|
||||
|
||||
class SupagraphDispatcher:
|
||||
"""
|
||||
Runtime supagraph dispatcher to dispatch keys for multiple set of
|
||||
supagraphs.
|
||||
|
||||
The dispatcher stores two sets of dispatch keys, one for PIECEWISE and one
|
||||
for FULL supagraph runtime mode. The keys are initialized depending on
|
||||
attention support and what supagraph mode is set in CompilationConfig. The
|
||||
keys stored in dispatcher are the only source of truth for valid
|
||||
supagraphs that can be dispatched at runtime.
|
||||
|
||||
At runtime, the dispatch method generates the runtime supagraph mode (FULL,
|
||||
PIECEWISE, or NONE for no supagraph) and the valid key (batch descriptor)
|
||||
based on the input key. After dispatching (communicate via forward context),
|
||||
the supagraph wrappers will trust the dispatch key to do either capturing
|
||||
or replaying (if mode matched), or pass through to the underlying runnable
|
||||
without supagraph (if mode no match or mode is NONE).
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.use_default_list = vllm_config.compilation_config.max_capture_size > 256 or self.compilation_config.max_capture_size == 0
|
||||
self.capture_list = _BATCH_SIZES_TO_CAPTURE if self.use_default_list else self.compilation_config.cudagraph_capture_sizes
|
||||
|
||||
# TODO(liming): Remove this hard code once we support piecewise
|
||||
self.supagraph_mode = SUPAGraphMode.FULL
|
||||
|
||||
# Dict to store valid supagraph dispatching keys.
|
||||
self.supagraph_keys: dict[SUPAGraphMode, set[BatchDescriptor]] = {
|
||||
SUPAGraphMode.PIECEWISE: set(),
|
||||
SUPAGraphMode.FULL: set(),
|
||||
SUPAGraphMode.FULL_DECODE_ONLY: set(),
|
||||
}
|
||||
|
||||
assert not self.supagraph_mode.requires_piecewise_compilation() or \
|
||||
(self.compilation_config.level == CompilationLevel.PIECEWISE and
|
||||
self.compilation_config.splitting_ops_contain_attention()), \
|
||||
"Compilation level should be CompilationLevel.PIECEWISE when "\
|
||||
"supagraph_mode piecewise supagraphs is used, "\
|
||||
f"supagraph_mode={self.supagraph_mode}, "\
|
||||
f"compilation_level={self.compilation_config.level}, "\
|
||||
f"splitting_ops={self.compilation_config.splitting_ops}"
|
||||
|
||||
self.keys_initialized = False
|
||||
|
||||
def add_supagraph_key(self, runtime_mode: SUPAGraphMode,
|
||||
batch_descriptor: BatchDescriptor):
|
||||
assert runtime_mode in [SUPAGraphMode.PIECEWISE,SUPAGraphMode.FULL_DECODE_ONLY, SUPAGraphMode.FULL], \
|
||||
f"Invalid supagraph runtime mode: {runtime_mode}"
|
||||
self.supagraph_keys[runtime_mode].add(batch_descriptor)
|
||||
|
||||
def initialize_supagraph_keys(self, supagraph_mode: SUPAGraphMode,
|
||||
uniform_decode_query_len: int):
|
||||
# This should be called only after attention backend is initialized.
|
||||
|
||||
# Note: we create all valid keys possible for supagraph but do not
|
||||
# guarantee all keys would be used. For example, we create keys for
|
||||
# piecewise supagraphs when it is piecewise compilation, which is always
|
||||
# valid, but for attention backend support unified routine, we may not
|
||||
# trigger capturing/replaying the piecewise supagraphs depending on
|
||||
# CompilationConfig.supagraph_mode. In addition, if we allow lazy
|
||||
# capturing in future PR, some keys may never be triggered.
|
||||
if supagraph_mode == SUPAGraphMode.FULL:
|
||||
max_num_tokens = (uniform_decode_query_len *
|
||||
self.vllm_config.scheduler_config.max_num_seqs)
|
||||
supagraph_capture_sizes_for_decode = [
|
||||
x for x in self.capture_list
|
||||
if x <= max_num_tokens and x >= uniform_decode_query_len
|
||||
]
|
||||
for bs in supagraph_capture_sizes_for_decode:
|
||||
self.add_supagraph_key(
|
||||
supagraph_mode,
|
||||
BatchDescriptor(num_tokens=bs, uniform_decode=True))
|
||||
|
||||
# if decode supagraph mode is FULL, and we don't already have mixed
|
||||
# mode full supagraphs then add them here.
|
||||
if supagraph_mode == SUPAGraphMode.FULL_DECODE_ONLY:
|
||||
max_num_tokens = uniform_decode_query_len * \
|
||||
self.vllm_config.scheduler_config.max_num_seqs
|
||||
supagraph_capture_sizes_for_decode = [
|
||||
x for x in self.capture_list
|
||||
if x <= max_num_tokens and x >= uniform_decode_query_len
|
||||
]
|
||||
for bs in supagraph_capture_sizes_for_decode:
|
||||
self.add_supagraph_key(
|
||||
supagraph_mode,
|
||||
BatchDescriptor(num_tokens=bs, uniform_decode=True))
|
||||
self.keys_initialized = True
|
||||
|
||||
def dispatch(
|
||||
self, batch_descriptor: BatchDescriptor
|
||||
) -> tuple[SUPAGraphMode, Optional[BatchDescriptor]]:
|
||||
"""
|
||||
Given a batch descriptor, dispatch to a supagraph mode.
|
||||
A new batch descriptor is returned as we might dispatch a uniform batch
|
||||
to a graph that supports a more general batch (uniform to non-uniform).
|
||||
"""
|
||||
# if not initialized, just skip dispatching.
|
||||
if not self.keys_initialized:
|
||||
logger.warning_once("supagraph dispatching keys are not "
|
||||
"initialized. No supagraph will be used.")
|
||||
return SUPAGraphMode.NONE, None
|
||||
|
||||
if batch_descriptor in self.supagraph_keys[
|
||||
SUPAGraphMode.FULL_DECODE_ONLY]:
|
||||
return SUPAGraphMode.FULL_DECODE_ONLY, batch_descriptor
|
||||
|
||||
# check if key exists for full supagraph
|
||||
if batch_descriptor in self.supagraph_keys[SUPAGraphMode.FULL]:
|
||||
return SUPAGraphMode.FULL, batch_descriptor
|
||||
|
||||
# # otherwise, check if non-uniform key exists
|
||||
non_uniform_key = batch_descriptor.non_uniform
|
||||
if non_uniform_key in self.supagraph_keys[SUPAGraphMode.FULL]:
|
||||
return SUPAGraphMode.FULL, non_uniform_key
|
||||
|
||||
|
||||
#
|
||||
# # also check if non-uniform key exists for more "general"
|
||||
# # piecewise supagraph
|
||||
# if non_uniform_key in self.supagraph_keys[SUPAGraphMode.PIECEWISE]:
|
||||
# return SUPAGraphMode.PIECEWISE, non_uniform_key
|
||||
|
||||
# finally, just return no supagraphs
|
||||
return SUPAGraphMode.NONE, None
|
||||
195
vllm_br/v1/worker/ubatching.py
Normal file
195
vllm_br/v1/worker/ubatching.py
Normal file
@@ -0,0 +1,195 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
import threading
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import forward_context
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.utils import current_stream
|
||||
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
|
||||
|
||||
|
||||
class SUPAUBatchContext:
|
||||
"""
|
||||
Context manager for micro-batching synchronization using threading events.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
id: int,
|
||||
comm_stream: torch.supa.Stream,
|
||||
compute_stream: torch.supa.Stream,
|
||||
forward_context: ForwardContext,
|
||||
ready_barrier: threading.Barrier,
|
||||
cpu_wait_event: threading.Event,
|
||||
cpu_signal_event: threading.Event,
|
||||
gpu_comm_done_event: torch.supa.Event,
|
||||
gpu_compute_done_event: torch.supa.Event,
|
||||
schedule: str = "default"):
|
||||
self.id = id
|
||||
self.comm_stream = comm_stream
|
||||
self.compute_stream = compute_stream
|
||||
self.forward_context = forward_context
|
||||
self.ready_barrier = ready_barrier
|
||||
self.cpu_wait_event = cpu_wait_event
|
||||
self.cpu_signal_event = cpu_signal_event
|
||||
self.current_stream = compute_stream
|
||||
self.gpu_comm_done_event = gpu_comm_done_event
|
||||
self.gpu_compute_done_event = gpu_compute_done_event
|
||||
self.schedule = schedule
|
||||
self.recv_hook = None
|
||||
|
||||
def __enter__(self):
|
||||
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
|
||||
_THREAD_ID_TO_CONTEXT[threading.get_ident()] = self.id
|
||||
_CURRENT_CONTEXTS[self.id] = self
|
||||
self.ready_barrier.wait()
|
||||
|
||||
self.cpu_wait_event.wait()
|
||||
self.cpu_wait_event.clear()
|
||||
self._restore_context()
|
||||
# Assume we want to start on the compute stream
|
||||
self.update_stream(self.compute_stream)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
|
||||
_CURRENT_CONTEXTS[self.id] = None
|
||||
del _THREAD_ID_TO_CONTEXT[threading.get_ident()]
|
||||
self.maybe_run_recv_hook()
|
||||
self.cpu_signal_event.set()
|
||||
self.cpu_wait_event.clear()
|
||||
return False
|
||||
|
||||
def _restore_context(self):
|
||||
forward_context._forward_context = self.forward_context
|
||||
|
||||
def update_stream(self, stream):
|
||||
self.current_stream = stream
|
||||
if current_stream() != self.current_stream:
|
||||
torch.supa.set_stream(self.current_stream)
|
||||
|
||||
def _signal_comm_done(self):
|
||||
self.gpu_comm_done_event.record(self.comm_stream)
|
||||
|
||||
def _signal_compute_done(self):
|
||||
self.gpu_compute_done_event.record(self.compute_stream)
|
||||
|
||||
def _wait_compute_done(self):
|
||||
self.comm_stream.wait_event(self.gpu_compute_done_event)
|
||||
|
||||
def _wait_comm_done(self):
|
||||
self.compute_stream.wait_event(self.gpu_comm_done_event)
|
||||
|
||||
def _cpu_yield(self):
|
||||
# It is critical for correctness that only one thread is running
|
||||
# at a time. These asserts just make sure that this is the only
|
||||
# thread running before waking the other one up and going to sleep
|
||||
assert forward_context._forward_context == self.forward_context
|
||||
assert current_stream() == self.current_stream
|
||||
assert not self.cpu_wait_event.is_set()
|
||||
|
||||
self.cpu_signal_event.set()
|
||||
self.cpu_wait_event.wait()
|
||||
self.cpu_wait_event.clear()
|
||||
self._restore_context()
|
||||
|
||||
def switch_to_comm(self):
|
||||
self.update_stream(self.comm_stream)
|
||||
|
||||
def switch_to_compute(self):
|
||||
self.update_stream(self.compute_stream)
|
||||
|
||||
def switch_to_comm_sync(self):
|
||||
self._signal_compute_done()
|
||||
self.update_stream(self.comm_stream)
|
||||
self._wait_compute_done()
|
||||
|
||||
def switch_to_compute_sync(self):
|
||||
self._signal_comm_done()
|
||||
self.update_stream(self.compute_stream)
|
||||
self._wait_comm_done()
|
||||
|
||||
def maybe_run_recv_hook(self):
|
||||
if self.recv_hook is not None:
|
||||
self.recv_hook()
|
||||
self.recv_hook = None
|
||||
|
||||
def yield_(self):
|
||||
self.current_stream = current_stream()
|
||||
self._cpu_yield()
|
||||
self.update_stream(self.current_stream)
|
||||
|
||||
def yield_and_switch_from_compute_to_comm(self):
|
||||
assert current_stream() == self.compute_stream
|
||||
self._signal_compute_done()
|
||||
self._cpu_yield()
|
||||
assert self.current_stream == self.compute_stream
|
||||
self.update_stream(self.comm_stream)
|
||||
self._wait_compute_done()
|
||||
|
||||
def yield_and_switch_from_comm_to_compute(self):
|
||||
assert current_stream() == self.comm_stream
|
||||
self._signal_comm_done()
|
||||
self._cpu_yield()
|
||||
assert self.current_stream == self.comm_stream
|
||||
self.update_stream(self.compute_stream)
|
||||
self._wait_comm_done()
|
||||
|
||||
|
||||
def supa_make_ubatch_contexts(
|
||||
num_micro_batches: int,
|
||||
compute_stream: torch.supa.Stream,
|
||||
comm_stream: torch.supa.Stream,
|
||||
forward_contexts: list[ForwardContext],
|
||||
ready_barrier: threading.Barrier,
|
||||
schedule: str = "default",
|
||||
) -> list[UBatchContext]:
|
||||
assert num_micro_batches == 2, "only been tested with 2 micro-batches"
|
||||
"""
|
||||
Create a context manager for micro-batching synchronization.
|
||||
"""
|
||||
cpu_events = [threading.Event() for _ in range(num_micro_batches)]
|
||||
gpu_comm_done_events = [
|
||||
torch.supa.Event() for _ in range(num_micro_batches)
|
||||
]
|
||||
gpu_compute_done_events = [
|
||||
torch.supa.Event() for _ in range(num_micro_batches)
|
||||
]
|
||||
|
||||
assert len(forward_contexts) == 2
|
||||
|
||||
ctxs = []
|
||||
for i in range(num_micro_batches):
|
||||
ctx = UBatchContext(id=i,
|
||||
compute_stream=compute_stream,
|
||||
comm_stream=comm_stream,
|
||||
forward_context=forward_contexts[i],
|
||||
ready_barrier=ready_barrier,
|
||||
cpu_wait_event=cpu_events[i],
|
||||
cpu_signal_event=cpu_events[(i + 1) %
|
||||
num_micro_batches],
|
||||
gpu_comm_done_event=gpu_comm_done_events[i],
|
||||
gpu_compute_done_event=gpu_compute_done_events[i],
|
||||
schedule=schedule)
|
||||
ctxs.append(ctx)
|
||||
|
||||
return ctxs
|
||||
|
||||
|
||||
UBatchContext = SUPAUBatchContext
|
||||
make_ubatch_contexts = supa_make_ubatch_contexts
|
||||
86
vllm_br/v1/worker/utils.py
Normal file
86
vllm_br/v1/worker/utils.py
Normal file
@@ -0,0 +1,86 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.layer import Attention
|
||||
|
||||
|
||||
def bind_kv_cache(
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
forward_context: dict[str, "Attention"],
|
||||
runner_kv_caches: list[torch.Tensor],
|
||||
num_attn_module: Optional[int] = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Bind the allocated KV cache to both ModelRunner and forward context so
|
||||
that the KV cache can be used in the forward pass.
|
||||
|
||||
This function:
|
||||
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
|
||||
kv_caches.
|
||||
2) Associates each attention layer in the `forward_context` with its
|
||||
corresponding KV cache in kv_caches.
|
||||
|
||||
Args:
|
||||
kv_caches: The allocated kv_caches with layer names as keys.
|
||||
forward_context: The global forward context containing all Attention
|
||||
layers with layer names as keys.
|
||||
runner_kv_caches: The kv_cache declared by ModelRunner.
|
||||
"""
|
||||
# Bind kv_caches to ModelRunner
|
||||
assert len(runner_kv_caches) == 0
|
||||
|
||||
# Convert kv_caches dict to a list of tensors in the order of layer_index.
|
||||
index2name = defaultdict(list)
|
||||
for layer_name in kv_caches:
|
||||
index2name[extract_layer_index(layer_name,
|
||||
num_attn_module)].append(layer_name)
|
||||
|
||||
for layer_index in sorted(index2name.keys()):
|
||||
layer_names = index2name[layer_index]
|
||||
if len(layer_names) > 1:
|
||||
# One typical case is encoder-decoder model, e.g., bart.
|
||||
# The cross attention and self attention in the same decoder layer
|
||||
# has different layer_name but the same layer_index.
|
||||
|
||||
# TODO - analyze where runner_kv_caches is used and the right
|
||||
# way to ensure it properly reflects multiple attention layers
|
||||
# in the same decoder block.
|
||||
if current_platform.is_cuda() or current_platform.is_xpu(
|
||||
) or current_platform.is_supa():
|
||||
# We know that the GPU runner is not impacted by this
|
||||
# case. Some test code depends on runner_kv_caches, but
|
||||
# not in a way that's impacted by ignoring this.
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError
|
||||
layer_name = layer_names[0]
|
||||
runner_kv_caches.append(kv_caches[layer_name])
|
||||
|
||||
# Bind kv_caches to forward context
|
||||
for layer_name, kv_cache in kv_caches.items():
|
||||
# NOTE: Use list because of v0 PP virtual engine.
|
||||
forward_context[layer_name].kv_cache = [kv_cache]
|
||||
429
vllm_br/v1/worker/worker.py
Normal file
429
vllm_br/v1/worker/worker.py
Normal file
@@ -0,0 +1,429 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""A GPU worker class."""
|
||||
import copy
|
||||
import datetime
|
||||
import gc
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm_br.envs as br_envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
set_custom_all_reduce)
|
||||
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.logger import logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||
DraftTokenIds, ModelRunnerOutput)
|
||||
from vllm.v1.utils import report_usage_stats
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
from vllm_br.platform import SUPAPlatform
|
||||
from vllm_br.utils import GiB_bytes, SUPAMemorySnapshot
|
||||
from vllm_br.v1.worker.model_runner import SUPAModelRunner
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
|
||||
class SUPAWorker(WorkerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=is_driver_worker,
|
||||
)
|
||||
self.kv_transfer_config = vllm_config.kv_transfer_config
|
||||
if self.model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
|
||||
init_cached_hf_modules()
|
||||
|
||||
# Buffers saved before sleep
|
||||
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
|
||||
|
||||
# Torch profiler. Enabled and configured through env vars:
|
||||
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
logger.info(
|
||||
"Profiling enabled. Traces will be saved to: %s",
|
||||
torch_profiler_trace_dir,
|
||||
)
|
||||
self.profiler = torch.profiler.profile(
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
torch_profiler_trace_dir, use_gzip=True),
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.SUPA, # type: ignore
|
||||
],
|
||||
schedule=torch.profiler.schedule(wait=0,
|
||||
warmup=0,
|
||||
active=1,
|
||||
repeat=1),
|
||||
profile_memory=False,
|
||||
record_shapes=True,
|
||||
with_stack=False,
|
||||
use_supa_simple=True, # type: ignore
|
||||
)
|
||||
else:
|
||||
self.profiler = None
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def wake_up(self, tags: Optional[list[str]] = None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
def init_device(self):
|
||||
if self.device_config.device.type == "supa":
|
||||
self.device = torch.device(f"supa:{self.local_rank}")
|
||||
if self.kv_transfer_config is not None:
|
||||
device_cursor = self.kv_transfer_config.get_from_extra_config(
|
||||
"device_cursor", 0)
|
||||
self.device = torch.device(
|
||||
f"supa:{self.local_rank + int(device_cursor)}")
|
||||
SUPAPlatform.set_device(self.device)
|
||||
|
||||
_check_if_gpu_supports_dtype(self.model_config.dtype)
|
||||
# Initialize the distributed environment BEFORE taking
|
||||
# memory snapshot
|
||||
# This ensures SUCCL buffers are allocated before we measure
|
||||
# available memory
|
||||
self._init_worker_distributed_environment()
|
||||
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
gc.collect()
|
||||
torch.supa.empty_cache()
|
||||
self.init_gpu_memory = SUPAPlatform.get_device_total_memory()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Not support device type: {self.device_config.device}")
|
||||
|
||||
# Construct the model runner
|
||||
self.model_runner: SUPAModelRunner = SUPAModelRunner( # type: ignore
|
||||
self.vllm_config, self.device)
|
||||
|
||||
if self.rank == 0:
|
||||
# If usage stat is enabled, collect relevant info.
|
||||
report_usage_stats(self.vllm_config)
|
||||
|
||||
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
|
||||
# to hijack tensor allocation.
|
||||
def load_model(self) -> None:
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
raise NotImplementedError('SUPA do not support sleep mode')
|
||||
else:
|
||||
from contextlib import nullcontext
|
||||
|
||||
context = nullcontext()
|
||||
with context:
|
||||
self.model_runner.load_model()
|
||||
|
||||
@torch.inference_mode()
|
||||
def determine_available_memory(self) -> int:
|
||||
"""Profiles the peak memory usage of the model to determine how much
|
||||
memory can be used for KV cache without OOMs.
|
||||
|
||||
The engine will first conduct a profiling of the existing memory usage.
|
||||
Then, it calculate the free memory that can be used for KV cache in
|
||||
bytes.
|
||||
|
||||
.. tip::
|
||||
You may limit the usage of GPU memory
|
||||
by adjusting the `gpu_memory_utilization` parameter.
|
||||
"""
|
||||
torch.supa.empty_cache()
|
||||
|
||||
_, total_gpu_memory = torch.supa.mem_get_info()
|
||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||
# of the model.
|
||||
before_profile = SUPAMemorySnapshot()
|
||||
after_profile = SUPAMemorySnapshot()
|
||||
before_profile.measure()
|
||||
self.model_runner.profile_run()
|
||||
after_profile.measure()
|
||||
|
||||
free_gpu_memory, _ = torch.supa.mem_get_info()
|
||||
# NOTE(woosuk): Here we assume that the other processes using the same
|
||||
# GPU did not change their memory usage during the profiling.
|
||||
assert self.init_gpu_memory > free_gpu_memory, (
|
||||
"Error in memory profiling. "
|
||||
f"Initial free memory {self.init_gpu_memory}, current free memory"
|
||||
f" {free_gpu_memory}. This happens when the GPU memory was "
|
||||
"not properly cleaned up before initializing the vLLM instance.")
|
||||
|
||||
# GPU did not change their memory usage during the profiling.
|
||||
peak_memory = torch.supa.memory_allocated()
|
||||
# Check for any memory left around that may have been allocated on the
|
||||
# gpu outside of `torch`. NCCL operations, for example, can use a few
|
||||
# GB during a forward pass
|
||||
torch.supa.empty_cache()
|
||||
torch_allocated_bytes = SUPAPlatform.get_memory_stats(
|
||||
self.device, "allocated_bytes.all.current")
|
||||
total_allocated_bytes = (torch.supa.mem_get_info()[1] -
|
||||
torch.supa.mem_get_info()[0])
|
||||
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
|
||||
#if non_torch_allocations > 0:
|
||||
# peak_memory += non_torch_allocations
|
||||
available_kv_cache_memory = (
|
||||
total_gpu_memory * self.cache_config.gpu_memory_utilization -
|
||||
peak_memory)
|
||||
memory_for_current_instance = total_gpu_memory * \
|
||||
self.cache_config.gpu_memory_utilization
|
||||
diff_profile = after_profile - before_profile
|
||||
msg = (f"Memory profiling takes {diff_profile.timestamp:.2f} seconds\n"
|
||||
"the current vLLM instance can use "
|
||||
"total_gpu_memory "
|
||||
f"({(total_gpu_memory / GiB_bytes):.2f}GiB)"
|
||||
" x gpu_memory_utilization "
|
||||
f"({self.cache_config.gpu_memory_utilization:.2f})"
|
||||
f" = {(memory_for_current_instance / GiB_bytes):.2f}GiB\n"
|
||||
"model weights take "
|
||||
f"{(self.model_runner.model_memory_usage / GiB_bytes):.2f}GiB;"
|
||||
" non_torch_memory takes "
|
||||
f"{(non_torch_allocations / GiB_bytes):.2f}GiB;"
|
||||
" PyTorch activation peak memory takes "
|
||||
f"{(diff_profile.torch_peak / GiB_bytes):.2f}GiB;"
|
||||
" the rest of the memory reserved for KV Cache is "
|
||||
f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.")
|
||||
logger.info(msg)
|
||||
return int(available_kv_cache_memory)
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
return self.model_runner.get_kv_cache_spec()
|
||||
|
||||
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""Allocate GPU KV cache with the specified kv_cache_config."""
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
raise NotImplementedError('SUPA do not support sleep mode')
|
||||
else:
|
||||
from contextlib import nullcontext
|
||||
|
||||
context = nullcontext()
|
||||
with context:
|
||||
self.model_runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
# warm up sizes that are not in cudagraph capture sizes,
|
||||
# but users still want to compile for better performance,
|
||||
# e.g. for the max-num-batched token size in chunked prefill.
|
||||
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
|
||||
if not self.model_config.enforce_eager:
|
||||
warmup_sizes = [
|
||||
x for x in warmup_sizes
|
||||
if x not in self.scheduler_config.cuda_graph_sizes
|
||||
]
|
||||
for size in sorted(warmup_sizes, reverse=True):
|
||||
logger.info("Compile and warming up model for size %d", size)
|
||||
self.model_runner._dummy_run(size,
|
||||
skip_eplb=True,
|
||||
remove_lora=False)
|
||||
self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
|
||||
if not self.model_config.enforce_eager:
|
||||
self.model_runner.capture_model()
|
||||
|
||||
# Warm up sampler and preallocate memory buffer for logits and other
|
||||
# sampling related tensors of max possible shape to avoid memory
|
||||
# fragmentation issue.
|
||||
# NOTE: This is called after `capture_model` on purpose to prevent
|
||||
# memory buffers from being cleared by `SUPAPlatform.empty_cache`.
|
||||
if get_pp_group().is_last_rank:
|
||||
max_num_reqs = min(
|
||||
self.scheduler_config.max_num_seqs,
|
||||
self.scheduler_config.max_num_batched_tokens,
|
||||
)
|
||||
hidden_states, last_hidden_states = \
|
||||
self.model_runner._dummy_run(
|
||||
num_tokens=max_num_reqs,
|
||||
skip_eplb=True,
|
||||
)
|
||||
if self.model_runner.is_pooling_model:
|
||||
self.model_runner._dummy_pooler_run(hidden_states)
|
||||
else:
|
||||
self.model_runner._dummy_sampler_run(
|
||||
hidden_states=last_hidden_states)
|
||||
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model_runner.get_model()
|
||||
|
||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
return self.model_runner.get_supported_tasks()
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]:
|
||||
intermediate_tensors = None
|
||||
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
|
||||
if forward_pass and not get_pp_group().is_first_rank:
|
||||
# intermediate_tensors = IntermediateTensors(
|
||||
# get_pp_group().recv_tensor_dict(
|
||||
# all_gather_group=get_tp_group()))
|
||||
# use cpu send/recv
|
||||
if br_envs.VLLM_PP_CPU_SEND_RECV:
|
||||
cpu_dict = get_pp_group().recv_tensor_dict()
|
||||
gpu_dict = {
|
||||
k: v.to(torch.supa.current_device())
|
||||
for k, v in cpu_dict.items()
|
||||
}
|
||||
intermediate_tensors = IntermediateTensors(gpu_dict)
|
||||
else:
|
||||
intermediate_tensors = IntermediateTensors(
|
||||
get_pp_group().recv_tensor_dict())
|
||||
|
||||
output = self.model_runner.execute_model(scheduler_output,
|
||||
intermediate_tensors)
|
||||
if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
|
||||
return output
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
assert parallel_config.distributed_executor_backend != (
|
||||
"external_launcher") and not get_pp_group().is_last_rank
|
||||
# use cpu send/recv
|
||||
if br_envs.VLLM_PP_CPU_SEND_RECV:
|
||||
cpu_dict = {k: v.cpu() for k, v in output.tensors.items()}
|
||||
get_pp_group().send_tensor_dict(cpu_dict)
|
||||
else:
|
||||
get_pp_group().send_tensor_dict(output.tensors)
|
||||
kv_connector_output = output.kv_connector_output
|
||||
if not kv_connector_output:
|
||||
return None
|
||||
# In case of PP with kv transfer, we need to pass through the
|
||||
# kv_connector_output
|
||||
if (not kv_connector_output.finished_sending
|
||||
and not kv_connector_output.finished_recving):
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
output.kv_connector_output = kv_connector_output
|
||||
return output
|
||||
|
||||
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
||||
return self.model_runner.take_draft_token_ids()
|
||||
|
||||
def profile(self, is_start: bool = True):
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
if is_start:
|
||||
self.profiler.start()
|
||||
else:
|
||||
self.profiler.stop()
|
||||
|
||||
def execute_dummy_batch(self) -> None:
|
||||
self.model_runner._dummy_run(1)
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.model_runner.add_lora(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self.model_runner.remove_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> set[int]:
|
||||
return self.model_runner.list_loras()
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.model_runner.pin_lora(lora_id)
|
||||
|
||||
def check_health(self) -> None:
|
||||
# worker will always be healthy as long as it's running.
|
||||
return
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
path: str,
|
||||
pattern: Optional[str] = None,
|
||||
max_size: Optional[int] = None,
|
||||
) -> None:
|
||||
from vllm.model_executor.model_loader.loader import ShardedStateLoader
|
||||
|
||||
ShardedStateLoader.save_model(
|
||||
self.model_runner.model,
|
||||
path,
|
||||
pattern=pattern,
|
||||
max_size=max_size,
|
||||
)
|
||||
|
||||
def _init_worker_distributed_environment(self) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
set_custom_all_reduce(
|
||||
not self.parallel_config.disable_custom_all_reduce)
|
||||
init_distributed_environment(self.parallel_config.world_size,
|
||||
self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank,
|
||||
"sccl",
|
||||
timeout=datetime.timedelta(seconds=100))
|
||||
ensure_model_parallel_initialized(
|
||||
self.parallel_config.tensor_parallel_size,
|
||||
self.parallel_config.pipeline_parallel_size)
|
||||
|
||||
ensure_kv_transfer_initialized(self.vllm_config)
|
||||
|
||||
|
||||
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
||||
# Check if the GPU supports the dtype.
|
||||
# TODO: add checkers
|
||||
return
|
||||
if torch_dtype == torch.bfloat16: # noqa: SIM102
|
||||
capability = SUPAPlatform.get_device_capability()
|
||||
gpu_name = SUPAPlatform.get_device_name()
|
||||
|
||||
if capability is None:
|
||||
compute_str = "does not have a compute capability"
|
||||
else:
|
||||
version_str = capability.as_version_str()
|
||||
compute_str = f"has compute capability {version_str}"
|
||||
|
||||
raise ValueError(
|
||||
"Bfloat16 is only supported on GPUs with compute capability "
|
||||
f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
|
||||
"You can use float16 instead by explicitly setting the "
|
||||
"`dtype` flag in CLI, for example: --dtype=half.")
|
||||
Reference in New Issue
Block a user