# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project import itertools from typing import Any, Callable, Dict, List, Optional, Tuple import torch from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig from vllm.forward_context import set_forward_context from vllm_mlu.mlu_forward_context import MLUDPMetadata from vllm_mlu.model_executor.models.dp_utils import DataParallelRuntimeParams from vllm_mlu.v1.attention.backends.mla.flashmla import ( FlashMLAPrefillMetadata, FlashMLAMetadata, MLACommonMetadata ) from vllm_mlu.v1.attention.backends.utils import ( COMMON_METADATA_STR, MLUCommonAttentionMetadata, ) SEQUENCE_DIM_PARITION_THRESHOLD = 1024 def get_common_and_layer_metadata( attn_metadata: Optional[dict], ) -> Tuple[Optional[MLUCommonAttentionMetadata], Optional[AttentionMetadata]]: """ Returns the common metadata and layer metadata from the given attention metadata. """ if attn_metadata is None: return None, None if isinstance(attn_metadata, dict): assert COMMON_METADATA_STR in attn_metadata, ( f"attn_metadata must contain {COMMON_METADATA_STR} key" ) assert len({id(v) for v in attn_metadata.values()}) == 2, ( f"attn_metadata should be a dict with two values, one for {COMMON_METADATA_STR} and " f"the other for layers." ) common_metadata = attn_metadata[COMMON_METADATA_STR] layer_metadata = next((v for k, v in attn_metadata.items() if k != COMMON_METADATA_STR), None) return common_metadata, layer_metadata return None, attn_metadata def should_skip_partition(layer_metadata, common_metadata) -> bool: """Helper function to simplify partition condition check""" is_layer_metadata_invalid = (layer_metadata is None or layer_metadata.prefill is None or layer_metadata.query_start_loc is None or layer_metadata.query_start_loc.numel() == 0) is_common_metadata_invalid = common_metadata is None or not common_metadata.is_prefill_only return is_layer_metadata_invalid or is_common_metadata_invalid def attn_mcc_plan( attn_metadata: Any, dp_params: DataParallelRuntimeParams, parts_to_split: int, ) -> Tuple[int, int]: """ Returns the number of parts for batch size dimension and the number of parts for sequence length dimension. """ # In the precedure of dummy run, attn_metadata is an instance of MLACommonMetadata if not isinstance(attn_metadata, (dict, MLACommonMetadata, type(None))): raise TypeError(f"attn_metadata must be dict or MLACommonMetadata, got {type(attn_metadata)}") if isinstance(attn_metadata, dict): common_metadata, layer_metadata = get_common_and_layer_metadata(attn_metadata) else: common_metadata, layer_metadata = None, attn_metadata if dp_params is None: # We don't support mcc with decode yet. if should_skip_partition(layer_metadata, common_metadata): return 1, 1 # The priority of batch size dimension to split is higher than sequence length dimension. # And we ensure each subtask is not empty without dp. num_prefills = layer_metadata.query_start_loc.numel() - 1 if num_prefills > 1: return min(parts_to_split, num_prefills), 1 try: max_query_len = torch.diff(layer_metadata.query_start_loc).max().item() except RuntimeError: return 1, 1 if max_query_len < SEQUENCE_DIM_PARITION_THRESHOLD: return 1, 1 return 1, min(parts_to_split, max_query_len) else: if not all(is_prefill for is_prefill in dp_params.dp_is_prefill): return 1, 1 max_bs = max(dp_params.batch_sizes) if max_bs > 1: # Ensure parts_to_split does not exceed max_bs to avoid unnecessary splits if max(dp_params.token_split_list) < SEQUENCE_DIM_PARITION_THRESHOLD: return 1, 1 return min(parts_to_split, max_bs), 1 else: if max(dp_params.token_split_list) < SEQUENCE_DIM_PARITION_THRESHOLD: return 1, 1 return 1, parts_to_split def get_data_num_and_offset(total_size, parts_to_split): """ Get data size and offset for each. For example, total batch 11, parallel_num 4, result is [3, 3, 3, 2], offsets is [0, 3, 6, 9] total batch 8, parallel_num 4, result is [2, 2, 2, 2], offsets is [0, 2, 4, 6] """ # Calculate the quotient and remainder of total_size divided by parts_to_split quotient = total_size // parts_to_split remainder = total_size % parts_to_split data_num_list = [quotient + 1] * remainder + [quotient] * (parts_to_split - remainder) offset_list = [0] + list(itertools.accumulate(data_num_list)) return data_num_list, offset_list[:-1] def split_dp_params( dp_params: DataParallelRuntimeParams, bs_parts_to_split: int, seq_parts_to_split: int, attn_data_parallel_size: int, attn_tensor_parallel_size: int, prefill_dispatch_use_RS_AG: bool, dp_rank_: int, ) -> List[DataParallelRuntimeParams]: assert bs_parts_to_split == 1 or seq_parts_to_split == 1, \ "We don't support split batch and sequence dimensions concurrently." if dp_params is None: return [None] * bs_parts_to_split * seq_parts_to_split if bs_parts_to_split * seq_parts_to_split == 1: return list([dp_params]) if bs_parts_to_split == 1: results : List[DataParallelRuntimeParams] = [] dp_seq_lens = [] for seq_len in dp_params.seq_lens: tokens, _ = get_data_num_and_offset(seq_len, seq_parts_to_split) dp_seq_lens.append(tokens) query_lens_per_dp_rank = [] # For each dp rank, the batch size is 0 or 1. bs_offset = 0 for i in range(attn_data_parallel_size): if dp_params.batch_sizes[i] > 0: seq_len = dp_params.seq_lens[bs_offset] tokens, _ = get_data_num_and_offset(seq_len, seq_parts_to_split) query_lens_per_dp_rank.append(tokens) bs_offset += dp_params.batch_sizes[i] else: query_lens_per_dp_rank.append([0] * seq_parts_to_split) for i in range(seq_parts_to_split): dp_is_prefill = [] for dp_rank in range(attn_data_parallel_size): dp_is_prefill.append(True) results.append(MLUDPMetadata.make_oot( data_parallel_rank=dp_rank_, data_parallel_size=attn_data_parallel_size, tensor_parallel_size=attn_tensor_parallel_size, dp_token_nums=[query_lens_per_dp_rank[j][i] for j in range(attn_data_parallel_size)], dp_is_prefill=dp_is_prefill, prefill_dispatch_use_RS_AG=prefill_dispatch_use_RS_AG, seq_lens=[seq_lens[i] for seq_lens in dp_seq_lens], batch_sizes=dp_params.batch_sizes, )) return results bs_per_dp = dp_params.batch_sizes # [bs_rank_0, bs_rank_1, ...] seq_lens_per_dp = dp_params.seq_lens # [seq_len_bs_0, seq_len_bs_1,...] # [[bs_rank_0_part_0, bs_rank_0_part_1,...], [bs_rank_1_part_0, bs_rank_1_part_1,...], ...] split_bs_per_dp = [] # [[ # [bs0_part_0_rank_0, bs1_part_0_rank_0, ...], # [bs0_part_1_rank_0, bs1_part_1_rank_0, ...], # ... # ], # [ # [bs0_part_0_rank_1, bs1_part_0_rank_1, ...], # [bs0_part_1_rank_1, bs1_part_1_rank_1, ...], # ... # ], # ] split_query_lens_per_dp = [] for dp_rank in range(attn_data_parallel_size): _bs, _offset = get_data_num_and_offset(bs_per_dp[dp_rank], bs_parts_to_split) split_bs_per_dp.append(_bs) split_query_lens_per_dp.append([]) for i in range(bs_parts_to_split): start = sum(bs_per_dp[:dp_rank]) + _offset[i] end = start + _bs[i] split_query_lens_per_dp[-1].append(dp_params.seq_lens[start:end]) results : List[DataParallelRuntimeParams] = [] for i in range(bs_parts_to_split): dp_query_lens = [sum(split_query_lens_per_dp[dp_rank][i]) for dp_rank in range(attn_data_parallel_size)] seq_lens = [] for dp_rank in range(attn_data_parallel_size): seq_lens += split_query_lens_per_dp[dp_rank][i] batch_sizes = [] for dp_rank in range(attn_data_parallel_size): batch_sizes.append(split_bs_per_dp[dp_rank][i]) dp_is_prefill = [] for dp_rank in range(attn_data_parallel_size): dp_is_prefill.append(True) results.append(MLUDPMetadata.make_oot( data_parallel_rank=dp_rank_, data_parallel_size=attn_data_parallel_size, tensor_parallel_size=attn_tensor_parallel_size, dp_token_nums=dp_query_lens, dp_is_prefill=dp_is_prefill, prefill_dispatch_use_RS_AG=prefill_dispatch_use_RS_AG, seq_lens=seq_lens, batch_sizes=batch_sizes, )) return results def split_input( input: torch.Tensor, bs_parts_to_split: int, seq_parts_to_split: int, attn_metadata_list: List[AttentionMetadata], ) -> List[torch.Tensor]: assert seq_parts_to_split == 1 or bs_parts_to_split == 1, \ "We don't support split batch and sequence dimensions concurrently." if input is None: return [None] * bs_parts_to_split * seq_parts_to_split if bs_parts_to_split * seq_parts_to_split == 1: return list([input]) token_num_list = [0] * len(attn_metadata_list) for i, metadata in enumerate(attn_metadata_list): common_metadata, layer_metadata = get_common_and_layer_metadata(metadata) if layer_metadata is not None: token_num_list[i] = layer_metadata.num_actual_tokens # A special case for dummy run if layer_metadata is None and i == 0: token_num_list[i] = input.shape[0] results = list() for i in range(bs_parts_to_split * seq_parts_to_split): start = sum(token_num_list[:i]) end = start + token_num_list[i] results.append(input[start:end]) return results def split_positions( positions: torch.Tensor, bs_parts_to_split: int, seq_parts_to_split: int, attn_metadata: AttentionMetadata, ) -> List[torch.Tensor]: if seq_parts_to_split == 1: return [positions] * bs_parts_to_split common_metadata, layer_metadata = get_common_and_layer_metadata(attn_metadata) total_tokens = layer_metadata.num_actual_tokens if layer_metadata is not None else 0 tokens, offsets = get_data_num_and_offset(total_tokens, seq_parts_to_split) positions_list = [] for i in range(seq_parts_to_split): positions_list.append(positions[offsets[i]: offsets[i] + tokens[i]]) return positions_list def split_attn_metadata( attn_metadata: dict, bs_parts_to_split: int, seq_parts_to_split: int, ) -> List[Any]: """ attn_metdata is a dict, which contains common and layer metadata.""" assert bs_parts_to_split == 1 or seq_parts_to_split == 1, \ "We don't support split batch and sequence dimensions concurrently." if bs_parts_to_split == 1 and seq_parts_to_split == 1: return list([attn_metadata]) if attn_metadata is None: return [None] * bs_parts_to_split * seq_parts_to_split if seq_parts_to_split > 1: common_metadata, layer_metadata = get_common_and_layer_metadata(attn_metadata) if common_metadata is None or not hasattr(common_metadata, 'num_actual_tokens'): raise ValueError("common_metadata is invalid or missing num_actual_tokens") num_prefill_tokens = common_metadata.num_actual_tokens tokens, offsets = get_data_num_and_offset(num_prefill_tokens, seq_parts_to_split) device = common_metadata.seq_lens.device sub_common_metadata, sub_layer_metadata = [], [] for i in range(seq_parts_to_split): # query_start_loc tensor, which indices positions in input. query_start_loc_tensor = torch.empty_like(common_metadata.query_start_loc) query_start_loc_tensor[0] = 0 query_start_loc_tensor[1] = tokens[i] # seq_lens tensor seq_lens_tensor = torch.tensor( [offsets[i] + tokens[i]], dtype=common_metadata.seq_lens.dtype, device=device ) # seq_start_loc tensor, which indicates positions in the sequence(kv cache). seq_start_loc_tensor = torch.empty_like(common_metadata.seq_start_loc) seq_start_loc_tensor[0] = offsets[i] seq_start_loc_tensor[1] = offsets[i] + tokens[i] # max_query_len scalar max_query_len = tokens[i] # num_actual_tokens scalar num_actual_tokens = tokens[i] # num_input_tokens scalar num_input_tokens = num_actual_tokens # infer_mode infer_mode = common_metadata.infer_mode # update common metadata sub_common_metadata.append(MLUCommonAttentionMetadata( query_start_loc=query_start_loc_tensor, query_start_loc_cpu=common_metadata.query_start_loc_cpu, # FIXME: split when used seq_lens=seq_lens_tensor, seq_lens_cpu=common_metadata.seq_lens_cpu, # FIXME: split when used num_computed_tokens_cpu=common_metadata.num_computed_tokens_cpu, # FIXME: split when used num_reqs=common_metadata.num_reqs, # FIXME: split when used num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, max_seq_len=max_query_len, block_table_tensor=common_metadata.block_table_tensor, # FIXME: split when used slot_mapping=common_metadata.slot_mapping, # FIXME: split when used seq_start_loc=seq_start_loc_tensor, num_input_tokens=num_input_tokens, infer_mode=infer_mode, num_prefill_query_tokens=tokens[i], num_prefill_kv_tokens=offsets[i] + tokens[i], )) # slot_mapping tensor slot_mapping = layer_metadata.slot_mapping[offsets[i]:offsets[i] + tokens[i]] # update layer metadata REQUIRED_NUM_DECODES = 0 REQUIRED_NUM_DECODE_TOKENS = 0 REQUIRED_NUM_PREFILLS = 1 if not hasattr(layer_metadata, 'num_prefills') or \ layer_metadata.num_prefills is None: raise ValueError("layer_metadata.num_prefills is required") assert layer_metadata.num_decodes == REQUIRED_NUM_DECODES and \ layer_metadata.num_decode_tokens == REQUIRED_NUM_DECODE_TOKENS and \ layer_metadata.num_prefills == REQUIRED_NUM_PREFILLS, ( f"num_decodes, num_decode_tokens, num_prefills must be {REQUIRED_NUM_DECODES}, {REQUIRED_NUM_DECODE_TOKENS}, " f"{REQUIRED_NUM_PREFILLS}, but got {layer_metadata.num_decodes}, {layer_metadata.num_decode_tokens}, " f"{layer_metadata.num_prefills}." ) assert layer_metadata.prefill.chunked_context is None, ( f"chunked_context is only available for prefill with chunked context, " f"and it is not supported when enabling mcc." ) prefill_metadata = FlashMLAPrefillMetadata( block_table=layer_metadata.prefill.block_table, query_start_loc=query_start_loc_tensor, max_query_len=max_query_len, chunked_context=None, num_prefills=layer_metadata.prefill.num_prefills, max_seq_len=layer_metadata.prefill.max_seq_len, ) # Note: for sequence dimension partition, we provide cu_seqlens_kv filed to # indicates key/value size for flash attention operator. prefill_metadata.cu_seqlens_kv = torch.empty_like(prefill_metadata.query_start_loc) prefill_metadata.cu_seqlens_kv[0] = 0 prefill_metadata.cu_seqlens_kv[1] = offsets[i] + tokens[i] sub_layer_metadata.append(FlashMLAMetadata( num_reqs=layer_metadata.num_reqs, max_query_len=max_query_len, max_seq_len=max_query_len, num_actual_tokens=num_actual_tokens, query_start_loc=query_start_loc_tensor, slot_mapping=slot_mapping, num_decodes=layer_metadata.num_decodes, num_decode_tokens=layer_metadata.num_decode_tokens, num_prefills=layer_metadata.num_prefills, num_prefill_tokens=tokens[i], head_dim=layer_metadata.head_dim, decode=layer_metadata.decode, prefill=prefill_metadata, )) sub_attn_metadata_list = [] for common_meta, layer_meta in zip(sub_common_metadata, sub_layer_metadata): sub_attn_metadata_dict = {} for key, value in attn_metadata.items(): if key == COMMON_METADATA_STR: sub_attn_metadata_dict[key] = common_meta else: sub_attn_metadata_dict[key] = layer_meta sub_attn_metadata_list.append(sub_attn_metadata_dict) return sub_attn_metadata_list elif bs_parts_to_split > 1: common_metadata, layer_metadata = get_common_and_layer_metadata(attn_metadata) if not hasattr(layer_metadata, 'num_prefills') or layer_metadata.num_prefills is None: raise ValueError("layer_metadata.num_prefills is required") total_batch = layer_metadata.num_prefills batch_sizes, offsets = get_data_num_and_offset(total_batch, bs_parts_to_split) sub_common_metadata, sub_layer_metadata = [], [] for i in range(bs_parts_to_split): # query_start_loc tensor start, end = offsets[i], offsets[i] + batch_sizes[i] query_start_loc_tensor = common_metadata.query_start_loc[start:end+1].clone() if i > 0: query_start_loc_tensor -= common_metadata.query_start_loc[start] # block_table block_tables = torch.empty( (batch_sizes[i], 0), dtype=layer_metadata.prefill.block_table.dtype, device=layer_metadata.prefill.block_table.device, ) # seq_lens tensor seq_lens_tensor = common_metadata.seq_lens[start:end].clone() # seq_start_loc tensor seq_start_loc_tensor = query_start_loc_tensor # max_query_len scalar max_query_len = seq_lens_tensor.max().item() if seq_lens_tensor.numel() > 0 else 0 # num_actual_tokens scalar num_actual_tokens = seq_start_loc_tensor[-1].item() # num_input_tokens scalar num_input_tokens = num_actual_tokens # infer_mode infer_mode = common_metadata.infer_mode # slot_mapping tensor slot_mapping_start = 0 for data in sub_common_metadata: slot_mapping_start += data.num_actual_tokens slot_mapping_tensor = layer_metadata.slot_mapping[ slot_mapping_start:slot_mapping_start + num_actual_tokens ] # update common metadata sub_common_metadata.append(MLUCommonAttentionMetadata( query_start_loc=query_start_loc_tensor, query_start_loc_cpu=common_metadata.query_start_loc_cpu, # FIXME: split when used seq_lens=seq_lens_tensor, seq_lens_cpu=common_metadata.seq_lens_cpu, # FIXME: split when used num_computed_tokens_cpu=common_metadata.num_computed_tokens_cpu, # FIXME: split when used num_reqs=common_metadata.num_reqs, # FIXME: split when used block_table_tensor=common_metadata.block_table_tensor, # FIXME: split when used slot_mapping=common_metadata.slot_mapping, # FIXME: split when used seq_start_loc=seq_start_loc_tensor, max_query_len=max_query_len, max_seq_len=max_query_len, num_actual_tokens=num_actual_tokens, num_input_tokens=num_input_tokens, infer_mode=infer_mode, num_prefill_query_tokens=num_actual_tokens, num_prefill_kv_tokens=num_actual_tokens, )) # update layer_metadata prefill_metadata = FlashMLAPrefillMetadata( block_table=block_tables, query_start_loc=query_start_loc_tensor, max_query_len=max_query_len, chunked_context=None, num_prefills=batch_sizes[i], max_seq_len=max_query_len, ) sub_layer_metadata.append(FlashMLAMetadata( num_reqs=batch_sizes[i], max_query_len=max_query_len, max_seq_len=max_query_len, num_actual_tokens=num_actual_tokens, query_start_loc=query_start_loc_tensor, slot_mapping=slot_mapping_tensor, num_decodes=layer_metadata.num_decodes, # useless field num_decode_tokens=0, # useless field num_prefills=batch_sizes[i], num_prefill_tokens=num_actual_tokens, head_dim=layer_metadata.head_dim, decode=layer_metadata.decode, prefill=prefill_metadata, )) sub_attn_metadata_list = [] for common_meta, layer_meta in zip(sub_common_metadata, sub_layer_metadata): sub_attn_metadata_dict = {} for key, value in attn_metadata.items(): if key == COMMON_METADATA_STR: sub_attn_metadata_dict[key] = common_meta else: sub_attn_metadata_dict[key] = layer_meta sub_attn_metadata_list.append(sub_attn_metadata_dict) return sub_attn_metadata_list def execute_with_updated_forward_context( vllm_config: VllmConfig, attn_metadata: AttentionMetadata, func: Callable, kwargs: Dict[str, Any], ): with set_forward_context(attn_metadata, vllm_config): return func(**kwargs)