Files
2026-04-24 09:58:03 +08:00

508 lines
22 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import itertools
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm_mlu.mlu_forward_context import MLUDPMetadata
from vllm_mlu.model_executor.models.dp_utils import DataParallelRuntimeParams
from vllm_mlu.v1.attention.backends.mla.flashmla import (
FlashMLAPrefillMetadata, FlashMLAMetadata, MLACommonMetadata
)
from vllm_mlu.v1.attention.backends.utils import (
COMMON_METADATA_STR,
MLUCommonAttentionMetadata,
)
SEQUENCE_DIM_PARITION_THRESHOLD = 1024
def get_common_and_layer_metadata(
attn_metadata: Optional[dict],
) -> Tuple[Optional[MLUCommonAttentionMetadata], Optional[AttentionMetadata]]:
"""
Returns the common metadata and layer metadata from the given attention metadata.
"""
if attn_metadata is None:
return None, None
if isinstance(attn_metadata, dict):
assert COMMON_METADATA_STR in attn_metadata, (
f"attn_metadata must contain {COMMON_METADATA_STR} key"
)
assert len({id(v) for v in attn_metadata.values()}) == 2, (
f"attn_metadata should be a dict with two values, one for {COMMON_METADATA_STR} and "
f"the other for layers."
)
common_metadata = attn_metadata[COMMON_METADATA_STR]
layer_metadata = next((v for k, v in attn_metadata.items() if k != COMMON_METADATA_STR), None)
return common_metadata, layer_metadata
return None, attn_metadata
def should_skip_partition(layer_metadata, common_metadata) -> bool:
"""Helper function to simplify partition condition check"""
is_layer_metadata_invalid = (layer_metadata is None
or layer_metadata.prefill is None
or layer_metadata.query_start_loc is None
or layer_metadata.query_start_loc.numel() == 0)
is_common_metadata_invalid = common_metadata is None or not common_metadata.is_prefill_only
return is_layer_metadata_invalid or is_common_metadata_invalid
def attn_mcc_plan(
attn_metadata: Any,
dp_params: DataParallelRuntimeParams,
parts_to_split: int,
) -> Tuple[int, int]:
"""
Returns the number of parts for batch size dimension and the number of parts for sequence length dimension.
"""
# In the precedure of dummy run, attn_metadata is an instance of MLACommonMetadata
if not isinstance(attn_metadata, (dict, MLACommonMetadata, type(None))):
raise TypeError(f"attn_metadata must be dict or MLACommonMetadata, got {type(attn_metadata)}")
if isinstance(attn_metadata, dict):
common_metadata, layer_metadata = get_common_and_layer_metadata(attn_metadata)
else:
common_metadata, layer_metadata = None, attn_metadata
if dp_params is None:
# We don't support mcc with decode yet.
if should_skip_partition(layer_metadata, common_metadata):
return 1, 1
# The priority of batch size dimension to split is higher than sequence length dimension.
# And we ensure each subtask is not empty without dp.
num_prefills = layer_metadata.query_start_loc.numel() - 1
if num_prefills > 1:
return min(parts_to_split, num_prefills), 1
try:
max_query_len = torch.diff(layer_metadata.query_start_loc).max().item()
except RuntimeError:
return 1, 1
if max_query_len < SEQUENCE_DIM_PARITION_THRESHOLD:
return 1, 1
return 1, min(parts_to_split, max_query_len)
else:
if not all(is_prefill for is_prefill in dp_params.dp_is_prefill):
return 1, 1
max_bs = max(dp_params.batch_sizes)
if max_bs > 1:
# Ensure parts_to_split does not exceed max_bs to avoid unnecessary splits
if max(dp_params.token_split_list) < SEQUENCE_DIM_PARITION_THRESHOLD:
return 1, 1
return min(parts_to_split, max_bs), 1
else:
if max(dp_params.token_split_list) < SEQUENCE_DIM_PARITION_THRESHOLD:
return 1, 1
return 1, parts_to_split
def get_data_num_and_offset(total_size, parts_to_split):
"""
Get data size and offset for each.
For example, total batch 11, parallel_num 4, result is [3, 3, 3, 2], offsets is [0, 3, 6, 9]
total batch 8, parallel_num 4, result is [2, 2, 2, 2], offsets is [0, 2, 4, 6]
"""
# Calculate the quotient and remainder of total_size divided by parts_to_split
quotient = total_size // parts_to_split
remainder = total_size % parts_to_split
data_num_list = [quotient + 1] * remainder + [quotient] * (parts_to_split - remainder)
offset_list = [0] + list(itertools.accumulate(data_num_list))
return data_num_list, offset_list[:-1]
def split_dp_params(
dp_params: DataParallelRuntimeParams,
bs_parts_to_split: int,
seq_parts_to_split: int,
attn_data_parallel_size: int,
attn_tensor_parallel_size: int,
prefill_dispatch_use_RS_AG: bool,
dp_rank_: int,
) -> List[DataParallelRuntimeParams]:
assert bs_parts_to_split == 1 or seq_parts_to_split == 1, \
"We don't support split batch and sequence dimensions concurrently."
if dp_params is None:
return [None] * bs_parts_to_split * seq_parts_to_split
if bs_parts_to_split * seq_parts_to_split == 1:
return list([dp_params])
if bs_parts_to_split == 1:
results : List[DataParallelRuntimeParams] = []
dp_seq_lens = []
for seq_len in dp_params.seq_lens:
tokens, _ = get_data_num_and_offset(seq_len, seq_parts_to_split)
dp_seq_lens.append(tokens)
query_lens_per_dp_rank = []
# For each dp rank, the batch size is 0 or 1.
bs_offset = 0
for i in range(attn_data_parallel_size):
if dp_params.batch_sizes[i] > 0:
seq_len = dp_params.seq_lens[bs_offset]
tokens, _ = get_data_num_and_offset(seq_len, seq_parts_to_split)
query_lens_per_dp_rank.append(tokens)
bs_offset += dp_params.batch_sizes[i]
else:
query_lens_per_dp_rank.append([0] * seq_parts_to_split)
for i in range(seq_parts_to_split):
dp_is_prefill = []
for dp_rank in range(attn_data_parallel_size):
dp_is_prefill.append(True)
results.append(MLUDPMetadata.make_oot(
data_parallel_rank=dp_rank_,
data_parallel_size=attn_data_parallel_size,
tensor_parallel_size=attn_tensor_parallel_size,
dp_token_nums=[query_lens_per_dp_rank[j][i] for j in range(attn_data_parallel_size)],
dp_is_prefill=dp_is_prefill,
prefill_dispatch_use_RS_AG=prefill_dispatch_use_RS_AG,
seq_lens=[seq_lens[i] for seq_lens in dp_seq_lens],
batch_sizes=dp_params.batch_sizes,
))
return results
bs_per_dp = dp_params.batch_sizes # [bs_rank_0, bs_rank_1, ...]
seq_lens_per_dp = dp_params.seq_lens # [seq_len_bs_0, seq_len_bs_1,...]
# [[bs_rank_0_part_0, bs_rank_0_part_1,...], [bs_rank_1_part_0, bs_rank_1_part_1,...], ...]
split_bs_per_dp = []
# [[
# [bs0_part_0_rank_0, bs1_part_0_rank_0, ...],
# [bs0_part_1_rank_0, bs1_part_1_rank_0, ...],
# ...
# ],
# [
# [bs0_part_0_rank_1, bs1_part_0_rank_1, ...],
# [bs0_part_1_rank_1, bs1_part_1_rank_1, ...],
# ...
# ],
# ]
split_query_lens_per_dp = []
for dp_rank in range(attn_data_parallel_size):
_bs, _offset = get_data_num_and_offset(bs_per_dp[dp_rank], bs_parts_to_split)
split_bs_per_dp.append(_bs)
split_query_lens_per_dp.append([])
for i in range(bs_parts_to_split):
start = sum(bs_per_dp[:dp_rank]) + _offset[i]
end = start + _bs[i]
split_query_lens_per_dp[-1].append(dp_params.seq_lens[start:end])
results : List[DataParallelRuntimeParams] = []
for i in range(bs_parts_to_split):
dp_query_lens = [sum(split_query_lens_per_dp[dp_rank][i]) for dp_rank in range(attn_data_parallel_size)]
seq_lens = []
for dp_rank in range(attn_data_parallel_size):
seq_lens += split_query_lens_per_dp[dp_rank][i]
batch_sizes = []
for dp_rank in range(attn_data_parallel_size):
batch_sizes.append(split_bs_per_dp[dp_rank][i])
dp_is_prefill = []
for dp_rank in range(attn_data_parallel_size):
dp_is_prefill.append(True)
results.append(MLUDPMetadata.make_oot(
data_parallel_rank=dp_rank_,
data_parallel_size=attn_data_parallel_size,
tensor_parallel_size=attn_tensor_parallel_size,
dp_token_nums=dp_query_lens,
dp_is_prefill=dp_is_prefill,
prefill_dispatch_use_RS_AG=prefill_dispatch_use_RS_AG,
seq_lens=seq_lens,
batch_sizes=batch_sizes,
))
return results
def split_input(
input: torch.Tensor,
bs_parts_to_split: int,
seq_parts_to_split: int,
attn_metadata_list: List[AttentionMetadata],
) -> List[torch.Tensor]:
assert seq_parts_to_split == 1 or bs_parts_to_split == 1, \
"We don't support split batch and sequence dimensions concurrently."
if input is None:
return [None] * bs_parts_to_split * seq_parts_to_split
if bs_parts_to_split * seq_parts_to_split == 1:
return list([input])
token_num_list = [0] * len(attn_metadata_list)
for i, metadata in enumerate(attn_metadata_list):
common_metadata, layer_metadata = get_common_and_layer_metadata(metadata)
if layer_metadata is not None:
token_num_list[i] = layer_metadata.num_actual_tokens
# A special case for dummy run
if layer_metadata is None and i == 0:
token_num_list[i] = input.shape[0]
results = list()
for i in range(bs_parts_to_split * seq_parts_to_split):
start = sum(token_num_list[:i])
end = start + token_num_list[i]
results.append(input[start:end])
return results
def split_positions(
positions: torch.Tensor,
bs_parts_to_split: int,
seq_parts_to_split: int,
attn_metadata: AttentionMetadata,
) -> List[torch.Tensor]:
if seq_parts_to_split == 1:
return [positions] * bs_parts_to_split
common_metadata, layer_metadata = get_common_and_layer_metadata(attn_metadata)
total_tokens = layer_metadata.num_actual_tokens if layer_metadata is not None else 0
tokens, offsets = get_data_num_and_offset(total_tokens, seq_parts_to_split)
positions_list = []
for i in range(seq_parts_to_split):
positions_list.append(positions[offsets[i]: offsets[i] + tokens[i]])
return positions_list
def split_attn_metadata(
attn_metadata: dict,
bs_parts_to_split: int,
seq_parts_to_split: int,
) -> List[Any]:
""" attn_metdata is a dict, which contains common and layer metadata."""
assert bs_parts_to_split == 1 or seq_parts_to_split == 1, \
"We don't support split batch and sequence dimensions concurrently."
if bs_parts_to_split == 1 and seq_parts_to_split == 1:
return list([attn_metadata])
if attn_metadata is None:
return [None] * bs_parts_to_split * seq_parts_to_split
if seq_parts_to_split > 1:
common_metadata, layer_metadata = get_common_and_layer_metadata(attn_metadata)
if common_metadata is None or not hasattr(common_metadata, 'num_actual_tokens'):
raise ValueError("common_metadata is invalid or missing num_actual_tokens")
num_prefill_tokens = common_metadata.num_actual_tokens
tokens, offsets = get_data_num_and_offset(num_prefill_tokens, seq_parts_to_split)
device = common_metadata.seq_lens.device
sub_common_metadata, sub_layer_metadata = [], []
for i in range(seq_parts_to_split):
# query_start_loc tensor, which indices positions in input.
query_start_loc_tensor = torch.empty_like(common_metadata.query_start_loc)
query_start_loc_tensor[0] = 0
query_start_loc_tensor[1] = tokens[i]
# seq_lens tensor
seq_lens_tensor = torch.tensor(
[offsets[i] + tokens[i]],
dtype=common_metadata.seq_lens.dtype,
device=device
)
# seq_start_loc tensor, which indicates positions in the sequence(kv cache).
seq_start_loc_tensor = torch.empty_like(common_metadata.seq_start_loc)
seq_start_loc_tensor[0] = offsets[i]
seq_start_loc_tensor[1] = offsets[i] + tokens[i]
# max_query_len scalar
max_query_len = tokens[i]
# num_actual_tokens scalar
num_actual_tokens = tokens[i]
# num_input_tokens scalar
num_input_tokens = num_actual_tokens
# infer_mode
infer_mode = common_metadata.infer_mode
# update common metadata
sub_common_metadata.append(MLUCommonAttentionMetadata(
query_start_loc=query_start_loc_tensor,
query_start_loc_cpu=common_metadata.query_start_loc_cpu, # FIXME: split when used
seq_lens=seq_lens_tensor,
seq_lens_cpu=common_metadata.seq_lens_cpu, # FIXME: split when used
num_computed_tokens_cpu=common_metadata.num_computed_tokens_cpu, # FIXME: split when used
num_reqs=common_metadata.num_reqs, # FIXME: split when used
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
max_seq_len=max_query_len,
block_table_tensor=common_metadata.block_table_tensor, # FIXME: split when used
slot_mapping=common_metadata.slot_mapping, # FIXME: split when used
seq_start_loc=seq_start_loc_tensor,
num_input_tokens=num_input_tokens,
infer_mode=infer_mode,
num_prefill_query_tokens=tokens[i],
num_prefill_kv_tokens=offsets[i] + tokens[i],
))
# slot_mapping tensor
slot_mapping = layer_metadata.slot_mapping[offsets[i]:offsets[i] + tokens[i]]
# update layer metadata
REQUIRED_NUM_DECODES = 0
REQUIRED_NUM_DECODE_TOKENS = 0
REQUIRED_NUM_PREFILLS = 1
if not hasattr(layer_metadata, 'num_prefills') or \
layer_metadata.num_prefills is None:
raise ValueError("layer_metadata.num_prefills is required")
assert layer_metadata.num_decodes == REQUIRED_NUM_DECODES and \
layer_metadata.num_decode_tokens == REQUIRED_NUM_DECODE_TOKENS and \
layer_metadata.num_prefills == REQUIRED_NUM_PREFILLS, (
f"num_decodes, num_decode_tokens, num_prefills must be {REQUIRED_NUM_DECODES}, {REQUIRED_NUM_DECODE_TOKENS}, "
f"{REQUIRED_NUM_PREFILLS}, but got {layer_metadata.num_decodes}, {layer_metadata.num_decode_tokens}, "
f"{layer_metadata.num_prefills}."
)
assert layer_metadata.prefill.chunked_context is None, (
f"chunked_context is only available for prefill with chunked context, "
f"and it is not supported when enabling mcc."
)
prefill_metadata = FlashMLAPrefillMetadata(
block_table=layer_metadata.prefill.block_table,
query_start_loc=query_start_loc_tensor,
max_query_len=max_query_len,
chunked_context=None,
num_prefills=layer_metadata.prefill.num_prefills,
max_seq_len=layer_metadata.prefill.max_seq_len,
)
# Note: for sequence dimension partition, we provide cu_seqlens_kv filed to
# indicates key/value size for flash attention operator.
prefill_metadata.cu_seqlens_kv = torch.empty_like(prefill_metadata.query_start_loc)
prefill_metadata.cu_seqlens_kv[0] = 0
prefill_metadata.cu_seqlens_kv[1] = offsets[i] + tokens[i]
sub_layer_metadata.append(FlashMLAMetadata(
num_reqs=layer_metadata.num_reqs,
max_query_len=max_query_len,
max_seq_len=max_query_len,
num_actual_tokens=num_actual_tokens,
query_start_loc=query_start_loc_tensor,
slot_mapping=slot_mapping,
num_decodes=layer_metadata.num_decodes,
num_decode_tokens=layer_metadata.num_decode_tokens,
num_prefills=layer_metadata.num_prefills,
num_prefill_tokens=tokens[i],
head_dim=layer_metadata.head_dim,
decode=layer_metadata.decode,
prefill=prefill_metadata,
))
sub_attn_metadata_list = []
for common_meta, layer_meta in zip(sub_common_metadata, sub_layer_metadata):
sub_attn_metadata_dict = {}
for key, value in attn_metadata.items():
if key == COMMON_METADATA_STR:
sub_attn_metadata_dict[key] = common_meta
else:
sub_attn_metadata_dict[key] = layer_meta
sub_attn_metadata_list.append(sub_attn_metadata_dict)
return sub_attn_metadata_list
elif bs_parts_to_split > 1:
common_metadata, layer_metadata = get_common_and_layer_metadata(attn_metadata)
if not hasattr(layer_metadata, 'num_prefills') or layer_metadata.num_prefills is None:
raise ValueError("layer_metadata.num_prefills is required")
total_batch = layer_metadata.num_prefills
batch_sizes, offsets = get_data_num_and_offset(total_batch, bs_parts_to_split)
sub_common_metadata, sub_layer_metadata = [], []
for i in range(bs_parts_to_split):
# query_start_loc tensor
start, end = offsets[i], offsets[i] + batch_sizes[i]
query_start_loc_tensor = common_metadata.query_start_loc[start:end+1].clone()
if i > 0:
query_start_loc_tensor -= common_metadata.query_start_loc[start]
# block_table
block_tables = torch.empty(
(batch_sizes[i], 0),
dtype=layer_metadata.prefill.block_table.dtype,
device=layer_metadata.prefill.block_table.device,
)
# seq_lens tensor
seq_lens_tensor = common_metadata.seq_lens[start:end].clone()
# seq_start_loc tensor
seq_start_loc_tensor = query_start_loc_tensor
# max_query_len scalar
max_query_len = seq_lens_tensor.max().item() if seq_lens_tensor.numel() > 0 else 0
# num_actual_tokens scalar
num_actual_tokens = seq_start_loc_tensor[-1].item()
# num_input_tokens scalar
num_input_tokens = num_actual_tokens
# infer_mode
infer_mode = common_metadata.infer_mode
# slot_mapping tensor
slot_mapping_start = 0
for data in sub_common_metadata:
slot_mapping_start += data.num_actual_tokens
slot_mapping_tensor = layer_metadata.slot_mapping[
slot_mapping_start:slot_mapping_start + num_actual_tokens
]
# update common metadata
sub_common_metadata.append(MLUCommonAttentionMetadata(
query_start_loc=query_start_loc_tensor,
query_start_loc_cpu=common_metadata.query_start_loc_cpu, # FIXME: split when used
seq_lens=seq_lens_tensor,
seq_lens_cpu=common_metadata.seq_lens_cpu, # FIXME: split when used
num_computed_tokens_cpu=common_metadata.num_computed_tokens_cpu, # FIXME: split when used
num_reqs=common_metadata.num_reqs, # FIXME: split when used
block_table_tensor=common_metadata.block_table_tensor, # FIXME: split when used
slot_mapping=common_metadata.slot_mapping, # FIXME: split when used
seq_start_loc=seq_start_loc_tensor,
max_query_len=max_query_len,
max_seq_len=max_query_len,
num_actual_tokens=num_actual_tokens,
num_input_tokens=num_input_tokens,
infer_mode=infer_mode,
num_prefill_query_tokens=num_actual_tokens,
num_prefill_kv_tokens=num_actual_tokens,
))
# update layer_metadata
prefill_metadata = FlashMLAPrefillMetadata(
block_table=block_tables,
query_start_loc=query_start_loc_tensor,
max_query_len=max_query_len,
chunked_context=None,
num_prefills=batch_sizes[i],
max_seq_len=max_query_len,
)
sub_layer_metadata.append(FlashMLAMetadata(
num_reqs=batch_sizes[i],
max_query_len=max_query_len,
max_seq_len=max_query_len,
num_actual_tokens=num_actual_tokens,
query_start_loc=query_start_loc_tensor,
slot_mapping=slot_mapping_tensor,
num_decodes=layer_metadata.num_decodes, # useless field
num_decode_tokens=0, # useless field
num_prefills=batch_sizes[i],
num_prefill_tokens=num_actual_tokens,
head_dim=layer_metadata.head_dim,
decode=layer_metadata.decode,
prefill=prefill_metadata,
))
sub_attn_metadata_list = []
for common_meta, layer_meta in zip(sub_common_metadata, sub_layer_metadata):
sub_attn_metadata_dict = {}
for key, value in attn_metadata.items():
if key == COMMON_METADATA_STR:
sub_attn_metadata_dict[key] = common_meta
else:
sub_attn_metadata_dict[key] = layer_meta
sub_attn_metadata_list.append(sub_attn_metadata_dict)
return sub_attn_metadata_list
def execute_with_updated_forward_context(
vllm_config: VllmConfig,
attn_metadata: AttentionMetadata,
func: Callable,
kwargs: Dict[str, Any],
):
with set_forward_context(attn_metadata, vllm_config):
return func(**kwargs)