508 lines
22 KiB
Python
508 lines
22 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
|
|
|
import itertools
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
|
|
import torch
|
|
|
|
from vllm.attention.backends.abstract import AttentionMetadata
|
|
from vllm.config import VllmConfig
|
|
from vllm.forward_context import set_forward_context
|
|
from vllm_mlu.mlu_forward_context import MLUDPMetadata
|
|
from vllm_mlu.model_executor.models.dp_utils import DataParallelRuntimeParams
|
|
from vllm_mlu.v1.attention.backends.mla.flashmla import (
|
|
FlashMLAPrefillMetadata, FlashMLAMetadata, MLACommonMetadata
|
|
)
|
|
from vllm_mlu.v1.attention.backends.utils import (
|
|
COMMON_METADATA_STR,
|
|
MLUCommonAttentionMetadata,
|
|
)
|
|
|
|
SEQUENCE_DIM_PARITION_THRESHOLD = 1024
|
|
|
|
def get_common_and_layer_metadata(
|
|
attn_metadata: Optional[dict],
|
|
) -> Tuple[Optional[MLUCommonAttentionMetadata], Optional[AttentionMetadata]]:
|
|
"""
|
|
Returns the common metadata and layer metadata from the given attention metadata.
|
|
"""
|
|
if attn_metadata is None:
|
|
return None, None
|
|
|
|
if isinstance(attn_metadata, dict):
|
|
assert COMMON_METADATA_STR in attn_metadata, (
|
|
f"attn_metadata must contain {COMMON_METADATA_STR} key"
|
|
)
|
|
assert len({id(v) for v in attn_metadata.values()}) == 2, (
|
|
f"attn_metadata should be a dict with two values, one for {COMMON_METADATA_STR} and "
|
|
f"the other for layers."
|
|
)
|
|
common_metadata = attn_metadata[COMMON_METADATA_STR]
|
|
layer_metadata = next((v for k, v in attn_metadata.items() if k != COMMON_METADATA_STR), None)
|
|
|
|
return common_metadata, layer_metadata
|
|
return None, attn_metadata
|
|
|
|
def should_skip_partition(layer_metadata, common_metadata) -> bool:
|
|
"""Helper function to simplify partition condition check"""
|
|
is_layer_metadata_invalid = (layer_metadata is None
|
|
or layer_metadata.prefill is None
|
|
or layer_metadata.query_start_loc is None
|
|
or layer_metadata.query_start_loc.numel() == 0)
|
|
is_common_metadata_invalid = common_metadata is None or not common_metadata.is_prefill_only
|
|
return is_layer_metadata_invalid or is_common_metadata_invalid
|
|
|
|
|
|
def attn_mcc_plan(
|
|
attn_metadata: Any,
|
|
dp_params: DataParallelRuntimeParams,
|
|
parts_to_split: int,
|
|
) -> Tuple[int, int]:
|
|
"""
|
|
Returns the number of parts for batch size dimension and the number of parts for sequence length dimension.
|
|
"""
|
|
# In the precedure of dummy run, attn_metadata is an instance of MLACommonMetadata
|
|
if not isinstance(attn_metadata, (dict, MLACommonMetadata, type(None))):
|
|
raise TypeError(f"attn_metadata must be dict or MLACommonMetadata, got {type(attn_metadata)}")
|
|
|
|
if isinstance(attn_metadata, dict):
|
|
common_metadata, layer_metadata = get_common_and_layer_metadata(attn_metadata)
|
|
else:
|
|
common_metadata, layer_metadata = None, attn_metadata
|
|
|
|
if dp_params is None:
|
|
# We don't support mcc with decode yet.
|
|
if should_skip_partition(layer_metadata, common_metadata):
|
|
return 1, 1
|
|
|
|
# The priority of batch size dimension to split is higher than sequence length dimension.
|
|
# And we ensure each subtask is not empty without dp.
|
|
num_prefills = layer_metadata.query_start_loc.numel() - 1
|
|
if num_prefills > 1:
|
|
return min(parts_to_split, num_prefills), 1
|
|
|
|
try:
|
|
max_query_len = torch.diff(layer_metadata.query_start_loc).max().item()
|
|
except RuntimeError:
|
|
return 1, 1
|
|
|
|
if max_query_len < SEQUENCE_DIM_PARITION_THRESHOLD:
|
|
return 1, 1
|
|
return 1, min(parts_to_split, max_query_len)
|
|
else:
|
|
if not all(is_prefill for is_prefill in dp_params.dp_is_prefill):
|
|
return 1, 1
|
|
|
|
max_bs = max(dp_params.batch_sizes)
|
|
if max_bs > 1:
|
|
# Ensure parts_to_split does not exceed max_bs to avoid unnecessary splits
|
|
if max(dp_params.token_split_list) < SEQUENCE_DIM_PARITION_THRESHOLD:
|
|
return 1, 1
|
|
return min(parts_to_split, max_bs), 1
|
|
else:
|
|
if max(dp_params.token_split_list) < SEQUENCE_DIM_PARITION_THRESHOLD:
|
|
return 1, 1
|
|
return 1, parts_to_split
|
|
|
|
def get_data_num_and_offset(total_size, parts_to_split):
|
|
"""
|
|
Get data size and offset for each.
|
|
For example, total batch 11, parallel_num 4, result is [3, 3, 3, 2], offsets is [0, 3, 6, 9]
|
|
total batch 8, parallel_num 4, result is [2, 2, 2, 2], offsets is [0, 2, 4, 6]
|
|
"""
|
|
# Calculate the quotient and remainder of total_size divided by parts_to_split
|
|
quotient = total_size // parts_to_split
|
|
remainder = total_size % parts_to_split
|
|
data_num_list = [quotient + 1] * remainder + [quotient] * (parts_to_split - remainder)
|
|
offset_list = [0] + list(itertools.accumulate(data_num_list))
|
|
return data_num_list, offset_list[:-1]
|
|
|
|
def split_dp_params(
|
|
dp_params: DataParallelRuntimeParams,
|
|
bs_parts_to_split: int,
|
|
seq_parts_to_split: int,
|
|
attn_data_parallel_size: int,
|
|
attn_tensor_parallel_size: int,
|
|
prefill_dispatch_use_RS_AG: bool,
|
|
dp_rank_: int,
|
|
) -> List[DataParallelRuntimeParams]:
|
|
assert bs_parts_to_split == 1 or seq_parts_to_split == 1, \
|
|
"We don't support split batch and sequence dimensions concurrently."
|
|
|
|
if dp_params is None:
|
|
return [None] * bs_parts_to_split * seq_parts_to_split
|
|
|
|
if bs_parts_to_split * seq_parts_to_split == 1:
|
|
return list([dp_params])
|
|
|
|
if bs_parts_to_split == 1:
|
|
results : List[DataParallelRuntimeParams] = []
|
|
dp_seq_lens = []
|
|
for seq_len in dp_params.seq_lens:
|
|
tokens, _ = get_data_num_and_offset(seq_len, seq_parts_to_split)
|
|
dp_seq_lens.append(tokens)
|
|
|
|
query_lens_per_dp_rank = []
|
|
|
|
# For each dp rank, the batch size is 0 or 1.
|
|
bs_offset = 0
|
|
for i in range(attn_data_parallel_size):
|
|
if dp_params.batch_sizes[i] > 0:
|
|
seq_len = dp_params.seq_lens[bs_offset]
|
|
tokens, _ = get_data_num_and_offset(seq_len, seq_parts_to_split)
|
|
query_lens_per_dp_rank.append(tokens)
|
|
bs_offset += dp_params.batch_sizes[i]
|
|
else:
|
|
query_lens_per_dp_rank.append([0] * seq_parts_to_split)
|
|
|
|
for i in range(seq_parts_to_split):
|
|
dp_is_prefill = []
|
|
for dp_rank in range(attn_data_parallel_size):
|
|
dp_is_prefill.append(True)
|
|
|
|
results.append(MLUDPMetadata.make_oot(
|
|
data_parallel_rank=dp_rank_,
|
|
data_parallel_size=attn_data_parallel_size,
|
|
tensor_parallel_size=attn_tensor_parallel_size,
|
|
dp_token_nums=[query_lens_per_dp_rank[j][i] for j in range(attn_data_parallel_size)],
|
|
dp_is_prefill=dp_is_prefill,
|
|
prefill_dispatch_use_RS_AG=prefill_dispatch_use_RS_AG,
|
|
seq_lens=[seq_lens[i] for seq_lens in dp_seq_lens],
|
|
batch_sizes=dp_params.batch_sizes,
|
|
))
|
|
return results
|
|
|
|
bs_per_dp = dp_params.batch_sizes # [bs_rank_0, bs_rank_1, ...]
|
|
seq_lens_per_dp = dp_params.seq_lens # [seq_len_bs_0, seq_len_bs_1,...]
|
|
|
|
# [[bs_rank_0_part_0, bs_rank_0_part_1,...], [bs_rank_1_part_0, bs_rank_1_part_1,...], ...]
|
|
split_bs_per_dp = []
|
|
# [[
|
|
# [bs0_part_0_rank_0, bs1_part_0_rank_0, ...],
|
|
# [bs0_part_1_rank_0, bs1_part_1_rank_0, ...],
|
|
# ...
|
|
# ],
|
|
# [
|
|
# [bs0_part_0_rank_1, bs1_part_0_rank_1, ...],
|
|
# [bs0_part_1_rank_1, bs1_part_1_rank_1, ...],
|
|
# ...
|
|
# ],
|
|
# ]
|
|
split_query_lens_per_dp = []
|
|
for dp_rank in range(attn_data_parallel_size):
|
|
_bs, _offset = get_data_num_and_offset(bs_per_dp[dp_rank], bs_parts_to_split)
|
|
split_bs_per_dp.append(_bs)
|
|
split_query_lens_per_dp.append([])
|
|
for i in range(bs_parts_to_split):
|
|
start = sum(bs_per_dp[:dp_rank]) + _offset[i]
|
|
end = start + _bs[i]
|
|
split_query_lens_per_dp[-1].append(dp_params.seq_lens[start:end])
|
|
|
|
results : List[DataParallelRuntimeParams] = []
|
|
for i in range(bs_parts_to_split):
|
|
dp_query_lens = [sum(split_query_lens_per_dp[dp_rank][i]) for dp_rank in range(attn_data_parallel_size)]
|
|
seq_lens = []
|
|
for dp_rank in range(attn_data_parallel_size):
|
|
seq_lens += split_query_lens_per_dp[dp_rank][i]
|
|
batch_sizes = []
|
|
for dp_rank in range(attn_data_parallel_size):
|
|
batch_sizes.append(split_bs_per_dp[dp_rank][i])
|
|
|
|
dp_is_prefill = []
|
|
for dp_rank in range(attn_data_parallel_size):
|
|
dp_is_prefill.append(True)
|
|
results.append(MLUDPMetadata.make_oot(
|
|
data_parallel_rank=dp_rank_,
|
|
data_parallel_size=attn_data_parallel_size,
|
|
tensor_parallel_size=attn_tensor_parallel_size,
|
|
dp_token_nums=dp_query_lens,
|
|
dp_is_prefill=dp_is_prefill,
|
|
prefill_dispatch_use_RS_AG=prefill_dispatch_use_RS_AG,
|
|
seq_lens=seq_lens,
|
|
batch_sizes=batch_sizes,
|
|
))
|
|
|
|
return results
|
|
|
|
|
|
def split_input(
|
|
input: torch.Tensor,
|
|
bs_parts_to_split: int,
|
|
seq_parts_to_split: int,
|
|
attn_metadata_list: List[AttentionMetadata],
|
|
) -> List[torch.Tensor]:
|
|
assert seq_parts_to_split == 1 or bs_parts_to_split == 1, \
|
|
"We don't support split batch and sequence dimensions concurrently."
|
|
|
|
if input is None:
|
|
return [None] * bs_parts_to_split * seq_parts_to_split
|
|
|
|
if bs_parts_to_split * seq_parts_to_split == 1:
|
|
return list([input])
|
|
|
|
token_num_list = [0] * len(attn_metadata_list)
|
|
for i, metadata in enumerate(attn_metadata_list):
|
|
common_metadata, layer_metadata = get_common_and_layer_metadata(metadata)
|
|
if layer_metadata is not None:
|
|
token_num_list[i] = layer_metadata.num_actual_tokens
|
|
|
|
# A special case for dummy run
|
|
if layer_metadata is None and i == 0:
|
|
token_num_list[i] = input.shape[0]
|
|
|
|
results = list()
|
|
for i in range(bs_parts_to_split * seq_parts_to_split):
|
|
start = sum(token_num_list[:i])
|
|
end = start + token_num_list[i]
|
|
results.append(input[start:end])
|
|
return results
|
|
|
|
|
|
def split_positions(
|
|
positions: torch.Tensor,
|
|
bs_parts_to_split: int,
|
|
seq_parts_to_split: int,
|
|
attn_metadata: AttentionMetadata,
|
|
) -> List[torch.Tensor]:
|
|
if seq_parts_to_split == 1:
|
|
return [positions] * bs_parts_to_split
|
|
|
|
common_metadata, layer_metadata = get_common_and_layer_metadata(attn_metadata)
|
|
total_tokens = layer_metadata.num_actual_tokens if layer_metadata is not None else 0
|
|
|
|
tokens, offsets = get_data_num_and_offset(total_tokens, seq_parts_to_split)
|
|
positions_list = []
|
|
for i in range(seq_parts_to_split):
|
|
positions_list.append(positions[offsets[i]: offsets[i] + tokens[i]])
|
|
|
|
return positions_list
|
|
|
|
def split_attn_metadata(
|
|
attn_metadata: dict,
|
|
bs_parts_to_split: int,
|
|
seq_parts_to_split: int,
|
|
) -> List[Any]:
|
|
""" attn_metdata is a dict, which contains common and layer metadata."""
|
|
assert bs_parts_to_split == 1 or seq_parts_to_split == 1, \
|
|
"We don't support split batch and sequence dimensions concurrently."
|
|
if bs_parts_to_split == 1 and seq_parts_to_split == 1:
|
|
return list([attn_metadata])
|
|
|
|
if attn_metadata is None:
|
|
return [None] * bs_parts_to_split * seq_parts_to_split
|
|
|
|
if seq_parts_to_split > 1:
|
|
common_metadata, layer_metadata = get_common_and_layer_metadata(attn_metadata)
|
|
if common_metadata is None or not hasattr(common_metadata, 'num_actual_tokens'):
|
|
raise ValueError("common_metadata is invalid or missing num_actual_tokens")
|
|
num_prefill_tokens = common_metadata.num_actual_tokens
|
|
tokens, offsets = get_data_num_and_offset(num_prefill_tokens, seq_parts_to_split)
|
|
device = common_metadata.seq_lens.device
|
|
sub_common_metadata, sub_layer_metadata = [], []
|
|
for i in range(seq_parts_to_split):
|
|
# query_start_loc tensor, which indices positions in input.
|
|
query_start_loc_tensor = torch.empty_like(common_metadata.query_start_loc)
|
|
query_start_loc_tensor[0] = 0
|
|
query_start_loc_tensor[1] = tokens[i]
|
|
# seq_lens tensor
|
|
seq_lens_tensor = torch.tensor(
|
|
[offsets[i] + tokens[i]],
|
|
dtype=common_metadata.seq_lens.dtype,
|
|
device=device
|
|
)
|
|
# seq_start_loc tensor, which indicates positions in the sequence(kv cache).
|
|
seq_start_loc_tensor = torch.empty_like(common_metadata.seq_start_loc)
|
|
seq_start_loc_tensor[0] = offsets[i]
|
|
seq_start_loc_tensor[1] = offsets[i] + tokens[i]
|
|
# max_query_len scalar
|
|
max_query_len = tokens[i]
|
|
# num_actual_tokens scalar
|
|
num_actual_tokens = tokens[i]
|
|
# num_input_tokens scalar
|
|
num_input_tokens = num_actual_tokens
|
|
# infer_mode
|
|
infer_mode = common_metadata.infer_mode
|
|
# update common metadata
|
|
sub_common_metadata.append(MLUCommonAttentionMetadata(
|
|
query_start_loc=query_start_loc_tensor,
|
|
query_start_loc_cpu=common_metadata.query_start_loc_cpu, # FIXME: split when used
|
|
seq_lens=seq_lens_tensor,
|
|
seq_lens_cpu=common_metadata.seq_lens_cpu, # FIXME: split when used
|
|
num_computed_tokens_cpu=common_metadata.num_computed_tokens_cpu, # FIXME: split when used
|
|
num_reqs=common_metadata.num_reqs, # FIXME: split when used
|
|
num_actual_tokens=num_actual_tokens,
|
|
max_query_len=max_query_len,
|
|
max_seq_len=max_query_len,
|
|
block_table_tensor=common_metadata.block_table_tensor, # FIXME: split when used
|
|
slot_mapping=common_metadata.slot_mapping, # FIXME: split when used
|
|
seq_start_loc=seq_start_loc_tensor,
|
|
num_input_tokens=num_input_tokens,
|
|
infer_mode=infer_mode,
|
|
num_prefill_query_tokens=tokens[i],
|
|
num_prefill_kv_tokens=offsets[i] + tokens[i],
|
|
))
|
|
# slot_mapping tensor
|
|
slot_mapping = layer_metadata.slot_mapping[offsets[i]:offsets[i] + tokens[i]]
|
|
# update layer metadata
|
|
REQUIRED_NUM_DECODES = 0
|
|
REQUIRED_NUM_DECODE_TOKENS = 0
|
|
REQUIRED_NUM_PREFILLS = 1
|
|
if not hasattr(layer_metadata, 'num_prefills') or \
|
|
layer_metadata.num_prefills is None:
|
|
raise ValueError("layer_metadata.num_prefills is required")
|
|
|
|
assert layer_metadata.num_decodes == REQUIRED_NUM_DECODES and \
|
|
layer_metadata.num_decode_tokens == REQUIRED_NUM_DECODE_TOKENS and \
|
|
layer_metadata.num_prefills == REQUIRED_NUM_PREFILLS, (
|
|
f"num_decodes, num_decode_tokens, num_prefills must be {REQUIRED_NUM_DECODES}, {REQUIRED_NUM_DECODE_TOKENS}, "
|
|
f"{REQUIRED_NUM_PREFILLS}, but got {layer_metadata.num_decodes}, {layer_metadata.num_decode_tokens}, "
|
|
f"{layer_metadata.num_prefills}."
|
|
)
|
|
assert layer_metadata.prefill.chunked_context is None, (
|
|
f"chunked_context is only available for prefill with chunked context, "
|
|
f"and it is not supported when enabling mcc."
|
|
)
|
|
prefill_metadata = FlashMLAPrefillMetadata(
|
|
block_table=layer_metadata.prefill.block_table,
|
|
query_start_loc=query_start_loc_tensor,
|
|
max_query_len=max_query_len,
|
|
chunked_context=None,
|
|
num_prefills=layer_metadata.prefill.num_prefills,
|
|
max_seq_len=layer_metadata.prefill.max_seq_len,
|
|
)
|
|
# Note: for sequence dimension partition, we provide cu_seqlens_kv filed to
|
|
# indicates key/value size for flash attention operator.
|
|
prefill_metadata.cu_seqlens_kv = torch.empty_like(prefill_metadata.query_start_loc)
|
|
prefill_metadata.cu_seqlens_kv[0] = 0
|
|
prefill_metadata.cu_seqlens_kv[1] = offsets[i] + tokens[i]
|
|
|
|
sub_layer_metadata.append(FlashMLAMetadata(
|
|
num_reqs=layer_metadata.num_reqs,
|
|
max_query_len=max_query_len,
|
|
max_seq_len=max_query_len,
|
|
num_actual_tokens=num_actual_tokens,
|
|
query_start_loc=query_start_loc_tensor,
|
|
slot_mapping=slot_mapping,
|
|
num_decodes=layer_metadata.num_decodes,
|
|
num_decode_tokens=layer_metadata.num_decode_tokens,
|
|
num_prefills=layer_metadata.num_prefills,
|
|
num_prefill_tokens=tokens[i],
|
|
head_dim=layer_metadata.head_dim,
|
|
decode=layer_metadata.decode,
|
|
prefill=prefill_metadata,
|
|
))
|
|
|
|
sub_attn_metadata_list = []
|
|
for common_meta, layer_meta in zip(sub_common_metadata, sub_layer_metadata):
|
|
sub_attn_metadata_dict = {}
|
|
for key, value in attn_metadata.items():
|
|
if key == COMMON_METADATA_STR:
|
|
sub_attn_metadata_dict[key] = common_meta
|
|
else:
|
|
sub_attn_metadata_dict[key] = layer_meta
|
|
sub_attn_metadata_list.append(sub_attn_metadata_dict)
|
|
return sub_attn_metadata_list
|
|
elif bs_parts_to_split > 1:
|
|
common_metadata, layer_metadata = get_common_and_layer_metadata(attn_metadata)
|
|
if not hasattr(layer_metadata, 'num_prefills') or layer_metadata.num_prefills is None:
|
|
raise ValueError("layer_metadata.num_prefills is required")
|
|
total_batch = layer_metadata.num_prefills
|
|
batch_sizes, offsets = get_data_num_and_offset(total_batch, bs_parts_to_split)
|
|
sub_common_metadata, sub_layer_metadata = [], []
|
|
for i in range(bs_parts_to_split):
|
|
# query_start_loc tensor
|
|
start, end = offsets[i], offsets[i] + batch_sizes[i]
|
|
query_start_loc_tensor = common_metadata.query_start_loc[start:end+1].clone()
|
|
if i > 0:
|
|
query_start_loc_tensor -= common_metadata.query_start_loc[start]
|
|
# block_table
|
|
block_tables = torch.empty(
|
|
(batch_sizes[i], 0),
|
|
dtype=layer_metadata.prefill.block_table.dtype,
|
|
device=layer_metadata.prefill.block_table.device,
|
|
)
|
|
# seq_lens tensor
|
|
seq_lens_tensor = common_metadata.seq_lens[start:end].clone()
|
|
# seq_start_loc tensor
|
|
seq_start_loc_tensor = query_start_loc_tensor
|
|
# max_query_len scalar
|
|
max_query_len = seq_lens_tensor.max().item() if seq_lens_tensor.numel() > 0 else 0
|
|
# num_actual_tokens scalar
|
|
num_actual_tokens = seq_start_loc_tensor[-1].item()
|
|
# num_input_tokens scalar
|
|
num_input_tokens = num_actual_tokens
|
|
# infer_mode
|
|
infer_mode = common_metadata.infer_mode
|
|
# slot_mapping tensor
|
|
slot_mapping_start = 0
|
|
for data in sub_common_metadata:
|
|
slot_mapping_start += data.num_actual_tokens
|
|
slot_mapping_tensor = layer_metadata.slot_mapping[
|
|
slot_mapping_start:slot_mapping_start + num_actual_tokens
|
|
]
|
|
# update common metadata
|
|
sub_common_metadata.append(MLUCommonAttentionMetadata(
|
|
query_start_loc=query_start_loc_tensor,
|
|
query_start_loc_cpu=common_metadata.query_start_loc_cpu, # FIXME: split when used
|
|
seq_lens=seq_lens_tensor,
|
|
seq_lens_cpu=common_metadata.seq_lens_cpu, # FIXME: split when used
|
|
num_computed_tokens_cpu=common_metadata.num_computed_tokens_cpu, # FIXME: split when used
|
|
num_reqs=common_metadata.num_reqs, # FIXME: split when used
|
|
block_table_tensor=common_metadata.block_table_tensor, # FIXME: split when used
|
|
slot_mapping=common_metadata.slot_mapping, # FIXME: split when used
|
|
seq_start_loc=seq_start_loc_tensor,
|
|
max_query_len=max_query_len,
|
|
max_seq_len=max_query_len,
|
|
num_actual_tokens=num_actual_tokens,
|
|
num_input_tokens=num_input_tokens,
|
|
infer_mode=infer_mode,
|
|
num_prefill_query_tokens=num_actual_tokens,
|
|
num_prefill_kv_tokens=num_actual_tokens,
|
|
))
|
|
# update layer_metadata
|
|
prefill_metadata = FlashMLAPrefillMetadata(
|
|
block_table=block_tables,
|
|
query_start_loc=query_start_loc_tensor,
|
|
max_query_len=max_query_len,
|
|
chunked_context=None,
|
|
num_prefills=batch_sizes[i],
|
|
max_seq_len=max_query_len,
|
|
)
|
|
sub_layer_metadata.append(FlashMLAMetadata(
|
|
num_reqs=batch_sizes[i],
|
|
max_query_len=max_query_len,
|
|
max_seq_len=max_query_len,
|
|
num_actual_tokens=num_actual_tokens,
|
|
query_start_loc=query_start_loc_tensor,
|
|
slot_mapping=slot_mapping_tensor,
|
|
num_decodes=layer_metadata.num_decodes, # useless field
|
|
num_decode_tokens=0, # useless field
|
|
num_prefills=batch_sizes[i],
|
|
num_prefill_tokens=num_actual_tokens,
|
|
head_dim=layer_metadata.head_dim,
|
|
decode=layer_metadata.decode,
|
|
prefill=prefill_metadata,
|
|
))
|
|
|
|
sub_attn_metadata_list = []
|
|
for common_meta, layer_meta in zip(sub_common_metadata, sub_layer_metadata):
|
|
sub_attn_metadata_dict = {}
|
|
for key, value in attn_metadata.items():
|
|
if key == COMMON_METADATA_STR:
|
|
sub_attn_metadata_dict[key] = common_meta
|
|
else:
|
|
sub_attn_metadata_dict[key] = layer_meta
|
|
sub_attn_metadata_list.append(sub_attn_metadata_dict)
|
|
return sub_attn_metadata_list
|
|
|
|
|
|
def execute_with_updated_forward_context(
|
|
vllm_config: VllmConfig,
|
|
attn_metadata: AttentionMetadata,
|
|
func: Callable,
|
|
kwargs: Dict[str, Any],
|
|
):
|
|
with set_forward_context(attn_metadata, vllm_config):
|
|
return func(**kwargs)
|