[Model] Support DeepSeek-V4
This commit is contained in:
3
vllm_mlu/v1/attention/backends/__init__.py
Normal file
3
vllm_mlu/v1/attention/backends/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
1050
vllm_mlu/v1/attention/backends/flash_attn.py
Normal file
1050
vllm_mlu/v1/attention/backends/flash_attn.py
Normal file
File diff suppressed because it is too large
Load Diff
404
vllm_mlu/v1/attention/backends/gdn_attn.py
Normal file
404
vllm_mlu/v1/attention/backends/gdn_attn.py
Normal file
@@ -0,0 +1,404 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from collections import OrderedDict, deque
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.v1.attention.backends.gdn_attn import (GDNAttentionMetadataBuilder,
|
||||
GDNAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
||||
compute_causal_conv1d_metadata,
|
||||
split_decodes_and_prefills,)
|
||||
|
||||
|
||||
class DeviceAwareLocalIdMapper:
|
||||
def __init__(self, batch_size: int):
|
||||
if batch_size <= 0:
|
||||
raise ValueError("batch_size must be positive")
|
||||
self.batch_size = batch_size
|
||||
self.global_to_local: OrderedDict[int, int] = OrderedDict()
|
||||
self.local_to_global = {}
|
||||
self.available_local_ids = deque(range(batch_size))
|
||||
|
||||
def batch_get_local_ids(self, global_id_tensor: torch.Tensor) -> torch.Tensor:
|
||||
original_device = global_id_tensor.device
|
||||
original_shape = global_id_tensor.shape
|
||||
|
||||
flat_global_cpu = global_id_tensor.cpu().numpy().ravel()
|
||||
num_elements = flat_global_cpu.size
|
||||
local_ids_cpu = torch.empty(num_elements, dtype=global_id_tensor.dtype)
|
||||
|
||||
g2l = self.global_to_local
|
||||
unique_miss_set = set()
|
||||
|
||||
# Pass 1: handle hits and collect unique misses
|
||||
for i, gid in enumerate(flat_global_cpu):
|
||||
if gid in g2l:
|
||||
local_id = g2l[gid]
|
||||
local_ids_cpu[i] = local_id
|
||||
g2l.move_to_end(gid)
|
||||
else:
|
||||
local_ids_cpu[i] = -1
|
||||
unique_miss_set.add(gid)
|
||||
|
||||
# Pass 2: assign local IDs to unique new global IDs
|
||||
new_mappings = {}
|
||||
available = self.available_local_ids
|
||||
local_to_global = self.local_to_global
|
||||
|
||||
for gid in unique_miss_set:
|
||||
if len(g2l) >= self.batch_size:
|
||||
old_gid, old_local = g2l.popitem(last=False)
|
||||
available.append(old_local)
|
||||
local_to_global.pop(old_local, None)
|
||||
new_local = available.popleft()
|
||||
g2l[gid] = new_local
|
||||
local_to_global[new_local] = gid
|
||||
new_mappings[gid] = new_local
|
||||
|
||||
# Pass 3: fill in all miss positions
|
||||
for i, gid in enumerate(flat_global_cpu):
|
||||
if local_ids_cpu[i].item() == -1:
|
||||
local_ids_cpu[i] = new_mappings[gid]
|
||||
|
||||
return local_ids_cpu.to(original_device).view(original_shape)
|
||||
|
||||
def reset(self):
|
||||
self.global_to_local.clear()
|
||||
self.local_to_global.clear()
|
||||
self.available_local_ids = deque(range(self.batch_size))
|
||||
|
||||
def vllm__v1__attention__bachends__GDNAttentionMetadataBuilder____init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
if self.speculative_config:
|
||||
self.num_spec = self.speculative_config.num_speculative_tokens
|
||||
else:
|
||||
self.num_spec = 0
|
||||
self.use_spec_decode = self.num_spec > 0
|
||||
self._init_reorder_batch_threshold(1, self.use_spec_decode)
|
||||
|
||||
self.use_full_cuda_graph = (
|
||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
)
|
||||
self.decode_cudagraph_max_bs = min(
|
||||
self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1),
|
||||
self.compilation_config.max_cudagraph_capture_size,
|
||||
)
|
||||
|
||||
self.spec_state_indices_tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs, self.num_spec + 1),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.non_spec_state_indices_tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.spec_sequence_masks = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.bool,
|
||||
device=device,
|
||||
)
|
||||
self.spec_token_indx = torch.empty(
|
||||
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.non_spec_token_indx = torch.empty(
|
||||
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.spec_query_start_loc = torch.empty(
|
||||
(self.decode_cudagraph_max_bs + 1,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.non_spec_query_start_loc = torch.empty(
|
||||
(self.decode_cudagraph_max_bs + 1,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.num_accepted_tokens = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: support qwen3-next
|
||||
'''
|
||||
self.mapper = DeviceAwareLocalIdMapper(self.vllm_config.mlu_config.mamba_support_max_batch_size)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__v1__attention__bachends__GDNAttentionMetadataBuilder__build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
num_decode_draft_tokens_cpu: torch.Tensor | None = None,
|
||||
fast_build: bool = False,
|
||||
) -> GDNAttentionMetadata:
|
||||
m = common_attn_metadata
|
||||
|
||||
query_start_loc = m.query_start_loc
|
||||
context_lens = m.num_computed_tokens_cpu
|
||||
context_lens_tensor = context_lens.to(query_start_loc.device)
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||
|
||||
if (
|
||||
not self.use_spec_decode
|
||||
or num_decode_draft_tokens_cpu is None
|
||||
or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >= 0]
|
||||
.sum()
|
||||
.item()
|
||||
== 0
|
||||
):
|
||||
spec_sequence_masks = None
|
||||
num_spec_decodes = 0
|
||||
else:
|
||||
spec_sequence_masks = num_decode_draft_tokens_cpu >= 0
|
||||
num_spec_decodes = spec_sequence_masks.sum().item()
|
||||
if num_spec_decodes == 0:
|
||||
spec_sequence_masks = None
|
||||
else:
|
||||
spec_sequence_masks = spec_sequence_masks.to(
|
||||
query_start_loc.device, non_blocking=True
|
||||
)
|
||||
|
||||
if spec_sequence_masks is None:
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(m, decode_threshold=1)
|
||||
)
|
||||
num_spec_decode_tokens = 0
|
||||
spec_token_indx = None
|
||||
non_spec_token_indx = None
|
||||
spec_state_indices_tensor = None
|
||||
non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
|
||||
spec_query_start_loc = None
|
||||
non_spec_query_start_loc = query_start_loc
|
||||
num_accepted_tokens = None
|
||||
else:
|
||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
|
||||
non_spec_query_lens = query_lens[~spec_sequence_masks]
|
||||
num_decodes = (non_spec_query_lens == 1).sum().item()
|
||||
num_prefills = non_spec_query_lens.size(0) - num_decodes
|
||||
num_decode_tokens = num_decodes
|
||||
num_prefill_tokens = non_spec_query_lens.sum().item() - num_decode_tokens
|
||||
num_spec_decode_tokens = (
|
||||
query_lens.sum().item() - num_prefill_tokens - num_decode_tokens
|
||||
)
|
||||
|
||||
if num_prefills == 0 and num_decodes == 0:
|
||||
spec_token_size = min(
|
||||
num_spec_decodes * (self.num_spec + 1),
|
||||
query_start_loc[-1].item(),
|
||||
)
|
||||
spec_token_indx = torch.arange(
|
||||
spec_token_size,
|
||||
dtype=torch.int32,
|
||||
device=query_start_loc.device,
|
||||
)
|
||||
non_spec_token_indx = torch.empty(
|
||||
0, dtype=torch.int32, device=query_start_loc.device
|
||||
)
|
||||
spec_state_indices_tensor = m.block_table_tensor[:, : self.num_spec + 1]
|
||||
non_spec_state_indices_tensor = None
|
||||
spec_query_start_loc = query_start_loc
|
||||
non_spec_query_start_loc = None
|
||||
else:
|
||||
spec_token_masks = torch.repeat_interleave(
|
||||
spec_sequence_masks, query_lens
|
||||
)
|
||||
index = torch.argsort(spec_token_masks)
|
||||
num_non_spec_tokens = num_prefill_tokens + num_decode_tokens
|
||||
non_spec_token_indx = index[:num_non_spec_tokens]
|
||||
spec_token_indx = index[num_non_spec_tokens:]
|
||||
|
||||
spec_state_indices_tensor = m.block_table_tensor[
|
||||
spec_sequence_masks, : self.num_spec + 1
|
||||
]
|
||||
non_spec_state_indices_tensor = m.block_table_tensor[
|
||||
~spec_sequence_masks, 0
|
||||
]
|
||||
|
||||
spec_query_start_loc = torch.zeros(
|
||||
num_spec_decodes + 1,
|
||||
dtype=torch.int32,
|
||||
device=query_start_loc.device,
|
||||
)
|
||||
torch.cumsum(
|
||||
query_lens[spec_sequence_masks], dim=0, out=spec_query_start_loc[1:]
|
||||
)
|
||||
non_spec_query_start_loc = torch.zeros(
|
||||
query_lens.size(0) - num_spec_decodes + 1,
|
||||
dtype=torch.int32,
|
||||
device=query_start_loc.device,
|
||||
)
|
||||
torch.cumsum(
|
||||
query_lens[~spec_sequence_masks],
|
||||
dim=0,
|
||||
out=non_spec_query_start_loc[1:],
|
||||
)
|
||||
|
||||
assert num_accepted_tokens is not None
|
||||
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
|
||||
|
||||
if num_prefills > 0:
|
||||
has_initial_state = context_lens_tensor > 0
|
||||
if spec_sequence_masks is not None:
|
||||
has_initial_state = has_initial_state[~spec_sequence_masks]
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = (
|
||||
compute_causal_conv1d_metadata(non_spec_query_start_loc)
|
||||
)
|
||||
else:
|
||||
has_initial_state = None
|
||||
num_actual_tokens = (
|
||||
num_prefill_tokens + num_decode_tokens + num_spec_decode_tokens
|
||||
)
|
||||
|
||||
# prepare tensors for cudagraph
|
||||
#
|
||||
# With speculative decoding, the xgrammar backend may rollback tokens
|
||||
# and causing some sequences has less draft tokens than self.num_spec.
|
||||
#
|
||||
# In above cases, the max possible batch size for n tokens, can be
|
||||
# min(n, cudagraph_max_bs).
|
||||
if (
|
||||
self.use_full_cuda_graph
|
||||
and num_prefills == 0
|
||||
and num_decodes == 0
|
||||
and num_spec_decodes <= self.decode_cudagraph_max_bs
|
||||
and num_spec_decode_tokens <= self.decode_cudagraph_max_bs
|
||||
):
|
||||
num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
|
||||
batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens)
|
||||
|
||||
self.spec_state_indices_tensor[:num_spec_decodes].copy_(
|
||||
spec_state_indices_tensor, non_blocking=True
|
||||
)
|
||||
spec_state_indices_tensor = self.spec_state_indices_tensor[:batch_size]
|
||||
spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID)
|
||||
|
||||
self.spec_sequence_masks[:num_spec_decodes].copy_(
|
||||
spec_sequence_masks, non_blocking=True
|
||||
)
|
||||
spec_sequence_masks = self.spec_sequence_masks[:batch_size]
|
||||
spec_sequence_masks[num_spec_decodes:].fill_(False)
|
||||
|
||||
assert non_spec_token_indx is not None and spec_token_indx is not None
|
||||
self.non_spec_token_indx[: non_spec_token_indx.size(0)].copy_(
|
||||
non_spec_token_indx, non_blocking=True
|
||||
)
|
||||
non_spec_token_indx = self.non_spec_token_indx[
|
||||
: non_spec_token_indx.size(0)
|
||||
]
|
||||
|
||||
self.spec_token_indx[: spec_token_indx.size(0)].copy_(
|
||||
spec_token_indx, non_blocking=True
|
||||
)
|
||||
spec_token_indx = self.spec_token_indx[: spec_token_indx.size(0)]
|
||||
|
||||
self.spec_query_start_loc[: num_spec_decodes + 1].copy_(
|
||||
spec_query_start_loc, non_blocking=True
|
||||
)
|
||||
spec_num_query_tokens = spec_query_start_loc[-1] # type: ignore[index]
|
||||
spec_query_start_loc = self.spec_query_start_loc[: batch_size + 1]
|
||||
spec_query_start_loc[num_spec_decodes + 1 :].fill_(spec_num_query_tokens)
|
||||
|
||||
self.num_accepted_tokens[:num_spec_decodes].copy_(
|
||||
num_accepted_tokens, non_blocking=True
|
||||
)
|
||||
num_accepted_tokens = self.num_accepted_tokens[:batch_size]
|
||||
num_accepted_tokens[num_spec_decodes:].fill_(1)
|
||||
|
||||
if (
|
||||
self.use_full_cuda_graph
|
||||
and num_prefills == 0
|
||||
and num_spec_decodes == 0
|
||||
and num_decodes <= self.decode_cudagraph_max_bs
|
||||
):
|
||||
num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
|
||||
batch_size = num_actual_tokens
|
||||
|
||||
self.non_spec_state_indices_tensor[:num_decodes].copy_(
|
||||
non_spec_state_indices_tensor, non_blocking=True
|
||||
)
|
||||
non_spec_state_indices_tensor = self.non_spec_state_indices_tensor[
|
||||
:batch_size
|
||||
]
|
||||
non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID)
|
||||
|
||||
self.non_spec_query_start_loc[: num_decodes + 1].copy_(
|
||||
non_spec_query_start_loc, non_blocking=True
|
||||
)
|
||||
non_spec_num_query_tokens = non_spec_query_start_loc[-1] # type: ignore[index]
|
||||
non_spec_query_start_loc = self.non_spec_query_start_loc[: batch_size + 1]
|
||||
non_spec_query_start_loc[num_decodes + 1 :].fill_(non_spec_num_query_tokens)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: support qwen3-next
|
||||
'''
|
||||
non_spec_state_indices_tensor = self.mapper.batch_get_local_ids(non_spec_state_indices_tensor)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
attn_metadata = GDNAttentionMetadata(
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_spec_decodes=num_spec_decodes,
|
||||
num_spec_decode_tokens=num_spec_decode_tokens,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
has_initial_state=has_initial_state,
|
||||
spec_query_start_loc=spec_query_start_loc,
|
||||
non_spec_query_start_loc=non_spec_query_start_loc,
|
||||
spec_state_indices_tensor=spec_state_indices_tensor,
|
||||
non_spec_state_indices_tensor=non_spec_state_indices_tensor,
|
||||
spec_sequence_masks=spec_sequence_masks,
|
||||
spec_token_indx=spec_token_indx,
|
||||
non_spec_token_indx=non_spec_token_indx,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
nums_dict=nums_dict,
|
||||
batch_ptr=batch_ptr,
|
||||
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
MluHijackObject.apply_hijack(GDNAttentionMetadataBuilder,
|
||||
GDNAttentionMetadataBuilder.__init__,
|
||||
vllm__v1__attention__bachends__GDNAttentionMetadataBuilder____init__)
|
||||
|
||||
MluHijackObject.apply_hijack(GDNAttentionMetadataBuilder,
|
||||
GDNAttentionMetadataBuilder.build,
|
||||
vllm__v1__attention__bachends__GDNAttentionMetadataBuilder__build)
|
||||
934
vllm_mlu/v1/attention/backends/mla/flashmla.py
Normal file
934
vllm_mlu/v1/attention/backends/mla/flashmla.py
Normal file
@@ -0,0 +1,934 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.math_utils import cdiv, round_down
|
||||
from vllm.attention.backends.utils import MLADims
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend, MLACommonPrefillMetadata,
|
||||
MLACommonDecodeMetadata, MLACommonMetadata,
|
||||
MLACommonMetadataBuilder, M, QueryLenSupport,
|
||||
use_cudnn_prefill, use_flashinfer_prefill,
|
||||
use_trtllm_ragged_deepseek_prefill,
|
||||
FlashInferPrefillMetadata,
|
||||
CudnnPrefillMetadata,
|
||||
MLACommonImpl,
|
||||
CUDNN_WORKSPACE_SIZE
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport, split_decodes_and_prefills,
|
||||
infer_global_hyperparameters, get_per_layer_parameters,
|
||||
)
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionLayer,
|
||||
MLAAttentionImpl,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
import vllm_mlu._mlu_utils as mlu_envs
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm_mlu.v1.attention.backends.flash_attn import MLUFlashAttentionImpl
|
||||
from vllm_mlu.v1.attention.backends.utils import (
|
||||
MLUCommonAttentionMetadata, get_common_metadata,
|
||||
MLUInferMode)
|
||||
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
|
||||
from vllm.platforms import current_platform
|
||||
from vllm import envs
|
||||
try:
|
||||
from flashinfer import BatchPrefillWithRaggedKVCacheWrapper
|
||||
from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache # noqa: F401
|
||||
|
||||
flashinfer_available = True
|
||||
except ImportError:
|
||||
BatchPrefillWithRaggedKVCacheWrapper = object
|
||||
|
||||
flashinfer_available = False
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
class MLACommonBackend_MluHijack(MLACommonBackend):
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [576, 512]
|
||||
|
||||
|
||||
def get_mla_dims(model_config: ModelConfig) -> MLADims:
|
||||
hf_text_config = model_config.hf_text_config
|
||||
|
||||
if model_config.hf_text_config.model_type == "deepseek_v4":
|
||||
return MLADims(
|
||||
q_lora_rank=getattr(hf_text_config, "q_lora_rank", None),
|
||||
kv_lora_rank=hf_text_config.head_dim,
|
||||
qk_nope_head_dim=hf_text_config.head_dim - hf_text_config.rope_head_dim,
|
||||
qk_rope_head_dim=hf_text_config.rope_head_dim,
|
||||
v_head_dim=hf_text_config.head_dim,
|
||||
)
|
||||
|
||||
return MLADims(
|
||||
q_lora_rank=getattr(hf_text_config, "q_lora_rank", None),
|
||||
kv_lora_rank=hf_text_config.kv_lora_rank,
|
||||
qk_nope_head_dim=hf_text_config.qk_nope_head_dim,
|
||||
qk_rope_head_dim=hf_text_config.qk_rope_head_dim,
|
||||
v_head_dim=hf_text_config.v_head_dim,
|
||||
)
|
||||
|
||||
|
||||
class MLACommonMetadataBuilder_MluHijack(MLACommonMetadataBuilder):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
metadata_cls: type[M] | None = None,
|
||||
supports_dcp_with_varlen: bool = False,
|
||||
):
|
||||
self.metadata_cls = (
|
||||
metadata_cls if metadata_cls is not None else MLACommonMetadata
|
||||
)
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.vllm_config = vllm_config
|
||||
self.device = device
|
||||
|
||||
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
|
||||
self.mla_dims = get_mla_dims(self.model_config)
|
||||
self.aot_schedule = current_platform.is_cuda()
|
||||
try:
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
self.dcp_world_size = 1
|
||||
self.dcp_rank = 0
|
||||
self.dcp_local_block_size = parallel_config.dcp_kv_cache_interleave_size
|
||||
self.dcp_virtual_block_size = self.dcp_local_block_size * self.dcp_world_size
|
||||
|
||||
# Don't try to access the runner on AMD
|
||||
if self.aot_schedule:
|
||||
self.page_size = self.kv_cache_spec.block_size
|
||||
|
||||
self.chunked_prefill_workspace_size = (
|
||||
self.determine_chunked_prefill_workspace_size(vllm_config)
|
||||
)
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
# Note(hc): The local kvcache is incomplete when DCP is triggered,
|
||||
# an additional kvcache allgather across the DCP group is therefore
|
||||
# required, so the workspace has to be enlarged by 1/DCP relative
|
||||
# to the original TP allocation.
|
||||
assert self.chunked_prefill_workspace_size % self.dcp_world_size == 0
|
||||
self.chunked_prefill_workspace = torch.empty(
|
||||
(
|
||||
self.chunked_prefill_workspace_size
|
||||
+ self.chunked_prefill_workspace_size // self.dcp_world_size,
|
||||
self.model_config.get_head_size(),
|
||||
),
|
||||
dtype=self.model_config.dtype,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.chunked_prefill_workspace = torch.empty(
|
||||
(
|
||||
self.chunked_prefill_workspace_size,
|
||||
self.model_config.get_head_size(),
|
||||
),
|
||||
dtype=self.model_config.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self._use_cudnn_prefill = use_cudnn_prefill()
|
||||
self._use_fi_prefill = use_flashinfer_prefill()
|
||||
self._use_trtllm_ragged_prefill = use_trtllm_ragged_deepseek_prefill()
|
||||
self.prefill_metadata_cls = (
|
||||
FlashInferPrefillMetadata
|
||||
if self._use_fi_prefill
|
||||
else CudnnPrefillMetadata
|
||||
if self._use_cudnn_prefill
|
||||
else MLACommonPrefillMetadata
|
||||
)
|
||||
|
||||
if self._use_fi_prefill:
|
||||
self._workspace_buffer = torch.empty(
|
||||
envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self._fi_prefill_main: BatchPrefillWithRaggedKVCacheWrapper | None = None
|
||||
self._fi_prefill_chunks: list[BatchPrefillWithRaggedKVCacheWrapper] = []
|
||||
|
||||
self._global_hyperparameters = infer_global_hyperparameters(
|
||||
get_per_layer_parameters(vllm_config, layer_names, MLACommonImpl)
|
||||
)
|
||||
|
||||
if self._use_trtllm_ragged_prefill:
|
||||
self._workspace_buffer = torch.empty(
|
||||
envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if self._use_cudnn_prefill:
|
||||
self.cudnn_workspace = torch.empty(
|
||||
CUDNN_WORKSPACE_SIZE * scheduler_config.max_num_seqs,
|
||||
dtype=torch.int8,
|
||||
device=device,
|
||||
)
|
||||
|
||||
supports_spec_decode = self.query_len_support != QueryLenSupport.SINGLE_ONLY
|
||||
self._init_reorder_batch_threshold(
|
||||
self.reorder_batch_threshold, supports_spec_decode, supports_dcp_with_varlen
|
||||
)
|
||||
|
||||
# Validate consistency between query_len_support and reorder_batch_threshold
|
||||
if self.query_len_support == QueryLenSupport.SINGLE_ONLY:
|
||||
assert self.reorder_batch_threshold == 1, (
|
||||
f"reorder_batch_threshold must be 1 when query_len_support is "
|
||||
f"SINGLE_ONLY, got {self.reorder_batch_threshold}"
|
||||
)
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(MLACommonBackend,
|
||||
MLACommonBackend.get_supported_head_sizes,
|
||||
MLACommonBackend_MluHijack.get_supported_head_sizes)
|
||||
MluHijackObject.apply_hijack(MLACommonMetadataBuilder,
|
||||
MLACommonMetadataBuilder.__init__,
|
||||
MLACommonMetadataBuilder_MluHijack.__init__)
|
||||
|
||||
|
||||
class FlashMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHMLA_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["FlashMLAMetadata"]:
|
||||
return FlashMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
|
||||
return FlashMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashMLAImpl"]:
|
||||
return FlashMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int, # assumed to be 1 for MLA
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
return (1, num_blocks, num_kv_heads, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_scale_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
) -> tuple[int, ...]:
|
||||
return (1, num_blocks, num_kv_heads, block_size)
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [576, 512]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLAPrefillMetadata(MLACommonPrefillMetadata):
|
||||
num_prefills: int = -1 # for gather_cache
|
||||
max_seq_len: int = -1 # for attn forward
|
||||
|
||||
@property
|
||||
def block_tables(self):
|
||||
return self.block_table
|
||||
|
||||
@property
|
||||
def context_chunk_cu_seq_lens(self):
|
||||
if self.chunked_context is None:
|
||||
return None
|
||||
return self.chunked_context.cu_seq_lens
|
||||
|
||||
@property
|
||||
def context_chunk_starts(self):
|
||||
if self.chunked_context is None:
|
||||
return None
|
||||
return self.chunked_context.starts
|
||||
|
||||
@property
|
||||
def context_chunk_seq_tot(self):
|
||||
if self.chunked_context is None:
|
||||
return None
|
||||
return self.chunked_context.seq_tot
|
||||
|
||||
@property
|
||||
def context_chunk_max_seq_lens(self):
|
||||
if self.chunked_context is None:
|
||||
return None
|
||||
return self.chunked_context.max_seq_lens
|
||||
|
||||
@property
|
||||
def context_chunk_workspace(self):
|
||||
if self.chunked_context is None:
|
||||
return None
|
||||
return self.chunked_context.workspace
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
tile_scheduler_metadata: torch.Tensor
|
||||
num_splits: torch.Tensor
|
||||
|
||||
# add for mlu rope and attn forward
|
||||
query_start_loc: torch.Tensor # for rope
|
||||
max_query_len: int # for rope
|
||||
max_seq_len:int = -1 # for attn forward
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLAMetadata(MLACommonMetadata):
|
||||
num_prefill_tokens: Optional[int] = None
|
||||
|
||||
|
||||
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
|
||||
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
|
||||
reorder_batch_threshold: int = 128 # process small prefills with decode pathway
|
||||
# ^ TODO(matt): tune this
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_spec, layer_names, vllm_config, device, FlashMLAMetadata
|
||||
)
|
||||
|
||||
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
|
||||
self.cg_buf_tile_scheduler_metadata = None
|
||||
self.cg_buf_num_splits = None
|
||||
self.is_fp8_kvcache = vllm_config.cache_config.cache_dtype.startswith("fp8")
|
||||
|
||||
self.cg_buf_tile_scheduler_metadata = None
|
||||
self.cg_buf_num_splits = None
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: 1. set decoder_query_len for mtp
|
||||
@brief: 2. init chunk workspace for prefix_caching only
|
||||
@brief: 3. set prefill_metadata_cls
|
||||
@brief: 4. add deepseek v3.2 infos
|
||||
'''
|
||||
cache_config = vllm_config.cache_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
speculative_config = vllm_config.speculative_config
|
||||
self.num_speculative_tokens = (speculative_config.num_speculative_tokens
|
||||
if speculative_config is not None else 0)
|
||||
self.decoder_query_len = 1 + self.num_speculative_tokens
|
||||
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
self.is_deepseek_v32 = self.model_config.hf_text_config.model_type == "deepseek_v32"
|
||||
|
||||
self.enable_caching = cache_config.enable_prefix_caching
|
||||
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
|
||||
if (not self.is_deepseek_v32 and not self.chunked_prefill_enabled and
|
||||
(mlu_envs.VLLM_V1_USE_UNCHUNK_SCHED and self.enable_caching)):
|
||||
self.chunked_prefill_workspace_size = min(
|
||||
# Max sure there is enough for 8 full length request or at least
|
||||
# 4 pages of cache per request
|
||||
max(
|
||||
8 * self.model_config.max_model_len, 4 *
|
||||
scheduler_config.max_num_seqs * cache_config.block_size),
|
||||
# For long-context models try not to over-allocate limiting
|
||||
# kv-cache space, limiting it to 64k tokens,
|
||||
# which would result in the workspace being:
|
||||
# 2*(576)*(64*1024) = 144mb
|
||||
# (assuming 576 MLA head dim, and fp16)
|
||||
# which would result in up-projected context being
|
||||
# 2*(192*128)*(64*1024) = 3gb
|
||||
# (assuming 192 QK head dim, 128 heads, and fp16)
|
||||
128 * 1024)
|
||||
assert self.chunked_prefill_workspace_size >= \
|
||||
scheduler_config.max_num_seqs * cache_config.block_size
|
||||
self.chunked_prefill_workspace = torch.empty(
|
||||
(self.chunked_prefill_workspace_size,
|
||||
self.model_config.get_head_size()),
|
||||
dtype=self.model_config.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.prefill_metadata_cls = FlashMLAPrefillMetadata
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
# We now want to reorder the batch so that the "decode" requests are and
|
||||
# the front and the "prefill" requests are at the using the least amount
|
||||
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
|
||||
# where attention is likely memory-bound and "prefill" to mean requests
|
||||
# where attention is likely compute-bound, TODO(lucas): figure out a
|
||||
# better naming here)
|
||||
decodes = []
|
||||
prefills = []
|
||||
num_decode_tokens = 0
|
||||
num_prefill_tokens = 0
|
||||
# mlu v1 mtp forces decoder_query_len = 1 for k > 1, so we should set again
|
||||
self.decoder_query_len = 1 + self.num_speculative_tokens
|
||||
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
# for now treat 1 scheduled token as "decode" even if its not,
|
||||
# we should update this to something like < 8 in the future but
|
||||
# currently the TritonMLA._forward_decode only supports
|
||||
# num_tokens = 1
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: record prefill and decode requests and token nums to call
|
||||
chunked fa and single-query attn respectively in forward.
|
||||
@Notes: decodes need all prompt tokens are computed.
|
||||
'''
|
||||
req_index = input_batch.req_id_to_index.get(req_id)
|
||||
all_prompt_tokens_has_computed = (
|
||||
input_batch.num_computed_tokens_cpu[req_index] >=
|
||||
input_batch.num_prompt_tokens[req_index])
|
||||
if num_tokens <= self.decoder_query_len and all_prompt_tokens_has_computed:
|
||||
decodes.append(i)
|
||||
num_decode_tokens += num_tokens
|
||||
else:
|
||||
prefills.append(i)
|
||||
num_prefill_tokens += num_tokens
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
# We hope that this is fairly minimal since decodes
|
||||
# should be around for a number of iterations so hopefully they are
|
||||
# relatively stationary (and new request are generally appended to the
|
||||
# persistent batch so already should be at the back)
|
||||
# To achieve this we loop over the decodes in descending order and
|
||||
# the prefills in ascending order. We swap decodes from the "back"
|
||||
# i.e. past where the last decode should be in the reodorered with
|
||||
# prefills from the front of the batch.
|
||||
# `decodes` and `prefills` are already in ascending order just based on
|
||||
# the above loop
|
||||
num_decodes = len(decodes)
|
||||
num_prefills = len(prefills)
|
||||
modified_batch = False
|
||||
|
||||
for i in range(1, min(num_decodes, num_prefills) + 1):
|
||||
# If the decode is at the "back" of the batch, i, we can swap it
|
||||
# with the prefill closest to the front of the batch
|
||||
decode_idx = decodes[num_decodes - i]
|
||||
if decode_idx < num_decodes:
|
||||
break
|
||||
|
||||
input_batch.swap_states(prefills[i - 1], decode_idx)
|
||||
modified_batch = True
|
||||
|
||||
return modified_batch
|
||||
|
||||
def _build_decode(
|
||||
self,
|
||||
block_table_tensor: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
max_query_len: int,
|
||||
max_seq_len: int,
|
||||
) -> FlashMLADecodeMetadata:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: set tile_scheduler_metadata and num_splits to None.
|
||||
@brief: set dcp_tot_seq_lens_device.
|
||||
'''
|
||||
return FlashMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens,
|
||||
tile_scheduler_metadata=None,
|
||||
num_splits=None,
|
||||
dcp_tot_seq_lens=None,
|
||||
# for mlu
|
||||
max_seq_len=max_seq_len,
|
||||
query_start_loc=query_start_loc,
|
||||
max_query_len=max_query_len
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: MLUCommonAttentionMetadata) -> M:
|
||||
"""
|
||||
This method builds the metadata for full cudagraph capture.
|
||||
Currently, only decode is supported for full cudagraphs with MLA.
|
||||
"""
|
||||
m = common_attn_metadata
|
||||
if m.infer_mode == MLUInferMode.DECODE_ONLY:
|
||||
assert m.num_reqs * m.max_query_len == m.num_actual_tokens, \
|
||||
"MLA only supports decode-only full CUDAGraph capture. " \
|
||||
"Make sure all cudagraph capture sizes <= max_num_seq."
|
||||
|
||||
return self.build(0, m)
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: MLUCommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
input_batch: "InputBatch" = None) -> M:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
# Note(simon): be careful about the CPU <> GPU memory movement in this
|
||||
# function. We should avoid GPU -> CPU sync as much as possible because
|
||||
# it blocks on all previous kernels.
|
||||
device = self.device
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
|
||||
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
|
||||
num_computed_tokens_cpu = (common_attn_metadata.seq_lens_cpu -
|
||||
query_seq_lens_cpu)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: support normal and mtp input split
|
||||
'''
|
||||
if input_batch is None:
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||
split_decodes_and_prefills(common_attn_metadata,
|
||||
self.decoder_query_len)
|
||||
else:
|
||||
num_decodes, num_prefills = input_batch.split_decodes_and_prefills()
|
||||
num_decode_tokens = common_attn_metadata.query_start_loc_cpu[num_decodes].item()
|
||||
num_prefill_tokens = num_tokens - num_decode_tokens
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
assert num_decode_tokens + num_prefill_tokens == num_tokens
|
||||
|
||||
prefill_metadata = None
|
||||
if num_prefills > 0:
|
||||
reqs_start = num_decodes # prefill_start
|
||||
|
||||
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
|
||||
max_context_len_cpu = context_lens_cpu.max().item()
|
||||
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: avoid buffer missing when prefill_only + mlugraph
|
||||
'''
|
||||
if num_decodes > 0:
|
||||
prefill_query_start_loc = query_start_loc[
|
||||
reqs_start:] - query_start_loc[reqs_start]
|
||||
else:
|
||||
prefill_query_start_loc= query_start_loc
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
chunked_context_metadata = None
|
||||
if ((self.chunked_prefill_enabled or
|
||||
(mlu_envs.VLLM_V1_USE_UNCHUNK_SCHED and
|
||||
self.enable_caching and
|
||||
common_attn_metadata.is_chunked)
|
||||
) and num_prefills > 0 and max_context_len_cpu > 0):
|
||||
# NOTE: it is recommend you read the `Chunked Prefill` section
|
||||
# in the comment at the top of the file before trying to
|
||||
# understand the following code
|
||||
|
||||
# currently we allocate an equal amount of workspace for each
|
||||
# prefill in the batch, we could probably use a more advanced
|
||||
# algorithm here and allocate more workspace to prefills with
|
||||
# longer context lengths
|
||||
if self.is_deepseek_v32:
|
||||
max_context_chunk = self.max_model_len
|
||||
else:
|
||||
max_context_chunk = (self.chunked_prefill_workspace_size //
|
||||
num_prefills_with_context_cpu)
|
||||
|
||||
if self.aot_schedule:
|
||||
# align max_context_chunk to page_size by rounding down,
|
||||
# currently the `gather_cache` kernel cannot handle
|
||||
# `context_chunk_starts` that are not aligned to page_size
|
||||
max_context_chunk = round_down(max_context_chunk,
|
||||
self.page_size)
|
||||
|
||||
assert max_context_chunk > 0
|
||||
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
|
||||
|
||||
# if `max_context_chunk = 256`, `num_chunks = 3`, and
|
||||
# `num_prefills_with_context = 4`, create a tensor that looks
|
||||
# like
|
||||
# [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
|
||||
# Note(simon): this is done in CPU because of downstream's
|
||||
# of `to_list`.
|
||||
chunk_starts = \
|
||||
torch.arange(num_chunks, dtype=torch.int32) \
|
||||
.unsqueeze(1).expand(-1, num_prefills) \
|
||||
* max_context_chunk
|
||||
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
|
||||
chunk_starts + max_context_chunk)
|
||||
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
|
||||
|
||||
cu_seq_lens_cpu = torch.zeros(num_chunks,
|
||||
num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
pin_memory=True)
|
||||
torch.cumsum(chunk_seq_lens,
|
||||
dim=1,
|
||||
out=cu_seq_lens_cpu[:, 1:],
|
||||
dtype=torch.int32)
|
||||
|
||||
chunked_context_metadata_cls = \
|
||||
FlashMLAPrefillMetadata.ChunkedContextMetadata
|
||||
|
||||
chunked_context_metadata = \
|
||||
chunked_context_metadata_cls(
|
||||
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
|
||||
starts=chunk_starts.to(device, non_blocking=True),
|
||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
seq_lens=chunk_seq_lens,
|
||||
workspace=getattr(self, "chunked_prefill_workspace", None),
|
||||
)
|
||||
|
||||
if not self.is_deepseek_v32:
|
||||
assert max(chunked_context_metadata.max_seq_lens) <= \
|
||||
self.chunked_prefill_workspace_size
|
||||
|
||||
prefill_metadata = self.prefill_metadata_cls(
|
||||
block_table=block_table_tensor[reqs_start:, ...],
|
||||
query_start_loc=prefill_query_start_loc,
|
||||
max_query_len=max_query_len,
|
||||
chunked_context=chunked_context_metadata,
|
||||
# for mlu
|
||||
num_prefills=num_prefills,
|
||||
max_seq_len=common_attn_metadata.seq_lens_cpu[reqs_start:].max().item(),
|
||||
)
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
decode_metadata = self._build_decode(
|
||||
block_table_tensor=block_table_tensor[:num_decodes, ...],
|
||||
seq_lens=seq_lens[:num_decodes],
|
||||
query_start_loc=query_start_loc[:num_decodes + 1],
|
||||
max_query_len=query_seq_lens_cpu[:num_decodes].max().item(),
|
||||
max_seq_len=common_attn_metadata.seq_lens_cpu[:num_decodes].max().item(),
|
||||
)
|
||||
|
||||
attn_metadata = self.metadata_cls(
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
max_seq_len=common_attn_metadata.max_seq_len,
|
||||
num_actual_tokens=num_tokens,
|
||||
query_start_loc=query_start_loc,
|
||||
slot_mapping=slot_mapping,
|
||||
head_dim=self.model_config.get_head_size(),
|
||||
# MLACommonMetadata Chunk prefill specific
|
||||
num_decodes=num_decodes,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
prefill=prefill_metadata,
|
||||
decode=decode_metadata,
|
||||
)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: MLUCommonAttentionMetadata) -> bool:
|
||||
return common_attn_metadata.max_query_len == self.decoder_query_len
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class FlashMLAImpl(MLUFlashAttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashMLAImpl")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashMLAMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
kwargs: Optional[dict[str, Any]] = {},
|
||||
) -> torch.Tensor:
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for FlashAttentionImpl")
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
|
||||
out_lse = None
|
||||
|
||||
# use default common metadata if kwargs does not have common_metadata
|
||||
common_metadata: MLUCommonAttentionMetadata = kwargs.get("common_metadata", None)
|
||||
if common_metadata is None:
|
||||
common_metadata = get_common_metadata()
|
||||
|
||||
only_prefill = kwargs.get("only_prefill", False)
|
||||
only_decode = kwargs.get("only_decode", False)
|
||||
attn_bias = kwargs.get("attn_bias", None)
|
||||
|
||||
assert only_prefill != only_decode, "only_prefill and only_decode cannot be True and False at the same time."
|
||||
|
||||
if only_prefill:
|
||||
cu_seqlens_q = attn_metadata.prefill.query_start_loc
|
||||
cu_seqlens_kv = common_metadata.query_start_loc
|
||||
seqused_k = common_metadata.seq_lens[attn_metadata.num_decodes:]
|
||||
max_seqlen_q = attn_metadata.prefill.max_query_len
|
||||
max_seqlen_k = attn_metadata.prefill.max_seq_len
|
||||
block_table = attn_metadata.prefill.block_table
|
||||
num_actual_tokens = attn_metadata.num_prefill_tokens
|
||||
else:
|
||||
cu_seqlens_q = None # nouse
|
||||
cu_seqlens_kv = None # nouse
|
||||
seqused_k = common_metadata.seq_lens[:attn_metadata.num_decodes]
|
||||
max_seqlen_q = None # nouse
|
||||
max_seqlen_k = common_metadata.max_seq_len
|
||||
block_table = attn_metadata.decode.block_table
|
||||
num_actual_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
skip_process_cache = ((self.use_mla
|
||||
and (common_metadata.is_prefill_only
|
||||
or self.use_fused_mla_qkv
|
||||
or only_prefill))
|
||||
or self.kv_sharing_target_layer_name is not None)
|
||||
|
||||
kv_cache_, kv_cache_scale_, kv_cache_index_ = kv_cache
|
||||
key_cache = kv_cache_[0]
|
||||
value_cache = None if self.use_mla else kv_cache_[1]
|
||||
key_cache_scale, value_cache_scale = None, None
|
||||
if kv_cache_scale_.numel() > 0:
|
||||
key_cache_scale = kv_cache_scale_[0]
|
||||
value_cache_scale = None if self.use_mla else kv_cache_scale_[1]
|
||||
if not skip_process_cache:
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
mlu_ops.quant_to_paged_cache(
|
||||
k=key[:num_actual_tokens],
|
||||
v=(None if self.use_mla else value[:num_actual_tokens]),
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
k_cache_quant_scale=key_cache_scale,
|
||||
v_cache_quant_scale=value_cache_scale,
|
||||
slot_mapping=attn_metadata.slot_mapping.flatten(),
|
||||
)
|
||||
else:
|
||||
mlu_ops.reshape_paged_cache(
|
||||
k=key[:num_actual_tokens],
|
||||
v=(None if self.use_mla else value[:num_actual_tokens]),
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
slot_mapping=attn_metadata.slot_mapping.flatten()
|
||||
)
|
||||
|
||||
alibi_slopes = None if self.alibi_slopes is None else \
|
||||
self.alibi_slopes.repeat(seqused_k.shape[0], 1)
|
||||
|
||||
if kwargs.get("model_type", "") == "deepseek_v32":
|
||||
from vllm_mlu.model_executor.models.sp_utils import get_sp_forward_context
|
||||
sp_context = get_sp_forward_context()
|
||||
if sp_context is not None and sp_context.is_v32:
|
||||
num_actual_tokens = sp_context.sp_attn_metadata.num_prefill_tokens
|
||||
decode_query = query[:num_actual_tokens].view(-1, self.num_heads, self.head_size)
|
||||
head_size_v = value.shape[-1] if self.use_mla else self.head_size
|
||||
decode_output = output[:num_actual_tokens].view(-1, self.num_heads, head_size_v)
|
||||
decode_query = query.unsqueeze(1) # see tokens as batch dim
|
||||
decode_output = decode_output.unsqueeze(1)
|
||||
q_quant_scale = kwargs.get("q_quant_scale", None)
|
||||
if q_quant_scale is not None:
|
||||
q_quant_scale = q_quant_scale[:num_actual_tokens].view(-1, self.num_heads)
|
||||
q_quant_scale = q_quant_scale.unsqueeze(1)
|
||||
mlu_ops.single_query_cached_kv_attn(
|
||||
q=decode_query,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
out=decode_output,
|
||||
block_tables=kwargs.get("new_block_tables", None),
|
||||
context_lens=kwargs.get("new_context_lens", None),
|
||||
k_cache_quant_scale=key_cache_scale,
|
||||
v_cache_quant_scale=value_cache_scale,
|
||||
alibi_slopes=alibi_slopes,
|
||||
max_contxt_len=kwargs.get("index_topk", None),
|
||||
windows_size_left=(-1 if self.sliding_window is None else self.sliding_window[0]),
|
||||
windows_size_right=(-1 if self.sliding_window is None else self.sliding_window[0]),
|
||||
softmax_scale=self.scale,
|
||||
head_size_v=(-1 if not self.use_mla else head_size_v),
|
||||
compute_dtype=compute_dtype,
|
||||
q_quant_scale=q_quant_scale,
|
||||
decoder_attn_dtype=self.decoder_attn_dtype,
|
||||
mask=attn_bias,
|
||||
)
|
||||
return output
|
||||
|
||||
if common_metadata.is_prefill_only or only_prefill:
|
||||
# prefill only
|
||||
prefill_causal = kwargs.get("prefill_causal", True)
|
||||
cu_seqlens_q = kwargs.get("cu_seq_lens_q", cu_seqlens_q)
|
||||
cu_seqlens_kv = kwargs.get("cu_seq_lens_kv", cu_seqlens_kv)
|
||||
max_seqlen_q = kwargs.get("max_seq_len_q", max_seqlen_q)
|
||||
max_seqlen_k = kwargs.get("max_seq_len_kv", max_seqlen_k)
|
||||
return_lse = kwargs.get("return_lse", False)
|
||||
num_prefill_query_tokens = common_metadata.num_prefill_query_tokens
|
||||
num_prefill_kv_tokens = common_metadata.num_prefill_kv_tokens
|
||||
use_f32 = attn_bias is not None and attn_bias.dtype == torch.float32
|
||||
if use_f32:
|
||||
f32_output = torch.empty_like(output, dtype=torch.float32)
|
||||
attn_output_list = mlu_ops.flash_attention(
|
||||
q=query[:num_prefill_query_tokens].to(torch.float32) if use_f32 else query[:num_prefill_query_tokens],
|
||||
k=key[:num_prefill_kv_tokens].to(torch.float32) if use_f32 else key[:num_prefill_kv_tokens],
|
||||
v=value[:num_prefill_kv_tokens].to(torch.float32) if use_f32 else value[:num_prefill_kv_tokens],
|
||||
out=f32_output[:num_prefill_query_tokens] if use_f32 else output[:num_prefill_query_tokens],
|
||||
cu_seq_lens_q=cu_seqlens_q,
|
||||
cu_seq_lens_kv=cu_seqlens_kv,
|
||||
alibi_slope=alibi_slopes,
|
||||
attn_bias=attn_bias,
|
||||
max_seq_len_q=max_seqlen_q,
|
||||
max_seq_len_kv=max_seqlen_k,
|
||||
softmax_scale=self.scale,
|
||||
is_causal=prefill_causal,
|
||||
window_size_left=(-1 if self.sliding_window is None else self.sliding_window[0]),
|
||||
window_size_right=(-1 if self.sliding_window is None else self.sliding_window[1]),
|
||||
compute_dtype=self.prefill_compute_dtype,
|
||||
return_lse=return_lse,
|
||||
q_quant_dtype=self.prefill_q_dtype,
|
||||
k_quant_dtype=self.prefill_k_dtype,
|
||||
v_quant_dtype=self.prefill_v_dtype
|
||||
)
|
||||
if use_f32:
|
||||
output[:num_prefill_query_tokens].copy_(f32_output[:num_prefill_query_tokens])
|
||||
|
||||
if return_lse:
|
||||
out_lse = attn_output_list[1]
|
||||
else:
|
||||
batch_size = block_table.shape[0]
|
||||
# decode only
|
||||
decode_query = query[:num_actual_tokens].view(batch_size, -1, self.num_heads, self.head_size)
|
||||
head_size_v = value.shape[-1] if self.use_mla else self.head_size
|
||||
decode_output = output[:num_actual_tokens].view(batch_size, -1, self.num_heads, head_size_v)
|
||||
q_quant_scale = kwargs.get("q_quant_scale", None)
|
||||
if q_quant_scale is not None:
|
||||
q_quant_scale = q_quant_scale[:num_actual_tokens].view(batch_size, -1, self.num_heads)
|
||||
mlu_ops.single_query_cached_kv_attn(
|
||||
q=decode_query,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
out=decode_output,
|
||||
block_tables=block_table,
|
||||
context_lens=seqused_k,
|
||||
k_cache_quant_scale=key_cache_scale,
|
||||
v_cache_quant_scale=value_cache_scale,
|
||||
alibi_slopes=alibi_slopes,
|
||||
max_contxt_len=max_seqlen_k,
|
||||
windows_size_left=(-1 if self.sliding_window is None else self.sliding_window[0]),
|
||||
windows_size_right=(-1 if self.sliding_window is None else self.sliding_window[0]),
|
||||
softmax_scale=self.scale,
|
||||
head_size_v=(-1 if not self.use_mla else head_size_v),
|
||||
compute_dtype=attn_metadata.decode.compute_dtype,
|
||||
q_quant_scale=q_quant_scale,
|
||||
decoder_attn_dtype=self.decoder_attn_dtype,
|
||||
mask=attn_bias,
|
||||
)
|
||||
|
||||
return output if out_lse is None else (output, out_lse)
|
||||
295
vllm_mlu/v1/attention/backends/utils.py
Normal file
295
vllm_mlu/v1/attention/backends/utils.py
Normal file
@@ -0,0 +1,295 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
|
||||
|
||||
COMMON_METADATA_STR: str = "common_metadata"
|
||||
|
||||
|
||||
class MLUInferMode(Enum):
|
||||
CHUNKED = 1
|
||||
PREFILL_ONLY = 2
|
||||
DECODE_ONLY = 3
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
max_query_len,
|
||||
max_computed_tokens,
|
||||
uniform_decode_query_len: int = 1,
|
||||
) -> Enum:
|
||||
if max_query_len <= uniform_decode_query_len:
|
||||
return MLUInferMode.DECODE_ONLY
|
||||
elif max_computed_tokens == 0:
|
||||
return MLUInferMode.PREFILL_ONLY
|
||||
else:
|
||||
return MLUInferMode.CHUNKED
|
||||
|
||||
@property
|
||||
def is_prefill_only(self):
|
||||
return self == MLUInferMode.PREFILL_ONLY
|
||||
|
||||
@property
|
||||
def is_decode_only(self):
|
||||
return self == MLUInferMode.DECODE_ONLY
|
||||
|
||||
@property
|
||||
def is_chunked(self):
|
||||
return self == MLUInferMode.CHUNKED
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLUCommonAttentionMetadata(CommonAttentionMetadata):
|
||||
"""
|
||||
Attention metadata attributes that can be shared by layers in different KV
|
||||
cache groups and thus having different block table.
|
||||
"""
|
||||
seq_start_loc: torch.Tensor | None = None
|
||||
seq_start_loc_cpu: torch.Tensor | None = None
|
||||
"""(batch_size + 1,), the start location of each request in the input key/value sequence."""
|
||||
num_input_tokens: int = 0
|
||||
"""Number of query tokens with padding."""
|
||||
num_prefill_query_tokens: int = 0
|
||||
"""Number of query tokens in prefill phase."""
|
||||
num_prefill_kv_tokens: int = 0
|
||||
"""Number of key/value tokens in prefill phase."""
|
||||
infer_mode: MLUInferMode | None = None
|
||||
"""Inference mode for flash attention."""
|
||||
|
||||
@property
|
||||
def is_prefill_only(self):
|
||||
return self.infer_mode == MLUInferMode.PREFILL_ONLY
|
||||
|
||||
@property
|
||||
def is_decode_only(self):
|
||||
return self.infer_mode == MLUInferMode.DECODE_ONLY
|
||||
|
||||
@property
|
||||
def is_chunked(self):
|
||||
return self.infer_mode == MLUInferMode.CHUNKED
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
query_start_loc, query_start_loc_cpu,
|
||||
seq_lens, seq_lens_cpu,
|
||||
num_computed_tokens_cpu,
|
||||
num_reqs, num_actual_tokens, max_query_len,
|
||||
block_table_tensor, slot_mapping,
|
||||
seq_start_loc, is_start_loc_match,
|
||||
num_input_tokens: int = 0,
|
||||
num_speculative_tokens: int = 0,
|
||||
has_prefill_reqs: bool = False
|
||||
):
|
||||
"""Build attention metadata for MLU inference.
|
||||
|
||||
Args:
|
||||
has_prefill_reqs: Whether there are pending prefill requests with chunked.
|
||||
"""
|
||||
infer_mode = None
|
||||
if is_start_loc_match:
|
||||
infer_mode = MLUInferMode.PREFILL_ONLY
|
||||
elif max_query_len <= (1 + num_speculative_tokens) and (not has_prefill_reqs):
|
||||
infer_mode = MLUInferMode.DECODE_ONLY
|
||||
else:
|
||||
infer_mode = MLUInferMode.CHUNKED
|
||||
num_input_tokens = (
|
||||
num_actual_tokens if num_input_tokens == 0
|
||||
else num_input_tokens
|
||||
)
|
||||
max_seq_len = int(seq_lens_cpu.max())
|
||||
return cls(query_start_loc=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_seq_len,
|
||||
block_table_tensor=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
seq_start_loc=seq_start_loc,
|
||||
seq_start_loc_cpu=seq_start_loc.to("cpu", non_blocking=True),
|
||||
num_input_tokens=num_input_tokens,
|
||||
infer_mode=infer_mode,
|
||||
num_prefill_query_tokens=num_actual_tokens,
|
||||
num_prefill_kv_tokens=num_actual_tokens)
|
||||
|
||||
def save(self, infer_phase: str):
|
||||
csv_path = os.getenv("VLLM_STEP_INPUT_CSV_PATH", None)
|
||||
if not csv_path:
|
||||
return
|
||||
|
||||
header = [
|
||||
"infer_phase", "infer_mode", "num_reqs", "num_actual_tokens",
|
||||
"max_query_len", "max_seq_len", "query_start_loc", "seq_lens"
|
||||
]
|
||||
data = [
|
||||
infer_phase, self.infer_mode, self.num_reqs,
|
||||
self.num_actual_tokens, self.max_query_len, self.max_seq_len,
|
||||
str(self.query_start_loc_cpu.tolist()),
|
||||
str(self.seq_lens_cpu.tolist())
|
||||
]
|
||||
data_dict = dict(zip(header, data))
|
||||
df_csv = pd.DataFrame(data_dict, index=[0])
|
||||
|
||||
if infer_phase == "RealInfer":
|
||||
print(df_csv.to_string())
|
||||
|
||||
try:
|
||||
if dir_path := os.path.dirname(csv_path):
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
append = False
|
||||
if os.path.isfile(csv_path):
|
||||
try:
|
||||
df_old = pd.read_csv(csv_path)
|
||||
append = (df_old.columns.tolist() == header)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Existing {csv_path} failed to be read and will be overwritten")
|
||||
if append:
|
||||
df_csv.to_csv(csv_path, mode='a', header=False, index=False)
|
||||
else:
|
||||
df_csv.to_csv(csv_path, index=False)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Invalid VLLM_STEP_INPUT_CSV_PATH: {csv_path} to dump step inputs, Error: {e}")
|
||||
|
||||
|
||||
def get_common_metadata_from_attn_metadata(
|
||||
attn_metadata) -> Union[MLUCommonAttentionMetadata, None]:
|
||||
"""
|
||||
Get MLUCommonAttentionMetadata for MLU-V1 inference.
|
||||
Use outside of set_forward_context().
|
||||
"""
|
||||
if attn_metadata is None:
|
||||
return
|
||||
|
||||
assert (isinstance(attn_metadata, dict)
|
||||
and COMMON_METADATA_STR in attn_metadata), \
|
||||
f"MLU-V1 only support type(attn_metadata)=dict, and " + \
|
||||
f"{COMMON_METADATA_STR} in attn_metadata. Now, type(attn_metadata)=" + \
|
||||
f"{type(attn_metadata)}, or {COMMON_METADATA_STR} not in attn_metadata."
|
||||
return attn_metadata[COMMON_METADATA_STR]
|
||||
|
||||
|
||||
def get_common_metadata() -> Union[MLUCommonAttentionMetadata, None]:
|
||||
"""
|
||||
Get MLUCommonAttentionMetadata for MLU-V1 inference.
|
||||
Use inside of set_forward_context().
|
||||
"""
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
return get_common_metadata_from_attn_metadata(attn_metadata)
|
||||
|
||||
|
||||
def unpad_common_attn_metadata(
|
||||
common_metadata: MLUCommonAttentionMetadata,
|
||||
num_reqs: int,
|
||||
num_scheduled_tokens: int,
|
||||
):
|
||||
"""
|
||||
Unpad MLUCommonAttentionMetadata by given num_reqs and num_scheduled_tokens.
|
||||
"""
|
||||
common_metadata.num_reqs = num_reqs
|
||||
common_metadata.num_input_tokens = num_scheduled_tokens
|
||||
common_metadata.query_start_loc = common_metadata.query_start_loc[:num_reqs + 1]
|
||||
common_metadata.query_start_loc_cpu = common_metadata.query_start_loc_cpu[:num_reqs + 1]
|
||||
common_metadata.seq_start_loc = common_metadata.seq_start_loc[:num_reqs + 1]
|
||||
common_metadata.seq_lens = common_metadata.seq_lens[:num_reqs]
|
||||
common_metadata.seq_lens_cpu = common_metadata.seq_lens_cpu[:num_reqs]
|
||||
common_metadata.block_table_tensor = common_metadata.block_table_tensor[:num_reqs]
|
||||
|
||||
def reorder_batch_to_split_decodes_and_prefills(
|
||||
input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput",
|
||||
decode_threshold: int = 1,
|
||||
) -> bool:
|
||||
"""
|
||||
Reorders the batch to split into prefill and decode requests; places all
|
||||
requests with <= decode_threshold tokens at the front of the batch.
|
||||
|
||||
Returns:
|
||||
True if the batch was modified, False otherwise.
|
||||
"""
|
||||
# We now want to reorder the batch into decode → extend → prefill order
|
||||
# where:
|
||||
# decode: request with num_scheduled_tokens <= decode_threshold
|
||||
# extend: non-decode request with existing context
|
||||
# prefill: non-decode request with no existing context
|
||||
# NOTE for now we loosely use "decode" to mean requests where attention is
|
||||
# likely memory-bound and "prefill" to mean requests where attention is
|
||||
# likely compute-bound,
|
||||
num_reqs = len(input_batch.req_ids)
|
||||
num_scheduled_tokens = [
|
||||
scheduler_output.num_scheduled_tokens[id] for id in input_batch.req_ids
|
||||
]
|
||||
num_scheduled_tokens_np = np.array(num_scheduled_tokens)
|
||||
num_computed_tokens_np = input_batch.num_computed_tokens_cpu[:num_reqs]
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: enhence decode mode condition that all prompt tokens are computed.
|
||||
'''
|
||||
# is_decode = num_scheduled_tokens_np <= decode_threshold
|
||||
is_decode = (
|
||||
(num_scheduled_tokens_np <= decode_threshold)
|
||||
& (num_computed_tokens_np >= input_batch.num_prompt_tokens[:num_reqs])
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
is_extend = (~is_decode) & (num_computed_tokens_np > 0)
|
||||
is_prefill = (~is_decode) & (num_computed_tokens_np == 0)
|
||||
|
||||
# Desired order: decode → extend → prefill
|
||||
req_regions = np.zeros(is_decode.shape, dtype=np.int32) # 0 = decode by default
|
||||
req_regions[is_extend] = 1
|
||||
req_regions[is_prefill] = 2
|
||||
|
||||
num_decodes = int(is_decode.sum())
|
||||
num_extends = int(is_extend.sum())
|
||||
|
||||
target_regions = np.zeros(num_reqs, dtype=np.int32)
|
||||
target_regions[num_decodes : num_decodes + num_extends] = 1
|
||||
target_regions[num_decodes + num_extends :] = 2
|
||||
|
||||
needs_swap = req_regions != target_regions
|
||||
|
||||
if not needs_swap.any():
|
||||
return False
|
||||
|
||||
# Extract indices that need swapping and sort by target region
|
||||
orig_indices = np.where(needs_swap)[0]
|
||||
sorted_order = np.argsort(req_regions[needs_swap], kind="stable")
|
||||
src_indices = orig_indices[sorted_order]
|
||||
|
||||
src_dest_map = {int(src): int(dst) for src, dst in zip(src_indices, orig_indices)}
|
||||
|
||||
for src in src_dest_map:
|
||||
dst = src_dest_map[src]
|
||||
while src != dst:
|
||||
input_batch.swap_states(src, dst)
|
||||
# Mark dst as done by updating its destination to itself
|
||||
next_dst = src_dest_map.get(dst, dst)
|
||||
src_dest_map[dst] = dst
|
||||
dst = next_dst
|
||||
|
||||
return True
|
||||
Reference in New Issue
Block a user