Files
2026-04-24 09:58:03 +08:00

296 lines
10 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import os
import numpy as np
import pandas as pd
import torch
from typing import TYPE_CHECKING, Union
from dataclasses import dataclass
from enum import Enum
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.forward_context import get_forward_context
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
COMMON_METADATA_STR: str = "common_metadata"
class MLUInferMode(Enum):
CHUNKED = 1
PREFILL_ONLY = 2
DECODE_ONLY = 3
@classmethod
def build(
cls,
max_query_len,
max_computed_tokens,
uniform_decode_query_len: int = 1,
) -> Enum:
if max_query_len <= uniform_decode_query_len:
return MLUInferMode.DECODE_ONLY
elif max_computed_tokens == 0:
return MLUInferMode.PREFILL_ONLY
else:
return MLUInferMode.CHUNKED
@property
def is_prefill_only(self):
return self == MLUInferMode.PREFILL_ONLY
@property
def is_decode_only(self):
return self == MLUInferMode.DECODE_ONLY
@property
def is_chunked(self):
return self == MLUInferMode.CHUNKED
@dataclass
class MLUCommonAttentionMetadata(CommonAttentionMetadata):
"""
Attention metadata attributes that can be shared by layers in different KV
cache groups and thus having different block table.
"""
seq_start_loc: torch.Tensor | None = None
seq_start_loc_cpu: torch.Tensor | None = None
"""(batch_size + 1,), the start location of each request in the input key/value sequence."""
num_input_tokens: int = 0
"""Number of query tokens with padding."""
num_prefill_query_tokens: int = 0
"""Number of query tokens in prefill phase."""
num_prefill_kv_tokens: int = 0
"""Number of key/value tokens in prefill phase."""
infer_mode: MLUInferMode | None = None
"""Inference mode for flash attention."""
@property
def is_prefill_only(self):
return self.infer_mode == MLUInferMode.PREFILL_ONLY
@property
def is_decode_only(self):
return self.infer_mode == MLUInferMode.DECODE_ONLY
@property
def is_chunked(self):
return self.infer_mode == MLUInferMode.CHUNKED
@classmethod
def build(
cls,
query_start_loc, query_start_loc_cpu,
seq_lens, seq_lens_cpu,
num_computed_tokens_cpu,
num_reqs, num_actual_tokens, max_query_len,
block_table_tensor, slot_mapping,
seq_start_loc, is_start_loc_match,
num_input_tokens: int = 0,
num_speculative_tokens: int = 0,
has_prefill_reqs: bool = False
):
"""Build attention metadata for MLU inference.
Args:
has_prefill_reqs: Whether there are pending prefill requests with chunked.
"""
infer_mode = None
if is_start_loc_match:
infer_mode = MLUInferMode.PREFILL_ONLY
elif max_query_len <= (1 + num_speculative_tokens) and (not has_prefill_reqs):
infer_mode = MLUInferMode.DECODE_ONLY
else:
infer_mode = MLUInferMode.CHUNKED
num_input_tokens = (
num_actual_tokens if num_input_tokens == 0
else num_input_tokens
)
max_seq_len = int(seq_lens_cpu.max())
return cls(query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
block_table_tensor=block_table_tensor,
slot_mapping=slot_mapping,
seq_start_loc=seq_start_loc,
seq_start_loc_cpu=seq_start_loc.to("cpu", non_blocking=True),
num_input_tokens=num_input_tokens,
infer_mode=infer_mode,
num_prefill_query_tokens=num_actual_tokens,
num_prefill_kv_tokens=num_actual_tokens)
def save(self, infer_phase: str):
csv_path = os.getenv("VLLM_STEP_INPUT_CSV_PATH", None)
if not csv_path:
return
header = [
"infer_phase", "infer_mode", "num_reqs", "num_actual_tokens",
"max_query_len", "max_seq_len", "query_start_loc", "seq_lens"
]
data = [
infer_phase, self.infer_mode, self.num_reqs,
self.num_actual_tokens, self.max_query_len, self.max_seq_len,
str(self.query_start_loc_cpu.tolist()),
str(self.seq_lens_cpu.tolist())
]
data_dict = dict(zip(header, data))
df_csv = pd.DataFrame(data_dict, index=[0])
if infer_phase == "RealInfer":
print(df_csv.to_string())
try:
if dir_path := os.path.dirname(csv_path):
os.makedirs(dir_path, exist_ok=True)
append = False
if os.path.isfile(csv_path):
try:
df_old = pd.read_csv(csv_path)
append = (df_old.columns.tolist() == header)
except Exception as e:
raise RuntimeError(f"Existing {csv_path} failed to be read and will be overwritten")
if append:
df_csv.to_csv(csv_path, mode='a', header=False, index=False)
else:
df_csv.to_csv(csv_path, index=False)
except Exception as e:
raise RuntimeError(f"Invalid VLLM_STEP_INPUT_CSV_PATH: {csv_path} to dump step inputs, Error: {e}")
def get_common_metadata_from_attn_metadata(
attn_metadata) -> Union[MLUCommonAttentionMetadata, None]:
"""
Get MLUCommonAttentionMetadata for MLU-V1 inference.
Use outside of set_forward_context().
"""
if attn_metadata is None:
return
assert (isinstance(attn_metadata, dict)
and COMMON_METADATA_STR in attn_metadata), \
f"MLU-V1 only support type(attn_metadata)=dict, and " + \
f"{COMMON_METADATA_STR} in attn_metadata. Now, type(attn_metadata)=" + \
f"{type(attn_metadata)}, or {COMMON_METADATA_STR} not in attn_metadata."
return attn_metadata[COMMON_METADATA_STR]
def get_common_metadata() -> Union[MLUCommonAttentionMetadata, None]:
"""
Get MLUCommonAttentionMetadata for MLU-V1 inference.
Use inside of set_forward_context().
"""
attn_metadata = get_forward_context().attn_metadata
return get_common_metadata_from_attn_metadata(attn_metadata)
def unpad_common_attn_metadata(
common_metadata: MLUCommonAttentionMetadata,
num_reqs: int,
num_scheduled_tokens: int,
):
"""
Unpad MLUCommonAttentionMetadata by given num_reqs and num_scheduled_tokens.
"""
common_metadata.num_reqs = num_reqs
common_metadata.num_input_tokens = num_scheduled_tokens
common_metadata.query_start_loc = common_metadata.query_start_loc[:num_reqs + 1]
common_metadata.query_start_loc_cpu = common_metadata.query_start_loc_cpu[:num_reqs + 1]
common_metadata.seq_start_loc = common_metadata.seq_start_loc[:num_reqs + 1]
common_metadata.seq_lens = common_metadata.seq_lens[:num_reqs]
common_metadata.seq_lens_cpu = common_metadata.seq_lens_cpu[:num_reqs]
common_metadata.block_table_tensor = common_metadata.block_table_tensor[:num_reqs]
def reorder_batch_to_split_decodes_and_prefills(
input_batch: "InputBatch",
scheduler_output: "SchedulerOutput",
decode_threshold: int = 1,
) -> bool:
"""
Reorders the batch to split into prefill and decode requests; places all
requests with <= decode_threshold tokens at the front of the batch.
Returns:
True if the batch was modified, False otherwise.
"""
# We now want to reorder the batch into decode → extend → prefill order
# where:
# decode: request with num_scheduled_tokens <= decode_threshold
# extend: non-decode request with existing context
# prefill: non-decode request with no existing context
# NOTE for now we loosely use "decode" to mean requests where attention is
# likely memory-bound and "prefill" to mean requests where attention is
# likely compute-bound,
num_reqs = len(input_batch.req_ids)
num_scheduled_tokens = [
scheduler_output.num_scheduled_tokens[id] for id in input_batch.req_ids
]
num_scheduled_tokens_np = np.array(num_scheduled_tokens)
num_computed_tokens_np = input_batch.num_computed_tokens_cpu[:num_reqs]
'''
=============================
Modify by vllm_mlu
=============================
@brief: enhence decode mode condition that all prompt tokens are computed.
'''
# is_decode = num_scheduled_tokens_np <= decode_threshold
is_decode = (
(num_scheduled_tokens_np <= decode_threshold)
& (num_computed_tokens_np >= input_batch.num_prompt_tokens[:num_reqs])
)
'''
==================
End of MLU Hijack
==================
'''
is_extend = (~is_decode) & (num_computed_tokens_np > 0)
is_prefill = (~is_decode) & (num_computed_tokens_np == 0)
# Desired order: decode → extend → prefill
req_regions = np.zeros(is_decode.shape, dtype=np.int32) # 0 = decode by default
req_regions[is_extend] = 1
req_regions[is_prefill] = 2
num_decodes = int(is_decode.sum())
num_extends = int(is_extend.sum())
target_regions = np.zeros(num_reqs, dtype=np.int32)
target_regions[num_decodes : num_decodes + num_extends] = 1
target_regions[num_decodes + num_extends :] = 2
needs_swap = req_regions != target_regions
if not needs_swap.any():
return False
# Extract indices that need swapping and sort by target region
orig_indices = np.where(needs_swap)[0]
sorted_order = np.argsort(req_regions[needs_swap], kind="stable")
src_indices = orig_indices[sorted_order]
src_dest_map = {int(src): int(dst) for src, dst in zip(src_indices, orig_indices)}
for src in src_dest_map:
dst = src_dest_map[src]
while src != dst:
input_batch.swap_states(src, dst)
# Mark dst as done by updating its destination to itself
next_dst = src_dest_map.get(dst, dst)
src_dest_map[dst] = dst
dst = next_dst
return True