forked from EngineX-Hygon/enginex-hygon-vllm
init src 0.9.2
This commit is contained in:
184
vllm/v1/attention/backends/cpu_attn.py
Normal file
184
vllm/v1/attention/backends/cpu_attn.py
Normal file
@@ -0,0 +1,184 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata)
|
||||
from vllm.attention.backends.torch_sdpa import (TorchSDPABackendImpl,
|
||||
TorchSDPAMetadata)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.attention.ops.ipex_attn import PagedAttention
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
|
||||
class TorchSDPABackend(AttentionBackend):
|
||||
accept_output_buffer: bool = False
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return PagedAttention.get_supported_head_sizes()
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
attn_type = cls.__name__.removesuffix("Backend")
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {supported_head_sizes}. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes.")
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TORCH_SDPA_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["TorchSDPABackendImpl"]:
|
||||
return TorchSDPABackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||
return TorchSDPAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["TorchSDPAMetadataBuilderV1"]:
|
||||
return TorchSDPAMetadataBuilderV1
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> tuple[int, ...]:
|
||||
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
|
||||
|
||||
def __init__(self, runner: CPUModelRunner, kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable) -> None:
|
||||
self.runner = runner
|
||||
self.block_table = block_table
|
||||
|
||||
# For reorder
|
||||
self.reorder_prompt_req_index_list = np.empty(self.runner.max_num_reqs,
|
||||
dtype=np.int64)
|
||||
self.reorder_decode_req_index_list = np.empty(self.runner.max_num_reqs,
|
||||
dtype=np.int64)
|
||||
self.num_prompt_req: int = 0
|
||||
|
||||
self.seq_start_loc_cpu = torch.zeros(
|
||||
runner.max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
self.seq_start_loc_np = self.seq_start_loc_cpu.numpy()
|
||||
|
||||
def reorder_batch(self, input_batch: InputBatch,
|
||||
scheduler_output: SchedulerOutput) -> bool:
|
||||
prompt_list_idx = 0
|
||||
decode_list_idx = 0
|
||||
for req_index in range(input_batch.num_reqs):
|
||||
if input_batch.num_computed_tokens_cpu[
|
||||
req_index] < input_batch.num_prompt_tokens[req_index]:
|
||||
# prompt stage
|
||||
self.reorder_prompt_req_index_list[prompt_list_idx] = req_index
|
||||
prompt_list_idx += 1
|
||||
else:
|
||||
# decode stage
|
||||
self.reorder_decode_req_index_list[decode_list_idx] = req_index
|
||||
decode_list_idx += 1
|
||||
assert decode_list_idx + prompt_list_idx == input_batch.num_reqs
|
||||
|
||||
# Update prompt requests number
|
||||
self.num_prompt_req = prompt_list_idx
|
||||
|
||||
reorder_req_num = 0
|
||||
for req_index in range(decode_list_idx):
|
||||
if self.reorder_decode_req_index_list[req_index] < prompt_list_idx:
|
||||
reorder_req_num += 1
|
||||
else:
|
||||
break
|
||||
|
||||
if reorder_req_num == 0:
|
||||
return False
|
||||
|
||||
reorder_prompt_list = (
|
||||
self.reorder_prompt_req_index_list[:prompt_list_idx]
|
||||
[-reorder_req_num:])
|
||||
reorder_decode_list = (
|
||||
self.reorder_decode_req_index_list[:decode_list_idx]
|
||||
[:reorder_req_num])
|
||||
assert reorder_decode_list.size == reorder_prompt_list.size
|
||||
|
||||
for idx in range(reorder_req_num):
|
||||
prompt_req_index = reorder_prompt_list[idx].item()
|
||||
decode_req_index = reorder_decode_list[idx].item()
|
||||
input_batch.swap_states(prompt_req_index, decode_req_index)
|
||||
|
||||
return True
|
||||
|
||||
def build(self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata):
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
runner = self.runner
|
||||
block_table = self.block_table
|
||||
seq_lens_np = runner.seq_lens_np[:num_reqs]
|
||||
num_prompt_req = self.num_prompt_req
|
||||
max_prefill_seq_len = seq_lens_np[:num_prompt_req].max().item(
|
||||
) if num_prompt_req > 0 else 0
|
||||
max_decode_seq_len = seq_lens_np[num_prompt_req:num_reqs].max().item(
|
||||
) if num_prompt_req < num_reqs else 0
|
||||
self.seq_start_loc_np[0] = 0
|
||||
np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1:num_reqs + 1])
|
||||
num_prefill_tokens = runner.query_start_loc_np[num_prompt_req].item()
|
||||
num_decode_tokens = runner.query_start_loc_np[num_reqs].item(
|
||||
) - num_prefill_tokens
|
||||
slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].long()
|
||||
block_table_tensor = block_table.get_device_tensor()
|
||||
attn_metadata = TorchSDPAMetadata(
|
||||
num_prefills=num_prompt_req,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
seq_lens_tensor=runner.
|
||||
seq_lens_cpu[num_prompt_req:num_reqs], # decode
|
||||
max_decode_seq_len=max_decode_seq_len, # decode
|
||||
block_tables=block_table_tensor[num_prompt_req:num_reqs], # decode
|
||||
chunked_prefill=True,
|
||||
max_query_len=max_query_len,
|
||||
max_kv_len=max_prefill_seq_len,
|
||||
prefill_query_start_loc=runner.
|
||||
query_start_loc_cpu[:num_prompt_req + 1], # prefill
|
||||
kv_start_loc=self.seq_start_loc_cpu[:num_prompt_req +
|
||||
1], # prefill
|
||||
prefill_block_tables=block_table_tensor[:
|
||||
num_prompt_req], # prefill
|
||||
query_start_loc=runner.query_start_loc_cpu[:num_reqs +
|
||||
1], # for logits index
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=False,
|
||||
)
|
||||
|
||||
return attn_metadata
|
||||
Reference in New Issue
Block a user