[Model] Support DeepSeek-V4
This commit is contained in:
3
vllm_mlu/v1/__init__.py
Normal file
3
vllm_mlu/v1/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
3
vllm_mlu/v1/attention/__init__.py
Normal file
3
vllm_mlu/v1/attention/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
3
vllm_mlu/v1/attention/backends/__init__.py
Normal file
3
vllm_mlu/v1/attention/backends/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
1050
vllm_mlu/v1/attention/backends/flash_attn.py
Normal file
1050
vllm_mlu/v1/attention/backends/flash_attn.py
Normal file
File diff suppressed because it is too large
Load Diff
404
vllm_mlu/v1/attention/backends/gdn_attn.py
Normal file
404
vllm_mlu/v1/attention/backends/gdn_attn.py
Normal file
@@ -0,0 +1,404 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from collections import OrderedDict, deque
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.v1.attention.backends.gdn_attn import (GDNAttentionMetadataBuilder,
|
||||
GDNAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
||||
compute_causal_conv1d_metadata,
|
||||
split_decodes_and_prefills,)
|
||||
|
||||
|
||||
class DeviceAwareLocalIdMapper:
|
||||
def __init__(self, batch_size: int):
|
||||
if batch_size <= 0:
|
||||
raise ValueError("batch_size must be positive")
|
||||
self.batch_size = batch_size
|
||||
self.global_to_local: OrderedDict[int, int] = OrderedDict()
|
||||
self.local_to_global = {}
|
||||
self.available_local_ids = deque(range(batch_size))
|
||||
|
||||
def batch_get_local_ids(self, global_id_tensor: torch.Tensor) -> torch.Tensor:
|
||||
original_device = global_id_tensor.device
|
||||
original_shape = global_id_tensor.shape
|
||||
|
||||
flat_global_cpu = global_id_tensor.cpu().numpy().ravel()
|
||||
num_elements = flat_global_cpu.size
|
||||
local_ids_cpu = torch.empty(num_elements, dtype=global_id_tensor.dtype)
|
||||
|
||||
g2l = self.global_to_local
|
||||
unique_miss_set = set()
|
||||
|
||||
# Pass 1: handle hits and collect unique misses
|
||||
for i, gid in enumerate(flat_global_cpu):
|
||||
if gid in g2l:
|
||||
local_id = g2l[gid]
|
||||
local_ids_cpu[i] = local_id
|
||||
g2l.move_to_end(gid)
|
||||
else:
|
||||
local_ids_cpu[i] = -1
|
||||
unique_miss_set.add(gid)
|
||||
|
||||
# Pass 2: assign local IDs to unique new global IDs
|
||||
new_mappings = {}
|
||||
available = self.available_local_ids
|
||||
local_to_global = self.local_to_global
|
||||
|
||||
for gid in unique_miss_set:
|
||||
if len(g2l) >= self.batch_size:
|
||||
old_gid, old_local = g2l.popitem(last=False)
|
||||
available.append(old_local)
|
||||
local_to_global.pop(old_local, None)
|
||||
new_local = available.popleft()
|
||||
g2l[gid] = new_local
|
||||
local_to_global[new_local] = gid
|
||||
new_mappings[gid] = new_local
|
||||
|
||||
# Pass 3: fill in all miss positions
|
||||
for i, gid in enumerate(flat_global_cpu):
|
||||
if local_ids_cpu[i].item() == -1:
|
||||
local_ids_cpu[i] = new_mappings[gid]
|
||||
|
||||
return local_ids_cpu.to(original_device).view(original_shape)
|
||||
|
||||
def reset(self):
|
||||
self.global_to_local.clear()
|
||||
self.local_to_global.clear()
|
||||
self.available_local_ids = deque(range(self.batch_size))
|
||||
|
||||
def vllm__v1__attention__bachends__GDNAttentionMetadataBuilder____init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
if self.speculative_config:
|
||||
self.num_spec = self.speculative_config.num_speculative_tokens
|
||||
else:
|
||||
self.num_spec = 0
|
||||
self.use_spec_decode = self.num_spec > 0
|
||||
self._init_reorder_batch_threshold(1, self.use_spec_decode)
|
||||
|
||||
self.use_full_cuda_graph = (
|
||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
)
|
||||
self.decode_cudagraph_max_bs = min(
|
||||
self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1),
|
||||
self.compilation_config.max_cudagraph_capture_size,
|
||||
)
|
||||
|
||||
self.spec_state_indices_tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs, self.num_spec + 1),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.non_spec_state_indices_tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.spec_sequence_masks = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.bool,
|
||||
device=device,
|
||||
)
|
||||
self.spec_token_indx = torch.empty(
|
||||
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.non_spec_token_indx = torch.empty(
|
||||
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.spec_query_start_loc = torch.empty(
|
||||
(self.decode_cudagraph_max_bs + 1,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.non_spec_query_start_loc = torch.empty(
|
||||
(self.decode_cudagraph_max_bs + 1,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.num_accepted_tokens = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: support qwen3-next
|
||||
'''
|
||||
self.mapper = DeviceAwareLocalIdMapper(self.vllm_config.mlu_config.mamba_support_max_batch_size)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__v1__attention__bachends__GDNAttentionMetadataBuilder__build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
num_decode_draft_tokens_cpu: torch.Tensor | None = None,
|
||||
fast_build: bool = False,
|
||||
) -> GDNAttentionMetadata:
|
||||
m = common_attn_metadata
|
||||
|
||||
query_start_loc = m.query_start_loc
|
||||
context_lens = m.num_computed_tokens_cpu
|
||||
context_lens_tensor = context_lens.to(query_start_loc.device)
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||
|
||||
if (
|
||||
not self.use_spec_decode
|
||||
or num_decode_draft_tokens_cpu is None
|
||||
or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >= 0]
|
||||
.sum()
|
||||
.item()
|
||||
== 0
|
||||
):
|
||||
spec_sequence_masks = None
|
||||
num_spec_decodes = 0
|
||||
else:
|
||||
spec_sequence_masks = num_decode_draft_tokens_cpu >= 0
|
||||
num_spec_decodes = spec_sequence_masks.sum().item()
|
||||
if num_spec_decodes == 0:
|
||||
spec_sequence_masks = None
|
||||
else:
|
||||
spec_sequence_masks = spec_sequence_masks.to(
|
||||
query_start_loc.device, non_blocking=True
|
||||
)
|
||||
|
||||
if spec_sequence_masks is None:
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(m, decode_threshold=1)
|
||||
)
|
||||
num_spec_decode_tokens = 0
|
||||
spec_token_indx = None
|
||||
non_spec_token_indx = None
|
||||
spec_state_indices_tensor = None
|
||||
non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
|
||||
spec_query_start_loc = None
|
||||
non_spec_query_start_loc = query_start_loc
|
||||
num_accepted_tokens = None
|
||||
else:
|
||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
|
||||
non_spec_query_lens = query_lens[~spec_sequence_masks]
|
||||
num_decodes = (non_spec_query_lens == 1).sum().item()
|
||||
num_prefills = non_spec_query_lens.size(0) - num_decodes
|
||||
num_decode_tokens = num_decodes
|
||||
num_prefill_tokens = non_spec_query_lens.sum().item() - num_decode_tokens
|
||||
num_spec_decode_tokens = (
|
||||
query_lens.sum().item() - num_prefill_tokens - num_decode_tokens
|
||||
)
|
||||
|
||||
if num_prefills == 0 and num_decodes == 0:
|
||||
spec_token_size = min(
|
||||
num_spec_decodes * (self.num_spec + 1),
|
||||
query_start_loc[-1].item(),
|
||||
)
|
||||
spec_token_indx = torch.arange(
|
||||
spec_token_size,
|
||||
dtype=torch.int32,
|
||||
device=query_start_loc.device,
|
||||
)
|
||||
non_spec_token_indx = torch.empty(
|
||||
0, dtype=torch.int32, device=query_start_loc.device
|
||||
)
|
||||
spec_state_indices_tensor = m.block_table_tensor[:, : self.num_spec + 1]
|
||||
non_spec_state_indices_tensor = None
|
||||
spec_query_start_loc = query_start_loc
|
||||
non_spec_query_start_loc = None
|
||||
else:
|
||||
spec_token_masks = torch.repeat_interleave(
|
||||
spec_sequence_masks, query_lens
|
||||
)
|
||||
index = torch.argsort(spec_token_masks)
|
||||
num_non_spec_tokens = num_prefill_tokens + num_decode_tokens
|
||||
non_spec_token_indx = index[:num_non_spec_tokens]
|
||||
spec_token_indx = index[num_non_spec_tokens:]
|
||||
|
||||
spec_state_indices_tensor = m.block_table_tensor[
|
||||
spec_sequence_masks, : self.num_spec + 1
|
||||
]
|
||||
non_spec_state_indices_tensor = m.block_table_tensor[
|
||||
~spec_sequence_masks, 0
|
||||
]
|
||||
|
||||
spec_query_start_loc = torch.zeros(
|
||||
num_spec_decodes + 1,
|
||||
dtype=torch.int32,
|
||||
device=query_start_loc.device,
|
||||
)
|
||||
torch.cumsum(
|
||||
query_lens[spec_sequence_masks], dim=0, out=spec_query_start_loc[1:]
|
||||
)
|
||||
non_spec_query_start_loc = torch.zeros(
|
||||
query_lens.size(0) - num_spec_decodes + 1,
|
||||
dtype=torch.int32,
|
||||
device=query_start_loc.device,
|
||||
)
|
||||
torch.cumsum(
|
||||
query_lens[~spec_sequence_masks],
|
||||
dim=0,
|
||||
out=non_spec_query_start_loc[1:],
|
||||
)
|
||||
|
||||
assert num_accepted_tokens is not None
|
||||
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
|
||||
|
||||
if num_prefills > 0:
|
||||
has_initial_state = context_lens_tensor > 0
|
||||
if spec_sequence_masks is not None:
|
||||
has_initial_state = has_initial_state[~spec_sequence_masks]
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = (
|
||||
compute_causal_conv1d_metadata(non_spec_query_start_loc)
|
||||
)
|
||||
else:
|
||||
has_initial_state = None
|
||||
num_actual_tokens = (
|
||||
num_prefill_tokens + num_decode_tokens + num_spec_decode_tokens
|
||||
)
|
||||
|
||||
# prepare tensors for cudagraph
|
||||
#
|
||||
# With speculative decoding, the xgrammar backend may rollback tokens
|
||||
# and causing some sequences has less draft tokens than self.num_spec.
|
||||
#
|
||||
# In above cases, the max possible batch size for n tokens, can be
|
||||
# min(n, cudagraph_max_bs).
|
||||
if (
|
||||
self.use_full_cuda_graph
|
||||
and num_prefills == 0
|
||||
and num_decodes == 0
|
||||
and num_spec_decodes <= self.decode_cudagraph_max_bs
|
||||
and num_spec_decode_tokens <= self.decode_cudagraph_max_bs
|
||||
):
|
||||
num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
|
||||
batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens)
|
||||
|
||||
self.spec_state_indices_tensor[:num_spec_decodes].copy_(
|
||||
spec_state_indices_tensor, non_blocking=True
|
||||
)
|
||||
spec_state_indices_tensor = self.spec_state_indices_tensor[:batch_size]
|
||||
spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID)
|
||||
|
||||
self.spec_sequence_masks[:num_spec_decodes].copy_(
|
||||
spec_sequence_masks, non_blocking=True
|
||||
)
|
||||
spec_sequence_masks = self.spec_sequence_masks[:batch_size]
|
||||
spec_sequence_masks[num_spec_decodes:].fill_(False)
|
||||
|
||||
assert non_spec_token_indx is not None and spec_token_indx is not None
|
||||
self.non_spec_token_indx[: non_spec_token_indx.size(0)].copy_(
|
||||
non_spec_token_indx, non_blocking=True
|
||||
)
|
||||
non_spec_token_indx = self.non_spec_token_indx[
|
||||
: non_spec_token_indx.size(0)
|
||||
]
|
||||
|
||||
self.spec_token_indx[: spec_token_indx.size(0)].copy_(
|
||||
spec_token_indx, non_blocking=True
|
||||
)
|
||||
spec_token_indx = self.spec_token_indx[: spec_token_indx.size(0)]
|
||||
|
||||
self.spec_query_start_loc[: num_spec_decodes + 1].copy_(
|
||||
spec_query_start_loc, non_blocking=True
|
||||
)
|
||||
spec_num_query_tokens = spec_query_start_loc[-1] # type: ignore[index]
|
||||
spec_query_start_loc = self.spec_query_start_loc[: batch_size + 1]
|
||||
spec_query_start_loc[num_spec_decodes + 1 :].fill_(spec_num_query_tokens)
|
||||
|
||||
self.num_accepted_tokens[:num_spec_decodes].copy_(
|
||||
num_accepted_tokens, non_blocking=True
|
||||
)
|
||||
num_accepted_tokens = self.num_accepted_tokens[:batch_size]
|
||||
num_accepted_tokens[num_spec_decodes:].fill_(1)
|
||||
|
||||
if (
|
||||
self.use_full_cuda_graph
|
||||
and num_prefills == 0
|
||||
and num_spec_decodes == 0
|
||||
and num_decodes <= self.decode_cudagraph_max_bs
|
||||
):
|
||||
num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
|
||||
batch_size = num_actual_tokens
|
||||
|
||||
self.non_spec_state_indices_tensor[:num_decodes].copy_(
|
||||
non_spec_state_indices_tensor, non_blocking=True
|
||||
)
|
||||
non_spec_state_indices_tensor = self.non_spec_state_indices_tensor[
|
||||
:batch_size
|
||||
]
|
||||
non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID)
|
||||
|
||||
self.non_spec_query_start_loc[: num_decodes + 1].copy_(
|
||||
non_spec_query_start_loc, non_blocking=True
|
||||
)
|
||||
non_spec_num_query_tokens = non_spec_query_start_loc[-1] # type: ignore[index]
|
||||
non_spec_query_start_loc = self.non_spec_query_start_loc[: batch_size + 1]
|
||||
non_spec_query_start_loc[num_decodes + 1 :].fill_(non_spec_num_query_tokens)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: support qwen3-next
|
||||
'''
|
||||
non_spec_state_indices_tensor = self.mapper.batch_get_local_ids(non_spec_state_indices_tensor)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
attn_metadata = GDNAttentionMetadata(
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_spec_decodes=num_spec_decodes,
|
||||
num_spec_decode_tokens=num_spec_decode_tokens,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
has_initial_state=has_initial_state,
|
||||
spec_query_start_loc=spec_query_start_loc,
|
||||
non_spec_query_start_loc=non_spec_query_start_loc,
|
||||
spec_state_indices_tensor=spec_state_indices_tensor,
|
||||
non_spec_state_indices_tensor=non_spec_state_indices_tensor,
|
||||
spec_sequence_masks=spec_sequence_masks,
|
||||
spec_token_indx=spec_token_indx,
|
||||
non_spec_token_indx=non_spec_token_indx,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
nums_dict=nums_dict,
|
||||
batch_ptr=batch_ptr,
|
||||
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
MluHijackObject.apply_hijack(GDNAttentionMetadataBuilder,
|
||||
GDNAttentionMetadataBuilder.__init__,
|
||||
vllm__v1__attention__bachends__GDNAttentionMetadataBuilder____init__)
|
||||
|
||||
MluHijackObject.apply_hijack(GDNAttentionMetadataBuilder,
|
||||
GDNAttentionMetadataBuilder.build,
|
||||
vllm__v1__attention__bachends__GDNAttentionMetadataBuilder__build)
|
||||
934
vllm_mlu/v1/attention/backends/mla/flashmla.py
Normal file
934
vllm_mlu/v1/attention/backends/mla/flashmla.py
Normal file
@@ -0,0 +1,934 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.math_utils import cdiv, round_down
|
||||
from vllm.attention.backends.utils import MLADims
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend, MLACommonPrefillMetadata,
|
||||
MLACommonDecodeMetadata, MLACommonMetadata,
|
||||
MLACommonMetadataBuilder, M, QueryLenSupport,
|
||||
use_cudnn_prefill, use_flashinfer_prefill,
|
||||
use_trtllm_ragged_deepseek_prefill,
|
||||
FlashInferPrefillMetadata,
|
||||
CudnnPrefillMetadata,
|
||||
MLACommonImpl,
|
||||
CUDNN_WORKSPACE_SIZE
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport, split_decodes_and_prefills,
|
||||
infer_global_hyperparameters, get_per_layer_parameters,
|
||||
)
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionLayer,
|
||||
MLAAttentionImpl,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
import vllm_mlu._mlu_utils as mlu_envs
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm_mlu.v1.attention.backends.flash_attn import MLUFlashAttentionImpl
|
||||
from vllm_mlu.v1.attention.backends.utils import (
|
||||
MLUCommonAttentionMetadata, get_common_metadata,
|
||||
MLUInferMode)
|
||||
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
|
||||
from vllm.platforms import current_platform
|
||||
from vllm import envs
|
||||
try:
|
||||
from flashinfer import BatchPrefillWithRaggedKVCacheWrapper
|
||||
from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache # noqa: F401
|
||||
|
||||
flashinfer_available = True
|
||||
except ImportError:
|
||||
BatchPrefillWithRaggedKVCacheWrapper = object
|
||||
|
||||
flashinfer_available = False
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
class MLACommonBackend_MluHijack(MLACommonBackend):
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [576, 512]
|
||||
|
||||
|
||||
def get_mla_dims(model_config: ModelConfig) -> MLADims:
|
||||
hf_text_config = model_config.hf_text_config
|
||||
|
||||
if model_config.hf_text_config.model_type == "deepseek_v4":
|
||||
return MLADims(
|
||||
q_lora_rank=getattr(hf_text_config, "q_lora_rank", None),
|
||||
kv_lora_rank=hf_text_config.head_dim,
|
||||
qk_nope_head_dim=hf_text_config.head_dim - hf_text_config.rope_head_dim,
|
||||
qk_rope_head_dim=hf_text_config.rope_head_dim,
|
||||
v_head_dim=hf_text_config.head_dim,
|
||||
)
|
||||
|
||||
return MLADims(
|
||||
q_lora_rank=getattr(hf_text_config, "q_lora_rank", None),
|
||||
kv_lora_rank=hf_text_config.kv_lora_rank,
|
||||
qk_nope_head_dim=hf_text_config.qk_nope_head_dim,
|
||||
qk_rope_head_dim=hf_text_config.qk_rope_head_dim,
|
||||
v_head_dim=hf_text_config.v_head_dim,
|
||||
)
|
||||
|
||||
|
||||
class MLACommonMetadataBuilder_MluHijack(MLACommonMetadataBuilder):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
metadata_cls: type[M] | None = None,
|
||||
supports_dcp_with_varlen: bool = False,
|
||||
):
|
||||
self.metadata_cls = (
|
||||
metadata_cls if metadata_cls is not None else MLACommonMetadata
|
||||
)
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.vllm_config = vllm_config
|
||||
self.device = device
|
||||
|
||||
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
|
||||
self.mla_dims = get_mla_dims(self.model_config)
|
||||
self.aot_schedule = current_platform.is_cuda()
|
||||
try:
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
self.dcp_world_size = 1
|
||||
self.dcp_rank = 0
|
||||
self.dcp_local_block_size = parallel_config.dcp_kv_cache_interleave_size
|
||||
self.dcp_virtual_block_size = self.dcp_local_block_size * self.dcp_world_size
|
||||
|
||||
# Don't try to access the runner on AMD
|
||||
if self.aot_schedule:
|
||||
self.page_size = self.kv_cache_spec.block_size
|
||||
|
||||
self.chunked_prefill_workspace_size = (
|
||||
self.determine_chunked_prefill_workspace_size(vllm_config)
|
||||
)
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
# Note(hc): The local kvcache is incomplete when DCP is triggered,
|
||||
# an additional kvcache allgather across the DCP group is therefore
|
||||
# required, so the workspace has to be enlarged by 1/DCP relative
|
||||
# to the original TP allocation.
|
||||
assert self.chunked_prefill_workspace_size % self.dcp_world_size == 0
|
||||
self.chunked_prefill_workspace = torch.empty(
|
||||
(
|
||||
self.chunked_prefill_workspace_size
|
||||
+ self.chunked_prefill_workspace_size // self.dcp_world_size,
|
||||
self.model_config.get_head_size(),
|
||||
),
|
||||
dtype=self.model_config.dtype,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.chunked_prefill_workspace = torch.empty(
|
||||
(
|
||||
self.chunked_prefill_workspace_size,
|
||||
self.model_config.get_head_size(),
|
||||
),
|
||||
dtype=self.model_config.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self._use_cudnn_prefill = use_cudnn_prefill()
|
||||
self._use_fi_prefill = use_flashinfer_prefill()
|
||||
self._use_trtllm_ragged_prefill = use_trtllm_ragged_deepseek_prefill()
|
||||
self.prefill_metadata_cls = (
|
||||
FlashInferPrefillMetadata
|
||||
if self._use_fi_prefill
|
||||
else CudnnPrefillMetadata
|
||||
if self._use_cudnn_prefill
|
||||
else MLACommonPrefillMetadata
|
||||
)
|
||||
|
||||
if self._use_fi_prefill:
|
||||
self._workspace_buffer = torch.empty(
|
||||
envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self._fi_prefill_main: BatchPrefillWithRaggedKVCacheWrapper | None = None
|
||||
self._fi_prefill_chunks: list[BatchPrefillWithRaggedKVCacheWrapper] = []
|
||||
|
||||
self._global_hyperparameters = infer_global_hyperparameters(
|
||||
get_per_layer_parameters(vllm_config, layer_names, MLACommonImpl)
|
||||
)
|
||||
|
||||
if self._use_trtllm_ragged_prefill:
|
||||
self._workspace_buffer = torch.empty(
|
||||
envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if self._use_cudnn_prefill:
|
||||
self.cudnn_workspace = torch.empty(
|
||||
CUDNN_WORKSPACE_SIZE * scheduler_config.max_num_seqs,
|
||||
dtype=torch.int8,
|
||||
device=device,
|
||||
)
|
||||
|
||||
supports_spec_decode = self.query_len_support != QueryLenSupport.SINGLE_ONLY
|
||||
self._init_reorder_batch_threshold(
|
||||
self.reorder_batch_threshold, supports_spec_decode, supports_dcp_with_varlen
|
||||
)
|
||||
|
||||
# Validate consistency between query_len_support and reorder_batch_threshold
|
||||
if self.query_len_support == QueryLenSupport.SINGLE_ONLY:
|
||||
assert self.reorder_batch_threshold == 1, (
|
||||
f"reorder_batch_threshold must be 1 when query_len_support is "
|
||||
f"SINGLE_ONLY, got {self.reorder_batch_threshold}"
|
||||
)
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(MLACommonBackend,
|
||||
MLACommonBackend.get_supported_head_sizes,
|
||||
MLACommonBackend_MluHijack.get_supported_head_sizes)
|
||||
MluHijackObject.apply_hijack(MLACommonMetadataBuilder,
|
||||
MLACommonMetadataBuilder.__init__,
|
||||
MLACommonMetadataBuilder_MluHijack.__init__)
|
||||
|
||||
|
||||
class FlashMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHMLA_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["FlashMLAMetadata"]:
|
||||
return FlashMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
|
||||
return FlashMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashMLAImpl"]:
|
||||
return FlashMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int, # assumed to be 1 for MLA
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
return (1, num_blocks, num_kv_heads, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_scale_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
) -> tuple[int, ...]:
|
||||
return (1, num_blocks, num_kv_heads, block_size)
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [576, 512]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLAPrefillMetadata(MLACommonPrefillMetadata):
|
||||
num_prefills: int = -1 # for gather_cache
|
||||
max_seq_len: int = -1 # for attn forward
|
||||
|
||||
@property
|
||||
def block_tables(self):
|
||||
return self.block_table
|
||||
|
||||
@property
|
||||
def context_chunk_cu_seq_lens(self):
|
||||
if self.chunked_context is None:
|
||||
return None
|
||||
return self.chunked_context.cu_seq_lens
|
||||
|
||||
@property
|
||||
def context_chunk_starts(self):
|
||||
if self.chunked_context is None:
|
||||
return None
|
||||
return self.chunked_context.starts
|
||||
|
||||
@property
|
||||
def context_chunk_seq_tot(self):
|
||||
if self.chunked_context is None:
|
||||
return None
|
||||
return self.chunked_context.seq_tot
|
||||
|
||||
@property
|
||||
def context_chunk_max_seq_lens(self):
|
||||
if self.chunked_context is None:
|
||||
return None
|
||||
return self.chunked_context.max_seq_lens
|
||||
|
||||
@property
|
||||
def context_chunk_workspace(self):
|
||||
if self.chunked_context is None:
|
||||
return None
|
||||
return self.chunked_context.workspace
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
tile_scheduler_metadata: torch.Tensor
|
||||
num_splits: torch.Tensor
|
||||
|
||||
# add for mlu rope and attn forward
|
||||
query_start_loc: torch.Tensor # for rope
|
||||
max_query_len: int # for rope
|
||||
max_seq_len:int = -1 # for attn forward
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLAMetadata(MLACommonMetadata):
|
||||
num_prefill_tokens: Optional[int] = None
|
||||
|
||||
|
||||
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
|
||||
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
|
||||
reorder_batch_threshold: int = 128 # process small prefills with decode pathway
|
||||
# ^ TODO(matt): tune this
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_spec, layer_names, vllm_config, device, FlashMLAMetadata
|
||||
)
|
||||
|
||||
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
|
||||
self.cg_buf_tile_scheduler_metadata = None
|
||||
self.cg_buf_num_splits = None
|
||||
self.is_fp8_kvcache = vllm_config.cache_config.cache_dtype.startswith("fp8")
|
||||
|
||||
self.cg_buf_tile_scheduler_metadata = None
|
||||
self.cg_buf_num_splits = None
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: 1. set decoder_query_len for mtp
|
||||
@brief: 2. init chunk workspace for prefix_caching only
|
||||
@brief: 3. set prefill_metadata_cls
|
||||
@brief: 4. add deepseek v3.2 infos
|
||||
'''
|
||||
cache_config = vllm_config.cache_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
speculative_config = vllm_config.speculative_config
|
||||
self.num_speculative_tokens = (speculative_config.num_speculative_tokens
|
||||
if speculative_config is not None else 0)
|
||||
self.decoder_query_len = 1 + self.num_speculative_tokens
|
||||
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
self.is_deepseek_v32 = self.model_config.hf_text_config.model_type == "deepseek_v32"
|
||||
|
||||
self.enable_caching = cache_config.enable_prefix_caching
|
||||
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
|
||||
if (not self.is_deepseek_v32 and not self.chunked_prefill_enabled and
|
||||
(mlu_envs.VLLM_V1_USE_UNCHUNK_SCHED and self.enable_caching)):
|
||||
self.chunked_prefill_workspace_size = min(
|
||||
# Max sure there is enough for 8 full length request or at least
|
||||
# 4 pages of cache per request
|
||||
max(
|
||||
8 * self.model_config.max_model_len, 4 *
|
||||
scheduler_config.max_num_seqs * cache_config.block_size),
|
||||
# For long-context models try not to over-allocate limiting
|
||||
# kv-cache space, limiting it to 64k tokens,
|
||||
# which would result in the workspace being:
|
||||
# 2*(576)*(64*1024) = 144mb
|
||||
# (assuming 576 MLA head dim, and fp16)
|
||||
# which would result in up-projected context being
|
||||
# 2*(192*128)*(64*1024) = 3gb
|
||||
# (assuming 192 QK head dim, 128 heads, and fp16)
|
||||
128 * 1024)
|
||||
assert self.chunked_prefill_workspace_size >= \
|
||||
scheduler_config.max_num_seqs * cache_config.block_size
|
||||
self.chunked_prefill_workspace = torch.empty(
|
||||
(self.chunked_prefill_workspace_size,
|
||||
self.model_config.get_head_size()),
|
||||
dtype=self.model_config.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.prefill_metadata_cls = FlashMLAPrefillMetadata
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
# We now want to reorder the batch so that the "decode" requests are and
|
||||
# the front and the "prefill" requests are at the using the least amount
|
||||
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
|
||||
# where attention is likely memory-bound and "prefill" to mean requests
|
||||
# where attention is likely compute-bound, TODO(lucas): figure out a
|
||||
# better naming here)
|
||||
decodes = []
|
||||
prefills = []
|
||||
num_decode_tokens = 0
|
||||
num_prefill_tokens = 0
|
||||
# mlu v1 mtp forces decoder_query_len = 1 for k > 1, so we should set again
|
||||
self.decoder_query_len = 1 + self.num_speculative_tokens
|
||||
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
# for now treat 1 scheduled token as "decode" even if its not,
|
||||
# we should update this to something like < 8 in the future but
|
||||
# currently the TritonMLA._forward_decode only supports
|
||||
# num_tokens = 1
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: record prefill and decode requests and token nums to call
|
||||
chunked fa and single-query attn respectively in forward.
|
||||
@Notes: decodes need all prompt tokens are computed.
|
||||
'''
|
||||
req_index = input_batch.req_id_to_index.get(req_id)
|
||||
all_prompt_tokens_has_computed = (
|
||||
input_batch.num_computed_tokens_cpu[req_index] >=
|
||||
input_batch.num_prompt_tokens[req_index])
|
||||
if num_tokens <= self.decoder_query_len and all_prompt_tokens_has_computed:
|
||||
decodes.append(i)
|
||||
num_decode_tokens += num_tokens
|
||||
else:
|
||||
prefills.append(i)
|
||||
num_prefill_tokens += num_tokens
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
# We hope that this is fairly minimal since decodes
|
||||
# should be around for a number of iterations so hopefully they are
|
||||
# relatively stationary (and new request are generally appended to the
|
||||
# persistent batch so already should be at the back)
|
||||
# To achieve this we loop over the decodes in descending order and
|
||||
# the prefills in ascending order. We swap decodes from the "back"
|
||||
# i.e. past where the last decode should be in the reodorered with
|
||||
# prefills from the front of the batch.
|
||||
# `decodes` and `prefills` are already in ascending order just based on
|
||||
# the above loop
|
||||
num_decodes = len(decodes)
|
||||
num_prefills = len(prefills)
|
||||
modified_batch = False
|
||||
|
||||
for i in range(1, min(num_decodes, num_prefills) + 1):
|
||||
# If the decode is at the "back" of the batch, i, we can swap it
|
||||
# with the prefill closest to the front of the batch
|
||||
decode_idx = decodes[num_decodes - i]
|
||||
if decode_idx < num_decodes:
|
||||
break
|
||||
|
||||
input_batch.swap_states(prefills[i - 1], decode_idx)
|
||||
modified_batch = True
|
||||
|
||||
return modified_batch
|
||||
|
||||
def _build_decode(
|
||||
self,
|
||||
block_table_tensor: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
max_query_len: int,
|
||||
max_seq_len: int,
|
||||
) -> FlashMLADecodeMetadata:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: set tile_scheduler_metadata and num_splits to None.
|
||||
@brief: set dcp_tot_seq_lens_device.
|
||||
'''
|
||||
return FlashMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens,
|
||||
tile_scheduler_metadata=None,
|
||||
num_splits=None,
|
||||
dcp_tot_seq_lens=None,
|
||||
# for mlu
|
||||
max_seq_len=max_seq_len,
|
||||
query_start_loc=query_start_loc,
|
||||
max_query_len=max_query_len
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: MLUCommonAttentionMetadata) -> M:
|
||||
"""
|
||||
This method builds the metadata for full cudagraph capture.
|
||||
Currently, only decode is supported for full cudagraphs with MLA.
|
||||
"""
|
||||
m = common_attn_metadata
|
||||
if m.infer_mode == MLUInferMode.DECODE_ONLY:
|
||||
assert m.num_reqs * m.max_query_len == m.num_actual_tokens, \
|
||||
"MLA only supports decode-only full CUDAGraph capture. " \
|
||||
"Make sure all cudagraph capture sizes <= max_num_seq."
|
||||
|
||||
return self.build(0, m)
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: MLUCommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
input_batch: "InputBatch" = None) -> M:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
# Note(simon): be careful about the CPU <> GPU memory movement in this
|
||||
# function. We should avoid GPU -> CPU sync as much as possible because
|
||||
# it blocks on all previous kernels.
|
||||
device = self.device
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
|
||||
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
|
||||
num_computed_tokens_cpu = (common_attn_metadata.seq_lens_cpu -
|
||||
query_seq_lens_cpu)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: support normal and mtp input split
|
||||
'''
|
||||
if input_batch is None:
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||
split_decodes_and_prefills(common_attn_metadata,
|
||||
self.decoder_query_len)
|
||||
else:
|
||||
num_decodes, num_prefills = input_batch.split_decodes_and_prefills()
|
||||
num_decode_tokens = common_attn_metadata.query_start_loc_cpu[num_decodes].item()
|
||||
num_prefill_tokens = num_tokens - num_decode_tokens
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
assert num_decode_tokens + num_prefill_tokens == num_tokens
|
||||
|
||||
prefill_metadata = None
|
||||
if num_prefills > 0:
|
||||
reqs_start = num_decodes # prefill_start
|
||||
|
||||
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
|
||||
max_context_len_cpu = context_lens_cpu.max().item()
|
||||
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: avoid buffer missing when prefill_only + mlugraph
|
||||
'''
|
||||
if num_decodes > 0:
|
||||
prefill_query_start_loc = query_start_loc[
|
||||
reqs_start:] - query_start_loc[reqs_start]
|
||||
else:
|
||||
prefill_query_start_loc= query_start_loc
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
chunked_context_metadata = None
|
||||
if ((self.chunked_prefill_enabled or
|
||||
(mlu_envs.VLLM_V1_USE_UNCHUNK_SCHED and
|
||||
self.enable_caching and
|
||||
common_attn_metadata.is_chunked)
|
||||
) and num_prefills > 0 and max_context_len_cpu > 0):
|
||||
# NOTE: it is recommend you read the `Chunked Prefill` section
|
||||
# in the comment at the top of the file before trying to
|
||||
# understand the following code
|
||||
|
||||
# currently we allocate an equal amount of workspace for each
|
||||
# prefill in the batch, we could probably use a more advanced
|
||||
# algorithm here and allocate more workspace to prefills with
|
||||
# longer context lengths
|
||||
if self.is_deepseek_v32:
|
||||
max_context_chunk = self.max_model_len
|
||||
else:
|
||||
max_context_chunk = (self.chunked_prefill_workspace_size //
|
||||
num_prefills_with_context_cpu)
|
||||
|
||||
if self.aot_schedule:
|
||||
# align max_context_chunk to page_size by rounding down,
|
||||
# currently the `gather_cache` kernel cannot handle
|
||||
# `context_chunk_starts` that are not aligned to page_size
|
||||
max_context_chunk = round_down(max_context_chunk,
|
||||
self.page_size)
|
||||
|
||||
assert max_context_chunk > 0
|
||||
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
|
||||
|
||||
# if `max_context_chunk = 256`, `num_chunks = 3`, and
|
||||
# `num_prefills_with_context = 4`, create a tensor that looks
|
||||
# like
|
||||
# [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
|
||||
# Note(simon): this is done in CPU because of downstream's
|
||||
# of `to_list`.
|
||||
chunk_starts = \
|
||||
torch.arange(num_chunks, dtype=torch.int32) \
|
||||
.unsqueeze(1).expand(-1, num_prefills) \
|
||||
* max_context_chunk
|
||||
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
|
||||
chunk_starts + max_context_chunk)
|
||||
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
|
||||
|
||||
cu_seq_lens_cpu = torch.zeros(num_chunks,
|
||||
num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
pin_memory=True)
|
||||
torch.cumsum(chunk_seq_lens,
|
||||
dim=1,
|
||||
out=cu_seq_lens_cpu[:, 1:],
|
||||
dtype=torch.int32)
|
||||
|
||||
chunked_context_metadata_cls = \
|
||||
FlashMLAPrefillMetadata.ChunkedContextMetadata
|
||||
|
||||
chunked_context_metadata = \
|
||||
chunked_context_metadata_cls(
|
||||
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
|
||||
starts=chunk_starts.to(device, non_blocking=True),
|
||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
seq_lens=chunk_seq_lens,
|
||||
workspace=getattr(self, "chunked_prefill_workspace", None),
|
||||
)
|
||||
|
||||
if not self.is_deepseek_v32:
|
||||
assert max(chunked_context_metadata.max_seq_lens) <= \
|
||||
self.chunked_prefill_workspace_size
|
||||
|
||||
prefill_metadata = self.prefill_metadata_cls(
|
||||
block_table=block_table_tensor[reqs_start:, ...],
|
||||
query_start_loc=prefill_query_start_loc,
|
||||
max_query_len=max_query_len,
|
||||
chunked_context=chunked_context_metadata,
|
||||
# for mlu
|
||||
num_prefills=num_prefills,
|
||||
max_seq_len=common_attn_metadata.seq_lens_cpu[reqs_start:].max().item(),
|
||||
)
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
decode_metadata = self._build_decode(
|
||||
block_table_tensor=block_table_tensor[:num_decodes, ...],
|
||||
seq_lens=seq_lens[:num_decodes],
|
||||
query_start_loc=query_start_loc[:num_decodes + 1],
|
||||
max_query_len=query_seq_lens_cpu[:num_decodes].max().item(),
|
||||
max_seq_len=common_attn_metadata.seq_lens_cpu[:num_decodes].max().item(),
|
||||
)
|
||||
|
||||
attn_metadata = self.metadata_cls(
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
max_seq_len=common_attn_metadata.max_seq_len,
|
||||
num_actual_tokens=num_tokens,
|
||||
query_start_loc=query_start_loc,
|
||||
slot_mapping=slot_mapping,
|
||||
head_dim=self.model_config.get_head_size(),
|
||||
# MLACommonMetadata Chunk prefill specific
|
||||
num_decodes=num_decodes,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
prefill=prefill_metadata,
|
||||
decode=decode_metadata,
|
||||
)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: MLUCommonAttentionMetadata) -> bool:
|
||||
return common_attn_metadata.max_query_len == self.decoder_query_len
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class FlashMLAImpl(MLUFlashAttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashMLAImpl")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashMLAMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
kwargs: Optional[dict[str, Any]] = {},
|
||||
) -> torch.Tensor:
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for FlashAttentionImpl")
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
|
||||
out_lse = None
|
||||
|
||||
# use default common metadata if kwargs does not have common_metadata
|
||||
common_metadata: MLUCommonAttentionMetadata = kwargs.get("common_metadata", None)
|
||||
if common_metadata is None:
|
||||
common_metadata = get_common_metadata()
|
||||
|
||||
only_prefill = kwargs.get("only_prefill", False)
|
||||
only_decode = kwargs.get("only_decode", False)
|
||||
attn_bias = kwargs.get("attn_bias", None)
|
||||
|
||||
assert only_prefill != only_decode, "only_prefill and only_decode cannot be True and False at the same time."
|
||||
|
||||
if only_prefill:
|
||||
cu_seqlens_q = attn_metadata.prefill.query_start_loc
|
||||
cu_seqlens_kv = common_metadata.query_start_loc
|
||||
seqused_k = common_metadata.seq_lens[attn_metadata.num_decodes:]
|
||||
max_seqlen_q = attn_metadata.prefill.max_query_len
|
||||
max_seqlen_k = attn_metadata.prefill.max_seq_len
|
||||
block_table = attn_metadata.prefill.block_table
|
||||
num_actual_tokens = attn_metadata.num_prefill_tokens
|
||||
else:
|
||||
cu_seqlens_q = None # nouse
|
||||
cu_seqlens_kv = None # nouse
|
||||
seqused_k = common_metadata.seq_lens[:attn_metadata.num_decodes]
|
||||
max_seqlen_q = None # nouse
|
||||
max_seqlen_k = common_metadata.max_seq_len
|
||||
block_table = attn_metadata.decode.block_table
|
||||
num_actual_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
skip_process_cache = ((self.use_mla
|
||||
and (common_metadata.is_prefill_only
|
||||
or self.use_fused_mla_qkv
|
||||
or only_prefill))
|
||||
or self.kv_sharing_target_layer_name is not None)
|
||||
|
||||
kv_cache_, kv_cache_scale_, kv_cache_index_ = kv_cache
|
||||
key_cache = kv_cache_[0]
|
||||
value_cache = None if self.use_mla else kv_cache_[1]
|
||||
key_cache_scale, value_cache_scale = None, None
|
||||
if kv_cache_scale_.numel() > 0:
|
||||
key_cache_scale = kv_cache_scale_[0]
|
||||
value_cache_scale = None if self.use_mla else kv_cache_scale_[1]
|
||||
if not skip_process_cache:
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
mlu_ops.quant_to_paged_cache(
|
||||
k=key[:num_actual_tokens],
|
||||
v=(None if self.use_mla else value[:num_actual_tokens]),
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
k_cache_quant_scale=key_cache_scale,
|
||||
v_cache_quant_scale=value_cache_scale,
|
||||
slot_mapping=attn_metadata.slot_mapping.flatten(),
|
||||
)
|
||||
else:
|
||||
mlu_ops.reshape_paged_cache(
|
||||
k=key[:num_actual_tokens],
|
||||
v=(None if self.use_mla else value[:num_actual_tokens]),
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
slot_mapping=attn_metadata.slot_mapping.flatten()
|
||||
)
|
||||
|
||||
alibi_slopes = None if self.alibi_slopes is None else \
|
||||
self.alibi_slopes.repeat(seqused_k.shape[0], 1)
|
||||
|
||||
if kwargs.get("model_type", "") == "deepseek_v32":
|
||||
from vllm_mlu.model_executor.models.sp_utils import get_sp_forward_context
|
||||
sp_context = get_sp_forward_context()
|
||||
if sp_context is not None and sp_context.is_v32:
|
||||
num_actual_tokens = sp_context.sp_attn_metadata.num_prefill_tokens
|
||||
decode_query = query[:num_actual_tokens].view(-1, self.num_heads, self.head_size)
|
||||
head_size_v = value.shape[-1] if self.use_mla else self.head_size
|
||||
decode_output = output[:num_actual_tokens].view(-1, self.num_heads, head_size_v)
|
||||
decode_query = query.unsqueeze(1) # see tokens as batch dim
|
||||
decode_output = decode_output.unsqueeze(1)
|
||||
q_quant_scale = kwargs.get("q_quant_scale", None)
|
||||
if q_quant_scale is not None:
|
||||
q_quant_scale = q_quant_scale[:num_actual_tokens].view(-1, self.num_heads)
|
||||
q_quant_scale = q_quant_scale.unsqueeze(1)
|
||||
mlu_ops.single_query_cached_kv_attn(
|
||||
q=decode_query,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
out=decode_output,
|
||||
block_tables=kwargs.get("new_block_tables", None),
|
||||
context_lens=kwargs.get("new_context_lens", None),
|
||||
k_cache_quant_scale=key_cache_scale,
|
||||
v_cache_quant_scale=value_cache_scale,
|
||||
alibi_slopes=alibi_slopes,
|
||||
max_contxt_len=kwargs.get("index_topk", None),
|
||||
windows_size_left=(-1 if self.sliding_window is None else self.sliding_window[0]),
|
||||
windows_size_right=(-1 if self.sliding_window is None else self.sliding_window[0]),
|
||||
softmax_scale=self.scale,
|
||||
head_size_v=(-1 if not self.use_mla else head_size_v),
|
||||
compute_dtype=compute_dtype,
|
||||
q_quant_scale=q_quant_scale,
|
||||
decoder_attn_dtype=self.decoder_attn_dtype,
|
||||
mask=attn_bias,
|
||||
)
|
||||
return output
|
||||
|
||||
if common_metadata.is_prefill_only or only_prefill:
|
||||
# prefill only
|
||||
prefill_causal = kwargs.get("prefill_causal", True)
|
||||
cu_seqlens_q = kwargs.get("cu_seq_lens_q", cu_seqlens_q)
|
||||
cu_seqlens_kv = kwargs.get("cu_seq_lens_kv", cu_seqlens_kv)
|
||||
max_seqlen_q = kwargs.get("max_seq_len_q", max_seqlen_q)
|
||||
max_seqlen_k = kwargs.get("max_seq_len_kv", max_seqlen_k)
|
||||
return_lse = kwargs.get("return_lse", False)
|
||||
num_prefill_query_tokens = common_metadata.num_prefill_query_tokens
|
||||
num_prefill_kv_tokens = common_metadata.num_prefill_kv_tokens
|
||||
use_f32 = attn_bias is not None and attn_bias.dtype == torch.float32
|
||||
if use_f32:
|
||||
f32_output = torch.empty_like(output, dtype=torch.float32)
|
||||
attn_output_list = mlu_ops.flash_attention(
|
||||
q=query[:num_prefill_query_tokens].to(torch.float32) if use_f32 else query[:num_prefill_query_tokens],
|
||||
k=key[:num_prefill_kv_tokens].to(torch.float32) if use_f32 else key[:num_prefill_kv_tokens],
|
||||
v=value[:num_prefill_kv_tokens].to(torch.float32) if use_f32 else value[:num_prefill_kv_tokens],
|
||||
out=f32_output[:num_prefill_query_tokens] if use_f32 else output[:num_prefill_query_tokens],
|
||||
cu_seq_lens_q=cu_seqlens_q,
|
||||
cu_seq_lens_kv=cu_seqlens_kv,
|
||||
alibi_slope=alibi_slopes,
|
||||
attn_bias=attn_bias,
|
||||
max_seq_len_q=max_seqlen_q,
|
||||
max_seq_len_kv=max_seqlen_k,
|
||||
softmax_scale=self.scale,
|
||||
is_causal=prefill_causal,
|
||||
window_size_left=(-1 if self.sliding_window is None else self.sliding_window[0]),
|
||||
window_size_right=(-1 if self.sliding_window is None else self.sliding_window[1]),
|
||||
compute_dtype=self.prefill_compute_dtype,
|
||||
return_lse=return_lse,
|
||||
q_quant_dtype=self.prefill_q_dtype,
|
||||
k_quant_dtype=self.prefill_k_dtype,
|
||||
v_quant_dtype=self.prefill_v_dtype
|
||||
)
|
||||
if use_f32:
|
||||
output[:num_prefill_query_tokens].copy_(f32_output[:num_prefill_query_tokens])
|
||||
|
||||
if return_lse:
|
||||
out_lse = attn_output_list[1]
|
||||
else:
|
||||
batch_size = block_table.shape[0]
|
||||
# decode only
|
||||
decode_query = query[:num_actual_tokens].view(batch_size, -1, self.num_heads, self.head_size)
|
||||
head_size_v = value.shape[-1] if self.use_mla else self.head_size
|
||||
decode_output = output[:num_actual_tokens].view(batch_size, -1, self.num_heads, head_size_v)
|
||||
q_quant_scale = kwargs.get("q_quant_scale", None)
|
||||
if q_quant_scale is not None:
|
||||
q_quant_scale = q_quant_scale[:num_actual_tokens].view(batch_size, -1, self.num_heads)
|
||||
mlu_ops.single_query_cached_kv_attn(
|
||||
q=decode_query,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
out=decode_output,
|
||||
block_tables=block_table,
|
||||
context_lens=seqused_k,
|
||||
k_cache_quant_scale=key_cache_scale,
|
||||
v_cache_quant_scale=value_cache_scale,
|
||||
alibi_slopes=alibi_slopes,
|
||||
max_contxt_len=max_seqlen_k,
|
||||
windows_size_left=(-1 if self.sliding_window is None else self.sliding_window[0]),
|
||||
windows_size_right=(-1 if self.sliding_window is None else self.sliding_window[0]),
|
||||
softmax_scale=self.scale,
|
||||
head_size_v=(-1 if not self.use_mla else head_size_v),
|
||||
compute_dtype=attn_metadata.decode.compute_dtype,
|
||||
q_quant_scale=q_quant_scale,
|
||||
decoder_attn_dtype=self.decoder_attn_dtype,
|
||||
mask=attn_bias,
|
||||
)
|
||||
|
||||
return output if out_lse is None else (output, out_lse)
|
||||
295
vllm_mlu/v1/attention/backends/utils.py
Normal file
295
vllm_mlu/v1/attention/backends/utils.py
Normal file
@@ -0,0 +1,295 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
|
||||
|
||||
COMMON_METADATA_STR: str = "common_metadata"
|
||||
|
||||
|
||||
class MLUInferMode(Enum):
|
||||
CHUNKED = 1
|
||||
PREFILL_ONLY = 2
|
||||
DECODE_ONLY = 3
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
max_query_len,
|
||||
max_computed_tokens,
|
||||
uniform_decode_query_len: int = 1,
|
||||
) -> Enum:
|
||||
if max_query_len <= uniform_decode_query_len:
|
||||
return MLUInferMode.DECODE_ONLY
|
||||
elif max_computed_tokens == 0:
|
||||
return MLUInferMode.PREFILL_ONLY
|
||||
else:
|
||||
return MLUInferMode.CHUNKED
|
||||
|
||||
@property
|
||||
def is_prefill_only(self):
|
||||
return self == MLUInferMode.PREFILL_ONLY
|
||||
|
||||
@property
|
||||
def is_decode_only(self):
|
||||
return self == MLUInferMode.DECODE_ONLY
|
||||
|
||||
@property
|
||||
def is_chunked(self):
|
||||
return self == MLUInferMode.CHUNKED
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLUCommonAttentionMetadata(CommonAttentionMetadata):
|
||||
"""
|
||||
Attention metadata attributes that can be shared by layers in different KV
|
||||
cache groups and thus having different block table.
|
||||
"""
|
||||
seq_start_loc: torch.Tensor | None = None
|
||||
seq_start_loc_cpu: torch.Tensor | None = None
|
||||
"""(batch_size + 1,), the start location of each request in the input key/value sequence."""
|
||||
num_input_tokens: int = 0
|
||||
"""Number of query tokens with padding."""
|
||||
num_prefill_query_tokens: int = 0
|
||||
"""Number of query tokens in prefill phase."""
|
||||
num_prefill_kv_tokens: int = 0
|
||||
"""Number of key/value tokens in prefill phase."""
|
||||
infer_mode: MLUInferMode | None = None
|
||||
"""Inference mode for flash attention."""
|
||||
|
||||
@property
|
||||
def is_prefill_only(self):
|
||||
return self.infer_mode == MLUInferMode.PREFILL_ONLY
|
||||
|
||||
@property
|
||||
def is_decode_only(self):
|
||||
return self.infer_mode == MLUInferMode.DECODE_ONLY
|
||||
|
||||
@property
|
||||
def is_chunked(self):
|
||||
return self.infer_mode == MLUInferMode.CHUNKED
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
query_start_loc, query_start_loc_cpu,
|
||||
seq_lens, seq_lens_cpu,
|
||||
num_computed_tokens_cpu,
|
||||
num_reqs, num_actual_tokens, max_query_len,
|
||||
block_table_tensor, slot_mapping,
|
||||
seq_start_loc, is_start_loc_match,
|
||||
num_input_tokens: int = 0,
|
||||
num_speculative_tokens: int = 0,
|
||||
has_prefill_reqs: bool = False
|
||||
):
|
||||
"""Build attention metadata for MLU inference.
|
||||
|
||||
Args:
|
||||
has_prefill_reqs: Whether there are pending prefill requests with chunked.
|
||||
"""
|
||||
infer_mode = None
|
||||
if is_start_loc_match:
|
||||
infer_mode = MLUInferMode.PREFILL_ONLY
|
||||
elif max_query_len <= (1 + num_speculative_tokens) and (not has_prefill_reqs):
|
||||
infer_mode = MLUInferMode.DECODE_ONLY
|
||||
else:
|
||||
infer_mode = MLUInferMode.CHUNKED
|
||||
num_input_tokens = (
|
||||
num_actual_tokens if num_input_tokens == 0
|
||||
else num_input_tokens
|
||||
)
|
||||
max_seq_len = int(seq_lens_cpu.max())
|
||||
return cls(query_start_loc=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_seq_len,
|
||||
block_table_tensor=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
seq_start_loc=seq_start_loc,
|
||||
seq_start_loc_cpu=seq_start_loc.to("cpu", non_blocking=True),
|
||||
num_input_tokens=num_input_tokens,
|
||||
infer_mode=infer_mode,
|
||||
num_prefill_query_tokens=num_actual_tokens,
|
||||
num_prefill_kv_tokens=num_actual_tokens)
|
||||
|
||||
def save(self, infer_phase: str):
|
||||
csv_path = os.getenv("VLLM_STEP_INPUT_CSV_PATH", None)
|
||||
if not csv_path:
|
||||
return
|
||||
|
||||
header = [
|
||||
"infer_phase", "infer_mode", "num_reqs", "num_actual_tokens",
|
||||
"max_query_len", "max_seq_len", "query_start_loc", "seq_lens"
|
||||
]
|
||||
data = [
|
||||
infer_phase, self.infer_mode, self.num_reqs,
|
||||
self.num_actual_tokens, self.max_query_len, self.max_seq_len,
|
||||
str(self.query_start_loc_cpu.tolist()),
|
||||
str(self.seq_lens_cpu.tolist())
|
||||
]
|
||||
data_dict = dict(zip(header, data))
|
||||
df_csv = pd.DataFrame(data_dict, index=[0])
|
||||
|
||||
if infer_phase == "RealInfer":
|
||||
print(df_csv.to_string())
|
||||
|
||||
try:
|
||||
if dir_path := os.path.dirname(csv_path):
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
append = False
|
||||
if os.path.isfile(csv_path):
|
||||
try:
|
||||
df_old = pd.read_csv(csv_path)
|
||||
append = (df_old.columns.tolist() == header)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Existing {csv_path} failed to be read and will be overwritten")
|
||||
if append:
|
||||
df_csv.to_csv(csv_path, mode='a', header=False, index=False)
|
||||
else:
|
||||
df_csv.to_csv(csv_path, index=False)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Invalid VLLM_STEP_INPUT_CSV_PATH: {csv_path} to dump step inputs, Error: {e}")
|
||||
|
||||
|
||||
def get_common_metadata_from_attn_metadata(
|
||||
attn_metadata) -> Union[MLUCommonAttentionMetadata, None]:
|
||||
"""
|
||||
Get MLUCommonAttentionMetadata for MLU-V1 inference.
|
||||
Use outside of set_forward_context().
|
||||
"""
|
||||
if attn_metadata is None:
|
||||
return
|
||||
|
||||
assert (isinstance(attn_metadata, dict)
|
||||
and COMMON_METADATA_STR in attn_metadata), \
|
||||
f"MLU-V1 only support type(attn_metadata)=dict, and " + \
|
||||
f"{COMMON_METADATA_STR} in attn_metadata. Now, type(attn_metadata)=" + \
|
||||
f"{type(attn_metadata)}, or {COMMON_METADATA_STR} not in attn_metadata."
|
||||
return attn_metadata[COMMON_METADATA_STR]
|
||||
|
||||
|
||||
def get_common_metadata() -> Union[MLUCommonAttentionMetadata, None]:
|
||||
"""
|
||||
Get MLUCommonAttentionMetadata for MLU-V1 inference.
|
||||
Use inside of set_forward_context().
|
||||
"""
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
return get_common_metadata_from_attn_metadata(attn_metadata)
|
||||
|
||||
|
||||
def unpad_common_attn_metadata(
|
||||
common_metadata: MLUCommonAttentionMetadata,
|
||||
num_reqs: int,
|
||||
num_scheduled_tokens: int,
|
||||
):
|
||||
"""
|
||||
Unpad MLUCommonAttentionMetadata by given num_reqs and num_scheduled_tokens.
|
||||
"""
|
||||
common_metadata.num_reqs = num_reqs
|
||||
common_metadata.num_input_tokens = num_scheduled_tokens
|
||||
common_metadata.query_start_loc = common_metadata.query_start_loc[:num_reqs + 1]
|
||||
common_metadata.query_start_loc_cpu = common_metadata.query_start_loc_cpu[:num_reqs + 1]
|
||||
common_metadata.seq_start_loc = common_metadata.seq_start_loc[:num_reqs + 1]
|
||||
common_metadata.seq_lens = common_metadata.seq_lens[:num_reqs]
|
||||
common_metadata.seq_lens_cpu = common_metadata.seq_lens_cpu[:num_reqs]
|
||||
common_metadata.block_table_tensor = common_metadata.block_table_tensor[:num_reqs]
|
||||
|
||||
def reorder_batch_to_split_decodes_and_prefills(
|
||||
input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput",
|
||||
decode_threshold: int = 1,
|
||||
) -> bool:
|
||||
"""
|
||||
Reorders the batch to split into prefill and decode requests; places all
|
||||
requests with <= decode_threshold tokens at the front of the batch.
|
||||
|
||||
Returns:
|
||||
True if the batch was modified, False otherwise.
|
||||
"""
|
||||
# We now want to reorder the batch into decode → extend → prefill order
|
||||
# where:
|
||||
# decode: request with num_scheduled_tokens <= decode_threshold
|
||||
# extend: non-decode request with existing context
|
||||
# prefill: non-decode request with no existing context
|
||||
# NOTE for now we loosely use "decode" to mean requests where attention is
|
||||
# likely memory-bound and "prefill" to mean requests where attention is
|
||||
# likely compute-bound,
|
||||
num_reqs = len(input_batch.req_ids)
|
||||
num_scheduled_tokens = [
|
||||
scheduler_output.num_scheduled_tokens[id] for id in input_batch.req_ids
|
||||
]
|
||||
num_scheduled_tokens_np = np.array(num_scheduled_tokens)
|
||||
num_computed_tokens_np = input_batch.num_computed_tokens_cpu[:num_reqs]
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: enhence decode mode condition that all prompt tokens are computed.
|
||||
'''
|
||||
# is_decode = num_scheduled_tokens_np <= decode_threshold
|
||||
is_decode = (
|
||||
(num_scheduled_tokens_np <= decode_threshold)
|
||||
& (num_computed_tokens_np >= input_batch.num_prompt_tokens[:num_reqs])
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
is_extend = (~is_decode) & (num_computed_tokens_np > 0)
|
||||
is_prefill = (~is_decode) & (num_computed_tokens_np == 0)
|
||||
|
||||
# Desired order: decode → extend → prefill
|
||||
req_regions = np.zeros(is_decode.shape, dtype=np.int32) # 0 = decode by default
|
||||
req_regions[is_extend] = 1
|
||||
req_regions[is_prefill] = 2
|
||||
|
||||
num_decodes = int(is_decode.sum())
|
||||
num_extends = int(is_extend.sum())
|
||||
|
||||
target_regions = np.zeros(num_reqs, dtype=np.int32)
|
||||
target_regions[num_decodes : num_decodes + num_extends] = 1
|
||||
target_regions[num_decodes + num_extends :] = 2
|
||||
|
||||
needs_swap = req_regions != target_regions
|
||||
|
||||
if not needs_swap.any():
|
||||
return False
|
||||
|
||||
# Extract indices that need swapping and sort by target region
|
||||
orig_indices = np.where(needs_swap)[0]
|
||||
sorted_order = np.argsort(req_regions[needs_swap], kind="stable")
|
||||
src_indices = orig_indices[sorted_order]
|
||||
|
||||
src_dest_map = {int(src): int(dst) for src, dst in zip(src_indices, orig_indices)}
|
||||
|
||||
for src in src_dest_map:
|
||||
dst = src_dest_map[src]
|
||||
while src != dst:
|
||||
input_batch.swap_states(src, dst)
|
||||
# Mark dst as done by updating its destination to itself
|
||||
next_dst = src_dest_map.get(dst, dst)
|
||||
src_dest_map[dst] = dst
|
||||
dst = next_dst
|
||||
|
||||
return True
|
||||
3
vllm_mlu/v1/core/__init__.py
Normal file
3
vllm_mlu/v1/core/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
146
vllm_mlu/v1/core/kv_cache_manager.py
Normal file
146
vllm_mlu/v1/core/kv_cache_manager.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import itertools
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, overload
|
||||
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
|
||||
from vllm.v1.core.kv_cache_utils import KVCacheBlock
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
class KVCacheManager_MluHijack(KVCacheManager):
|
||||
|
||||
def allocate_slots(
|
||||
self,
|
||||
request: Request,
|
||||
num_new_tokens: int,
|
||||
num_new_computed_tokens: int = 0,
|
||||
new_computed_blocks: KVCacheBlocks | None = None,
|
||||
num_lookahead_tokens: int = 0,
|
||||
delay_cache_blocks: bool = False,
|
||||
num_encoder_tokens: int = 0,
|
||||
fixed_window_tokens: int = 0,
|
||||
) -> KVCacheBlocks | None:
|
||||
"""Add slots for a request with new tokens to append.
|
||||
|
||||
Args:
|
||||
request: The request to allocate slots.
|
||||
num_new_tokens: The number of tokens to allocate, including external
|
||||
tokens. Note that this does not include tokens that have
|
||||
already been computed locally (i.e. new_computed_blocks).
|
||||
num_new_computed_tokens: The number of new computed tokens just
|
||||
hitting the prefix caching, excluding external tokens.
|
||||
new_computed_blocks: The cached blocks for the above new computed
|
||||
tokens.
|
||||
num_lookahead_tokens: The number of speculative tokens to allocate.
|
||||
This is used by spec decode proposers with kv-cache such
|
||||
as eagle.
|
||||
delay_cache_blocks: Whether to skip caching the blocks. This is
|
||||
used by P/D when allocating blocks used in a KV transfer
|
||||
which will complete in a future step.
|
||||
|
||||
Blocks layout:
|
||||
```
|
||||
-----------------------------------------------------------------------
|
||||
| < computed > | < new computed > | < new > | < pre-allocated > |
|
||||
-----------------------------------------------------------------------
|
||||
| < required > |
|
||||
--------------------------------------------------
|
||||
| < full > |
|
||||
------------------------------------------------
|
||||
| <new full> |
|
||||
--------------
|
||||
```
|
||||
The following *_blocks are illustrated in this layout.
|
||||
|
||||
Returns:
|
||||
A list of new allocated blocks.
|
||||
"""
|
||||
if num_new_tokens == 0:
|
||||
raise ValueError("num_new_tokens must be greater than 0")
|
||||
|
||||
if new_computed_blocks is not None:
|
||||
new_computed_block_list = new_computed_blocks.blocks
|
||||
else:
|
||||
new_computed_block_list = self.empty_kv_cache_blocks.blocks
|
||||
|
||||
# Free the blocks that are skipped during the attention computation
|
||||
# (e.g., tokens outside the sliding window).
|
||||
# We can do this even if we cannot schedule this request due to
|
||||
# insufficient free blocks.
|
||||
# Should call this function before allocating new blocks to reduce
|
||||
# the number of evicted blocks.
|
||||
self.coordinator.remove_skipped_blocks(
|
||||
request.request_id, request.num_computed_tokens
|
||||
)
|
||||
|
||||
# The number of computed tokens is the number of computed tokens plus
|
||||
# the new prefix caching hits
|
||||
num_computed_tokens = request.num_computed_tokens + num_new_computed_tokens
|
||||
num_tokens_need_slot = min(
|
||||
num_computed_tokens + num_new_tokens + num_lookahead_tokens + fixed_window_tokens,
|
||||
self.max_model_len,
|
||||
)
|
||||
|
||||
num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate(
|
||||
request_id=request.request_id,
|
||||
num_tokens=num_tokens_need_slot,
|
||||
new_computed_blocks=new_computed_block_list,
|
||||
num_encoder_tokens=num_encoder_tokens,
|
||||
)
|
||||
|
||||
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
|
||||
# Cannot allocate new blocks
|
||||
return None
|
||||
|
||||
# Touch the computed blocks to make sure they won't be evicted.
|
||||
if self.enable_caching:
|
||||
self.block_pool.touch(new_computed_block_list)
|
||||
else:
|
||||
assert not any(new_computed_block_list), (
|
||||
"Computed blocks should be empty when prefix caching is disabled"
|
||||
)
|
||||
|
||||
if new_computed_block_list is not self.empty_kv_cache_blocks.blocks:
|
||||
# Append the new computed blocks to the request blocks until now to
|
||||
# avoid the case where the new blocks cannot be allocated.
|
||||
self.coordinator.save_new_computed_blocks(
|
||||
request.request_id, new_computed_block_list
|
||||
)
|
||||
|
||||
new_blocks = self.coordinator.allocate_new_blocks(
|
||||
request.request_id, num_tokens_need_slot, num_encoder_tokens
|
||||
)
|
||||
|
||||
# P/D: delay caching blocks if we have to recv from
|
||||
# remote. Update state for locally cached blocks.
|
||||
if not self.enable_caching or delay_cache_blocks:
|
||||
return self.create_kv_cache_blocks(new_blocks)
|
||||
|
||||
# NOTE(woosuk): We want to commit (cache) up to num_computed_tokens +
|
||||
# num_new_tokens, but must exclude "non-committable" tokens (e.g.,
|
||||
# draft tokens that could be rejected). Therefore, we cap the number
|
||||
# at `request.num_tokens`, ensuring only "finalized" tokens are cached.
|
||||
num_tokens_to_cache = min(
|
||||
num_computed_tokens + num_new_tokens, request.num_tokens
|
||||
)
|
||||
self.coordinator.cache_blocks(request, num_tokens_to_cache)
|
||||
|
||||
return self.create_kv_cache_blocks(new_blocks)
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(KVCacheManager,
|
||||
KVCacheManager.allocate_slots,
|
||||
KVCacheManager_MluHijack.allocate_slots)
|
||||
123
vllm_mlu/v1/core/kv_cache_utils.py
Normal file
123
vllm_mlu/v1/core/kv_cache_utils.py
Normal file
@@ -0,0 +1,123 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
KVCacheConfig,
|
||||
KVCacheGroupSpec,
|
||||
KVCacheSpec,
|
||||
KVCacheTensor,
|
||||
UniformTypeKVCacheSpecs,
|
||||
)
|
||||
from vllm.v1.core import kv_cache_utils
|
||||
from vllm.v1.core.kv_cache_utils import (may_override_num_blocks,
|
||||
get_uniform_page_size,
|
||||
get_num_blocks)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
def vllm__v1__core__kv_cache_utils__get_kv_cache_config_from_groups(
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_groups: list[KVCacheGroupSpec],
|
||||
kv_cache_specs: dict[str, KVCacheSpec],
|
||||
available_memory: int,
|
||||
) -> KVCacheConfig:
|
||||
"""
|
||||
Generate the KV cache configuration from the KV cache groups and spec
|
||||
of each layer.
|
||||
|
||||
Args:
|
||||
vllm_config: The global VllmConfig
|
||||
kv_cache_groups: The KV cache groups
|
||||
kv_cache_specs: The KV cache spec of each attention layer in the model
|
||||
available_memory: Memory available for KV cache in bytes
|
||||
Returns:
|
||||
The generated KVCacheConfig
|
||||
"""
|
||||
if len(kv_cache_groups) == 0:
|
||||
# Attention free models do not have KV cache.
|
||||
# Return num_blocks=1 as BlockPool always needs a null_block.
|
||||
return KVCacheConfig(
|
||||
num_blocks=1,
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=kv_cache_groups,
|
||||
)
|
||||
|
||||
# Determine how model runners should initialize the KV cache tensors.
|
||||
if len(kv_cache_groups) == 1 and isinstance(
|
||||
kv_cache_groups[0].kv_cache_spec, UniformTypeKVCacheSpecs
|
||||
):
|
||||
# Special case: all layers have the same type of KV cache but with
|
||||
# different hidden size. Allocate different amount of memory for each
|
||||
# layer based on its hidden size.
|
||||
num_blocks = (
|
||||
available_memory // kv_cache_groups[0].kv_cache_spec.page_size_bytes
|
||||
)
|
||||
num_blocks = may_override_num_blocks(vllm_config, num_blocks)
|
||||
per_layer_specs = kv_cache_groups[0].kv_cache_spec.kv_cache_specs
|
||||
kv_cache_tensors = [
|
||||
KVCacheTensor(
|
||||
size=per_layer_specs[layer_name].page_size_bytes * num_blocks,
|
||||
shared_by=[layer_name],
|
||||
)
|
||||
for layer_name in kv_cache_groups[0].layer_names
|
||||
]
|
||||
else:
|
||||
# General case:
|
||||
# We will have group_size memory pools, each is shared by one layer from
|
||||
# each group. As layers of different groups have different block table,
|
||||
# they will use different parts of the shared Tensor.
|
||||
# The memory layout for 3 groups (full.0, full.1), (sw.0, sw.2),
|
||||
# (sw.1, padding) will be: (group_size = 2)
|
||||
# full.0, sw.0, sw.1: share a Tensor with size=available_memory//2
|
||||
# full.1, sw.2: share another Tensor with size=available_memory//2
|
||||
group_size = max(len(group.layer_names) for group in kv_cache_groups)
|
||||
|
||||
page_size = get_uniform_page_size(kv_cache_specs)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: support qwen3-next
|
||||
'''
|
||||
if (vllm_config.mlu_config.enable_mamba_split_page_size):
|
||||
# Note(wulingchao): 预留出linear attention的内存不参与系统调度
|
||||
# 当前的 page size是小page,需要扩展到完整的linear attention的page
|
||||
mamba_page_size = (page_size \
|
||||
* vllm_config.mlu_config.mamba_to_attn_block_ratio
|
||||
* vllm_config.mlu_config.mamba_support_max_batch_size \
|
||||
* group_size * 3)
|
||||
logger.warning(f"all available memory {available_memory}, mamba mem used {mamba_page_size}")
|
||||
available_memory = available_memory - mamba_page_size
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
assert group_size > 0, "group_size must be greater than 0"
|
||||
num_blocks = get_num_blocks(
|
||||
vllm_config, group_size, available_memory, page_size
|
||||
)
|
||||
kv_cache_tensors = []
|
||||
for i in range(group_size):
|
||||
shared_by = []
|
||||
for j in range(len(kv_cache_groups)):
|
||||
if i < len(kv_cache_groups[j].layer_names):
|
||||
shared_by.append(kv_cache_groups[j].layer_names[i])
|
||||
kv_cache_tensors.append(
|
||||
KVCacheTensor(size=page_size * num_blocks, shared_by=shared_by)
|
||||
)
|
||||
|
||||
return KVCacheConfig(
|
||||
num_blocks=num_blocks,
|
||||
kv_cache_tensors=kv_cache_tensors,
|
||||
kv_cache_groups=kv_cache_groups,
|
||||
)
|
||||
|
||||
MluHijackObject.apply_hijack(kv_cache_utils,
|
||||
kv_cache_utils.get_kv_cache_config_from_groups,
|
||||
vllm__v1__core__kv_cache_utils__get_kv_cache_config_from_groups)
|
||||
|
||||
|
||||
3
vllm_mlu/v1/core/sched/__init__.py
Normal file
3
vllm_mlu/v1/core/sched/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
136
vllm_mlu/v1/core/sched/async_scheduler.py
Normal file
136
vllm_mlu/v1/core/sched/async_scheduler.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
from vllm_mlu.v1.core.sched.scheduler import MLUUnchunkScheduler, SchedulerWithProfiler
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AsyncScheduler(SchedulerWithProfiler):
|
||||
def _update_after_schedule(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> None:
|
||||
super()._update_after_schedule(scheduler_output)
|
||||
pending_structured_output_tokens = False
|
||||
spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
|
||||
for req_id in scheduler_output.num_scheduled_tokens:
|
||||
request = self.requests[req_id]
|
||||
pending_structured_output_tokens |= (
|
||||
request.use_structured_output and request.num_output_placeholders > 0
|
||||
)
|
||||
cur_num_spec_tokens = len(spec_decode_tokens.get(req_id, ()))
|
||||
if (
|
||||
request.num_computed_tokens
|
||||
== request.num_tokens
|
||||
+ request.num_output_placeholders
|
||||
+ cur_num_spec_tokens
|
||||
):
|
||||
# The request will generate a new token plus num_spec_tokens
|
||||
# in this scheduling step.
|
||||
request.num_output_placeholders += 1 + cur_num_spec_tokens
|
||||
# Add placeholders for the new tokens in spec_token_ids.
|
||||
# Wwe will update the actual spec token ids in the worker process.
|
||||
request.spec_token_ids = [-1] * self.num_spec_tokens
|
||||
|
||||
scheduler_output.pending_structured_output_tokens = (
|
||||
pending_structured_output_tokens
|
||||
)
|
||||
|
||||
def _update_request_with_output(
|
||||
self,
|
||||
request: Request,
|
||||
new_token_ids: list[int],
|
||||
) -> tuple[list[int], bool]:
|
||||
status_before_update = request.status
|
||||
new_token_ids, stopped = super()._update_request_with_output(
|
||||
request, new_token_ids
|
||||
)
|
||||
|
||||
# Update the number of output placeholders.
|
||||
request.num_output_placeholders -= len(new_token_ids)
|
||||
assert request.num_output_placeholders >= 0
|
||||
|
||||
# Cache the new tokens. Preempted requests should be skipped.
|
||||
if status_before_update == RequestStatus.RUNNING:
|
||||
self.kv_cache_manager.cache_blocks(
|
||||
request, request.num_computed_tokens - request.num_output_placeholders
|
||||
)
|
||||
return new_token_ids, stopped
|
||||
|
||||
class MLUUnchunkAsyncScheduler(MLUUnchunkScheduler):
|
||||
def _update_after_schedule(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> None:
|
||||
super()._update_after_schedule(scheduler_output)
|
||||
spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
|
||||
for req_id in scheduler_output.num_scheduled_tokens:
|
||||
request = self.requests[req_id]
|
||||
cur_num_spec_tokens = len(spec_decode_tokens.get(req_id, []))
|
||||
if (
|
||||
request.num_computed_tokens
|
||||
== request.num_tokens
|
||||
+ request.num_output_placeholders
|
||||
+ cur_num_spec_tokens
|
||||
):
|
||||
# The request will generate a new token plus num_spec_tokens
|
||||
# in this scheduling step.
|
||||
request.num_output_placeholders += 1 + cur_num_spec_tokens
|
||||
# Add a placeholder for the new token in spec_token_ids.
|
||||
# because the actual token id is not known yet. so just use -1
|
||||
# as a placeholder and the length of spec_token_ids is set to
|
||||
# self.num_spec_tokens. we will update the actual spec token id
|
||||
# in worker process.
|
||||
request.spec_token_ids = [-1] * self.num_spec_tokens
|
||||
|
||||
def _update_request_with_output(
|
||||
self,
|
||||
request: Request,
|
||||
new_token_ids: list[int],
|
||||
) -> tuple[list[int], bool]:
|
||||
status_before_update = request.status
|
||||
new_token_ids, stopped = super()._update_request_with_output(
|
||||
request, new_token_ids)
|
||||
|
||||
# num_output_placeholders = 0 happend when a request is preempted.
|
||||
# a preempted request will be added to waiting queue again and
|
||||
# num_output_placeholders is reset to 0,
|
||||
# so don't need to revert num_output_placeholders for this situation.
|
||||
if request.num_output_placeholders > 0:
|
||||
# Update the number of output placeholders.
|
||||
request.num_output_placeholders -= len(new_token_ids)
|
||||
assert request.num_output_placeholders >= 0
|
||||
|
||||
# Cache the new tokens. Preempted requests should be skipped.
|
||||
if status_before_update == RequestStatus.RUNNING:
|
||||
self.kv_cache_manager.cache_blocks(
|
||||
request,
|
||||
request.num_computed_tokens - request.num_output_placeholders)
|
||||
return new_token_ids, stopped
|
||||
|
||||
|
||||
def _update_computed_tokens_after_speculation(
|
||||
self, request: Request, num_rejected: int
|
||||
):
|
||||
"""Update the computed tokens for each request, which is necessary
|
||||
for spec decoding. In sync scheduler, we need to revert
|
||||
num_computed_tokens by num_rejected tokens,
|
||||
but in async scheduler, we also need to revert num_output_placeholders
|
||||
by num_rejected tokens for spec decoding.
|
||||
"""
|
||||
# num_computed_tokens = 0 happend when a request is preempted.
|
||||
# a preempted request will be added to waiting queue again and
|
||||
# num_computed_tokens is reset to 0,
|
||||
# so don't need to revert num_computed_tokens for this situation.
|
||||
if request.num_computed_tokens > 0:
|
||||
# when spec decoding is enabled, num_output_placeholders
|
||||
# is increased by num_spec_tokens in _update_after_schedule.
|
||||
# update num_output_placeholders here to reflect the actual number
|
||||
# of accepted output tokens.
|
||||
request.num_output_placeholders -= num_rejected
|
||||
super()._update_computed_tokens_after_speculation(request, num_rejected)
|
||||
111
vllm_mlu/v1/core/sched/output.py
Normal file
111
vllm_mlu/v1/core/sched/output.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from vllm._bc_linter import bc_linter_include
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
|
||||
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.request import Request
|
||||
else:
|
||||
ECConnectorMetadata = object
|
||||
KVConnectorMetadata = object
|
||||
LoRARequest = object
|
||||
MultiModalFeatureSpec = object
|
||||
PoolingParams = object
|
||||
SamplingParams = object
|
||||
Request = object
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Add new_toked_ids to pass the first token generated
|
||||
by the prefiller to the decoder's model_runner.
|
||||
'''
|
||||
@bc_linter_include
|
||||
@dataclass
|
||||
class NewRequestData:
|
||||
req_id: str
|
||||
prompt_token_ids: list[int] | None
|
||||
mm_features: list[MultiModalFeatureSpec]
|
||||
sampling_params: SamplingParams | None
|
||||
pooling_params: PoolingParams | None
|
||||
block_ids: tuple[list[int], ...]
|
||||
num_computed_tokens: int
|
||||
lora_request: LoRARequest | None
|
||||
new_token_ids: list[list[int]]
|
||||
prompt_embeds: "torch.Tensor | None" = None
|
||||
|
||||
@classmethod
|
||||
def from_request(
|
||||
cls,
|
||||
request: Request,
|
||||
block_ids: tuple[list[int], ...],
|
||||
) -> "NewRequestData":
|
||||
return cls(
|
||||
req_id=request.request_id,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
mm_features=request.mm_features,
|
||||
sampling_params=request.sampling_params,
|
||||
pooling_params=request.pooling_params,
|
||||
block_ids=block_ids,
|
||||
num_computed_tokens=request.num_computed_tokens,
|
||||
lora_request=request.lora_request,
|
||||
prompt_embeds=request.prompt_embeds,
|
||||
new_token_ids=request._output_token_ids,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
prompt_embeds_shape = self.prompt_embeds.shape if self.prompt_embeds else None
|
||||
return (
|
||||
f"NewRequestData("
|
||||
f"req_id={self.req_id},"
|
||||
f"prompt_token_ids={self.prompt_token_ids},"
|
||||
f"mm_features={self.mm_features},"
|
||||
f"sampling_params={self.sampling_params},"
|
||||
f"block_ids={self.block_ids},"
|
||||
f"num_computed_tokens={self.num_computed_tokens},"
|
||||
f"lora_request={self.lora_request},"
|
||||
f"prompt_embeds_shape={prompt_embeds_shape},"
|
||||
f"new_token_ids={self.new_token_ids}"
|
||||
")"
|
||||
)
|
||||
|
||||
# Version of __repr__ with the prompt data obfuscated
|
||||
def anon_repr(self) -> str:
|
||||
prompt_token_ids_len = (
|
||||
len(self.prompt_token_ids) if self.prompt_token_ids is not None else None
|
||||
)
|
||||
prompt_embeds_shape = self.prompt_embeds.shape if self.prompt_embeds else None
|
||||
return (
|
||||
f"NewRequestData("
|
||||
f"req_id={self.req_id},"
|
||||
f"prompt_token_ids_len={prompt_token_ids_len},"
|
||||
f"mm_features={self.mm_features},"
|
||||
f"sampling_params={self.sampling_params},"
|
||||
f"block_ids={self.block_ids},"
|
||||
f"num_computed_tokens={self.num_computed_tokens},"
|
||||
f"lora_request={self.lora_request},"
|
||||
f"prompt_embeds_shape={prompt_embeds_shape}"
|
||||
")"
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
1723
vllm_mlu/v1/core/sched/scheduler.py
Normal file
1723
vllm_mlu/v1/core/sched/scheduler.py
Normal file
File diff suppressed because it is too large
Load Diff
21
vllm_mlu/v1/core/single_type_kv_cache_manager.py
Normal file
21
vllm_mlu/v1/core/single_type_kv_cache_manager.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from vllm.v1.core.single_type_kv_cache_manager import (
|
||||
FullAttentionManager,
|
||||
SlidingWindowManager,
|
||||
spec_manager_map,
|
||||
)
|
||||
|
||||
from vllm_mlu.v1.kv_cache_interface import (
|
||||
MLUFullAttentionSpec,
|
||||
MLUMLAAttentionSpec,
|
||||
MLUSlidingWindowSpec,
|
||||
)
|
||||
|
||||
|
||||
spec_manager_map.update({
|
||||
MLUFullAttentionSpec: FullAttentionManager,
|
||||
MLUSlidingWindowSpec: SlidingWindowManager,
|
||||
MLUMLAAttentionSpec: FullAttentionManager,
|
||||
})
|
||||
3
vllm_mlu/v1/engine/__init__.py
Normal file
3
vllm_mlu/v1/engine/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
23
vllm_mlu/v1/engine/async_llm.py
Normal file
23
vllm_mlu/v1/engine/async_llm.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
class AsyncLLM_MluHijack(AsyncLLM):
|
||||
|
||||
async def start_scheduler_profile(self) -> None:
|
||||
await self.engine_core.start_scheduler_profile()
|
||||
|
||||
async def stop_scheduler_profile(self) -> None:
|
||||
await self.engine_core.stop_scheduler_profile()
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(AsyncLLM,
|
||||
"start_scheduler_profile",
|
||||
AsyncLLM_MluHijack.start_scheduler_profile)
|
||||
MluHijackObject.apply_hijack(AsyncLLM,
|
||||
"stop_scheduler_profile",
|
||||
AsyncLLM_MluHijack.stop_scheduler_profile)
|
||||
566
vllm_mlu/v1/engine/core.py
Normal file
566
vllm_mlu/v1/engine/core.py
Normal file
@@ -0,0 +1,566 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from collections import deque
|
||||
import signal
|
||||
from typing import Any, Callable, cast
|
||||
from concurrent.futures import Future
|
||||
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.logger import logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.cache import engine_receiver_cache_from_config
|
||||
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
|
||||
from vllm.utils.gc_utils import freeze_gc_heap
|
||||
from vllm.utils.hashing import get_hash_fn_by_name
|
||||
from vllm.utils.system_utils import decorate_logs, set_process_title
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash, get_request_block_hasher, init_none_hash
|
||||
from vllm.v1.engine import EngineCoreOutputs
|
||||
from vllm.v1.engine.core import (
|
||||
EngineCore,
|
||||
EngineCoreProc,
|
||||
DPEngineCoreProc,
|
||||
)
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.core.sched.interface import SchedulerInterface
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
from logging import DEBUG
|
||||
|
||||
import vllm_mlu._mlu_utils as mlu_envs
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm_mlu.mlu_metric import LLMMetric
|
||||
|
||||
|
||||
class EngineCore_MluHijack(EngineCore):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
executor_fail_callback: Callable | None = None,
|
||||
):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: load_general_plugins in run_engine_core
|
||||
'''
|
||||
# # plugins need to be loaded at the engine/scheduler level too
|
||||
# from vllm.plugins import load_general_plugins
|
||||
# load_general_plugins()
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
if vllm_config.parallel_config.data_parallel_rank == 0:
|
||||
logger.info(
|
||||
"Initializing a V1 LLM engine (v%s) with config: %s",
|
||||
VLLM_VERSION,
|
||||
vllm_config,
|
||||
)
|
||||
|
||||
self.log_stats = log_stats
|
||||
|
||||
# Setup Model.
|
||||
self.model_executor = executor_class(vllm_config)
|
||||
if executor_fail_callback is not None:
|
||||
self.model_executor.register_failure_callback(executor_fail_callback)
|
||||
|
||||
self.available_gpu_memory_for_kv_cache = -1
|
||||
|
||||
# Setup KV Caches and update CacheConfig after profiling.
|
||||
num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
|
||||
vllm_config
|
||||
)
|
||||
|
||||
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks))
|
||||
|
||||
self.structured_output_manager = StructuredOutputManager(vllm_config)
|
||||
|
||||
# Setup scheduler.
|
||||
Scheduler = vllm_config.scheduler_config.get_scheduler_cls()
|
||||
|
||||
if len(kv_cache_config.kv_cache_groups) == 0:
|
||||
# Encoder models without KV cache don't support
|
||||
# chunked prefill. But do SSM models?
|
||||
logger.info("Disabling chunked prefill for model without KVCache")
|
||||
vllm_config.scheduler_config.enable_chunked_prefill = False
|
||||
|
||||
scheduler_block_size = (
|
||||
vllm_config.cache_config.block_size
|
||||
* vllm_config.parallel_config.decode_context_parallel_size
|
||||
)
|
||||
|
||||
self.scheduler: SchedulerInterface = Scheduler(
|
||||
vllm_config=vllm_config,
|
||||
kv_cache_config=kv_cache_config,
|
||||
structured_output_manager=self.structured_output_manager,
|
||||
include_finished_set=vllm_config.parallel_config.data_parallel_size > 1,
|
||||
log_stats=self.log_stats,
|
||||
block_size=scheduler_block_size,
|
||||
)
|
||||
self.use_spec_decode = vllm_config.speculative_config is not None
|
||||
if self.scheduler.connector is not None: # type: ignore
|
||||
self.model_executor.init_kv_output_aggregator(self.scheduler.connector) # type: ignore
|
||||
|
||||
self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
|
||||
self.mm_receiver_cache = engine_receiver_cache_from_config(
|
||||
vllm_config, mm_registry
|
||||
)
|
||||
|
||||
# If a KV connector is initialized for scheduler, we want to collect
|
||||
# handshake metadata from all workers so the connector in the scheduler
|
||||
# will have the full context
|
||||
kv_connector = self.scheduler.get_kv_connector()
|
||||
if kv_connector is not None:
|
||||
# Collect and store KV connector xfer metadata from workers
|
||||
# (after KV cache registration)
|
||||
xfer_handshake_metadata = (
|
||||
self.model_executor.get_kv_connector_handshake_metadata()
|
||||
)
|
||||
|
||||
if xfer_handshake_metadata:
|
||||
# xfer_handshake_metadata is list of dicts from workers
|
||||
# Each dict already has structure {tp_rank: metadata}
|
||||
# Merge all worker dicts into a single dict
|
||||
content: dict[int, Any] = {}
|
||||
for worker_dict in xfer_handshake_metadata:
|
||||
if worker_dict is not None:
|
||||
content.update(worker_dict)
|
||||
kv_connector.set_xfer_handshake_metadata(content)
|
||||
|
||||
# Setup batch queue for pipeline parallelism.
|
||||
# Batch queue for scheduled batches. This enables us to asynchronously
|
||||
# schedule and execute batches, and is required by pipeline parallelism
|
||||
# to eliminate pipeline bubbles.
|
||||
self.batch_queue_size = self.model_executor.max_concurrent_batches
|
||||
self.batch_queue: (
|
||||
deque[tuple[Future[ModelRunnerOutput], SchedulerOutput]] | None
|
||||
) = None
|
||||
if self.batch_queue_size > 1:
|
||||
logger.info("Batch queue is enabled with size %d", self.batch_queue_size)
|
||||
self.batch_queue = deque(maxlen=self.batch_queue_size)
|
||||
|
||||
self.ec_producer = (
|
||||
vllm_config.ec_transfer_config is not None
|
||||
and vllm_config.ec_transfer_config.is_ec_producer
|
||||
)
|
||||
self.is_pooling_model = vllm_config.model_config.runner_type == "pooling"
|
||||
|
||||
self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
|
||||
if vllm_config.cache_config.enable_prefix_caching or kv_connector is not None:
|
||||
caching_hash_fn = get_hash_fn_by_name(
|
||||
vllm_config.cache_config.prefix_caching_hash_algo
|
||||
)
|
||||
init_none_hash(caching_hash_fn)
|
||||
|
||||
self.request_block_hasher = get_request_block_hasher(
|
||||
scheduler_block_size, caching_hash_fn
|
||||
)
|
||||
|
||||
self.step_fn = (
|
||||
self.step if self.batch_queue is None else self.step_with_batch_queue
|
||||
)
|
||||
self.async_scheduling = vllm_config.scheduler_config.async_scheduling
|
||||
|
||||
# Mark the startup heap as static so that it's ignored by GC.
|
||||
# Reduces pause times of oldest generation collections.
|
||||
freeze_gc_heap()
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: v1 support offline benchmark
|
||||
'''
|
||||
self.step_latency = []
|
||||
self.model_exec_latency = []
|
||||
self.mm_encoder_latency = []
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
self.num_cpu_blocks = num_cpu_blocks
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
|
||||
"""Schedule, execute, and make output.
|
||||
|
||||
Returns tuple of outputs and a flag indicating whether the model
|
||||
was executed.
|
||||
"""
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: v1 support offline benchmark
|
||||
'''
|
||||
if mlu_envs.VLLM_LATENCY_DEBUG_EN:
|
||||
step_start = LLMMetric.get_mlu_cost_time()
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
# Check for any requests remaining in the scheduler - unfinished,
|
||||
# or finished and not yet removed from the batch.
|
||||
if not self.scheduler.has_requests():
|
||||
return {}, False
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
future = self.model_executor.execute_model(scheduler_output, non_block=True)
|
||||
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
|
||||
with self.log_error_detail(scheduler_output):
|
||||
model_output = future.result()
|
||||
if model_output is None:
|
||||
model_output = self.model_executor.sample_tokens(grammar_output)
|
||||
|
||||
if self.use_spec_decode and \
|
||||
self.vllm_config.kv_transfer_config is not None and \
|
||||
self.vllm_config.kv_transfer_config.kv_role == "kv_producer":
|
||||
draft_token_ids = self.model_executor.take_draft_token_ids()
|
||||
self.scheduler.draft_token_ids = draft_token_ids
|
||||
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, model_output
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: v1 support offline benchmark
|
||||
'''
|
||||
has_sched_reqs = (scheduler_output.total_num_scheduled_tokens > 0)
|
||||
if mlu_envs.VLLM_LATENCY_DEBUG_EN and has_sched_reqs:
|
||||
self.step_latency.append(LLMMetric.get_mlu_cost_time() - step_start)
|
||||
if mlu_envs.VLLM_LATENCY_DEBUG_WITH_DEVICE_EN and has_sched_reqs:
|
||||
self.model_exec_latency.append(self.get_model_exec_latency())
|
||||
mm_encoder_latency = self.get_mm_encoder_latency()
|
||||
if mm_encoder_latency:
|
||||
self.mm_encoder_latency.append(mm_encoder_latency)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
|
||||
|
||||
def step_with_batch_queue(
|
||||
self,
|
||||
) -> tuple[dict[int, EngineCoreOutputs] | None, bool]:
|
||||
"""Schedule and execute batches with the batch queue.
|
||||
Note that if nothing to output in this step, None is returned.
|
||||
|
||||
The execution flow is as follows:
|
||||
1. Try to schedule a new batch if the batch queue is not full.
|
||||
If a new batch is scheduled, directly return an empty engine core
|
||||
output. In other words, fulfilling the batch queue has a higher priority
|
||||
than getting model outputs.
|
||||
2. If there is no new scheduled batch, meaning that the batch queue
|
||||
is full or no other requests can be scheduled, we block until the first
|
||||
batch in the job queue is finished.
|
||||
3. Update the scheduler from the output.
|
||||
"""
|
||||
batch_queue = self.batch_queue
|
||||
assert batch_queue is not None
|
||||
|
||||
# Try to schedule a new batch if the batch queue is not full, but
|
||||
# the scheduler may return an empty batch if all requests are scheduled.
|
||||
# Note that this is not blocking.
|
||||
assert len(batch_queue) < self.batch_queue_size
|
||||
|
||||
model_executed = False
|
||||
deferred_scheduler_output = None
|
||||
if self.scheduler.has_requests():
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
exec_future = self.model_executor.execute_model(
|
||||
scheduler_output, non_block=True
|
||||
)
|
||||
if not self.ec_producer:
|
||||
model_executed = scheduler_output.total_num_scheduled_tokens > 0
|
||||
|
||||
if self.is_pooling_model or not model_executed:
|
||||
# No sampling required (no requests scheduled).
|
||||
future = cast(Future[ModelRunnerOutput], exec_future)
|
||||
else:
|
||||
exec_future.add_done_callback(self._log_err_callback(scheduler_output))
|
||||
|
||||
if not scheduler_output.pending_structured_output_tokens:
|
||||
# We aren't waiting for any tokens, get any grammar output
|
||||
# and sample immediately.
|
||||
grammar_output = self.scheduler.get_grammar_bitmask(
|
||||
scheduler_output
|
||||
)
|
||||
future = self.model_executor.sample_tokens(
|
||||
grammar_output, non_block=True
|
||||
)
|
||||
else:
|
||||
# We need to defer sampling until we have processed the model output
|
||||
# from the prior step.
|
||||
deferred_scheduler_output = scheduler_output
|
||||
|
||||
if not deferred_scheduler_output:
|
||||
# Add this step's future to the queue.
|
||||
batch_queue.appendleft((future, scheduler_output))
|
||||
if (
|
||||
model_executed
|
||||
and len(batch_queue) < self.batch_queue_size
|
||||
and not batch_queue[-1][0].done()
|
||||
):
|
||||
# Don't block on next worker response unless the queue is full
|
||||
# or there are no more requests to schedule.
|
||||
return None, True
|
||||
|
||||
elif not batch_queue:
|
||||
# Queue is empty. We should not reach here since this method should
|
||||
# only be called when the scheduler contains requests or the queue
|
||||
# is non-empty.
|
||||
return None, False
|
||||
|
||||
# Block until the next result is available.
|
||||
future, scheduler_output = batch_queue.pop()
|
||||
with self.log_error_detail(scheduler_output):
|
||||
model_output = future.result()
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: supoort disagg for mlu.
|
||||
'''
|
||||
if self.use_spec_decode and \
|
||||
self.vllm_config.kv_transfer_config is not None and \
|
||||
self.vllm_config.kv_transfer_config.kv_role == "kv_producer":
|
||||
draft_token_ids = self.model_executor.take_draft_token_ids()
|
||||
self.scheduler.draft_token_ids = draft_token_ids
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, model_output
|
||||
)
|
||||
|
||||
# NOTE(nick): We can either handle the deferred tasks here or save
|
||||
# in a field and do it immediately once step_with_batch_queue is
|
||||
# re-called. The latter slightly favors TTFT over TPOT/throughput.
|
||||
if deferred_scheduler_output:
|
||||
# We now have the tokens needed to compute the bitmask for the
|
||||
# deferred request. Get the bitmask and call sample tokens.
|
||||
grammar_output = self.scheduler.get_grammar_bitmask(
|
||||
deferred_scheduler_output
|
||||
)
|
||||
future = self.model_executor.sample_tokens(grammar_output, non_block=True)
|
||||
batch_queue.appendleft((future, deferred_scheduler_output))
|
||||
|
||||
return engine_core_outputs, model_executed
|
||||
|
||||
def get_model_exec_latency(self):
|
||||
latency = self.model_executor.get_latency()
|
||||
return latency
|
||||
|
||||
def get_mm_encoder_latency(self):
|
||||
return self.model_executor.get_mm_encoder_latency()
|
||||
|
||||
def get_hfu_info(self, batch, input_len, output_len):
|
||||
return self.model_executor.get_hfu_info(batch, input_len, output_len)
|
||||
|
||||
def get_latency(self):
|
||||
return (self.step_latency, self.model_exec_latency, self.mm_encoder_latency)
|
||||
|
||||
def get_memory_usage(self):
|
||||
peak_memory, block_memory = self.model_executor.get_memory_usage()
|
||||
return (peak_memory, block_memory,
|
||||
self.num_gpu_blocks, self.num_cpu_blocks)
|
||||
|
||||
def recapture_model(self,
|
||||
prefill_enable_mlugraph: bool,
|
||||
batch_size: int,
|
||||
input_len: int):
|
||||
self.model_executor.recapture_model(
|
||||
prefill_enable_mlugraph, batch_size, input_len)
|
||||
|
||||
def init_metric(self, use_unchunk_sched: bool, min_prefill_batch: int):
|
||||
self.step_latency = []
|
||||
self.model_exec_latency = []
|
||||
self.mm_encoder_latency = []
|
||||
mlu_envs.VLLM_V1_USE_UNCHUNK_SCHED = use_unchunk_sched
|
||||
mlu_envs.VLLM_V1_MIN_PREFILL_BATCH = min_prefill_batch
|
||||
|
||||
def start_scheduler_profile(self):
|
||||
self.scheduler.start_scheduler_profile()
|
||||
|
||||
def stop_scheduler_profile(self):
|
||||
self.scheduler.stop_scheduler_profile()
|
||||
|
||||
def response_remote_alloc_once(self):
|
||||
self.model_executor.response_remote_alloc_once()
|
||||
|
||||
|
||||
class EngineCoreProc_MluHijack(EngineCoreProc):
|
||||
|
||||
@staticmethod
|
||||
def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
|
||||
"""Launch EngineCore busy loop in background process."""
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: load_general_plugins for mp backend engine
|
||||
'''
|
||||
# plugins need to be loaded at the engine/scheduler level too
|
||||
from vllm.plugins import load_general_plugins
|
||||
load_general_plugins()
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
# Signal handler used for graceful termination.
|
||||
# SystemExit exception is only raised once to allow this and worker
|
||||
# processes to terminate without error
|
||||
shutdown_requested = False
|
||||
|
||||
# Ensure we can serialize transformer config after spawning
|
||||
maybe_register_config_serialize_by_value()
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
nonlocal shutdown_requested
|
||||
if not shutdown_requested:
|
||||
shutdown_requested = True
|
||||
raise SystemExit()
|
||||
|
||||
# Either SIGTERM or SIGINT will terminate the engine_core
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
engine_core: EngineCoreProc | None = None
|
||||
try:
|
||||
parallel_config: ParallelConfig = kwargs["vllm_config"].parallel_config
|
||||
if parallel_config.data_parallel_size > 1 or dp_rank > 0:
|
||||
set_process_title("EngineCore", f"DP{dp_rank}")
|
||||
decorate_logs()
|
||||
# Set data parallel rank for this engine process.
|
||||
parallel_config.data_parallel_rank = dp_rank
|
||||
parallel_config.data_parallel_rank_local = local_dp_rank
|
||||
engine_core = DPEngineCoreProc(*args, **kwargs)
|
||||
else:
|
||||
set_process_title("EngineCore")
|
||||
decorate_logs()
|
||||
engine_core = EngineCoreProc(*args, **kwargs)
|
||||
|
||||
engine_core.run_busy_loop()
|
||||
|
||||
except SystemExit:
|
||||
logger.debug("EngineCore exiting.")
|
||||
raise
|
||||
except Exception as e:
|
||||
if engine_core is None:
|
||||
logger.exception("EngineCore failed to start.")
|
||||
else:
|
||||
logger.exception("EngineCore encountered a fatal error.")
|
||||
engine_core._send_engine_dead()
|
||||
raise e
|
||||
finally:
|
||||
if engine_core is not None:
|
||||
engine_core.shutdown()
|
||||
|
||||
def _process_input_queue(self):
|
||||
"""Exits when an engine step needs to be performed."""
|
||||
|
||||
waited = False
|
||||
while (
|
||||
not self.engines_running
|
||||
and not self.scheduler.has_requests()
|
||||
and not self.batch_queue
|
||||
):
|
||||
if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
|
||||
logger.debug("EngineCore waiting for work.")
|
||||
waited = True
|
||||
|
||||
if self.vllm_config.kv_transfer_config is not None and \
|
||||
self.vllm_config.kv_transfer_config.kv_role == "kv_consumer":
|
||||
self.response_remote_alloc_once()
|
||||
if self.input_queue.empty():
|
||||
continue
|
||||
req = self.input_queue.get_nowait()
|
||||
self._handle_client_request(*req)
|
||||
else:
|
||||
req = self.input_queue.get()
|
||||
self._handle_client_request(*req)
|
||||
|
||||
if waited:
|
||||
logger.debug("EngineCore loop active.")
|
||||
|
||||
if self.vllm_config.kv_transfer_config is not None and \
|
||||
self.vllm_config.kv_transfer_config.kv_role == "kv_consumer":
|
||||
self.response_remote_alloc_once()
|
||||
|
||||
# Handle any more client requests.
|
||||
while not self.input_queue.empty():
|
||||
req = self.input_queue.get_nowait()
|
||||
self._handle_client_request(*req)
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(EngineCore,
|
||||
"get_mm_encoder_latency",
|
||||
EngineCore_MluHijack.get_mm_encoder_latency)
|
||||
MluHijackObject.apply_hijack(EngineCore,
|
||||
"get_model_exec_latency",
|
||||
EngineCore_MluHijack.get_model_exec_latency)
|
||||
MluHijackObject.apply_hijack(EngineCore,
|
||||
"get_hfu_info",
|
||||
EngineCore_MluHijack.get_hfu_info)
|
||||
MluHijackObject.apply_hijack(EngineCore,
|
||||
"get_latency",
|
||||
EngineCore_MluHijack.get_latency)
|
||||
MluHijackObject.apply_hijack(EngineCore,
|
||||
"get_memory_usage",
|
||||
EngineCore_MluHijack.get_memory_usage)
|
||||
MluHijackObject.apply_hijack(EngineCore,
|
||||
"recapture_model",
|
||||
EngineCore_MluHijack.recapture_model)
|
||||
MluHijackObject.apply_hijack(EngineCore,
|
||||
"init_metric",
|
||||
EngineCore_MluHijack.init_metric)
|
||||
MluHijackObject.apply_hijack(EngineCore,
|
||||
"start_scheduler_profile",
|
||||
EngineCore_MluHijack.start_scheduler_profile)
|
||||
MluHijackObject.apply_hijack(EngineCore,
|
||||
"stop_scheduler_profile",
|
||||
EngineCore_MluHijack.stop_scheduler_profile)
|
||||
MluHijackObject.apply_hijack(EngineCore,
|
||||
EngineCore.__init__,
|
||||
EngineCore_MluHijack.__init__)
|
||||
MluHijackObject.apply_hijack(EngineCore,
|
||||
EngineCore.step,
|
||||
EngineCore_MluHijack.step)
|
||||
MluHijackObject.apply_hijack(EngineCore,
|
||||
"response_remote_alloc_once",
|
||||
EngineCore_MluHijack.response_remote_alloc_once)
|
||||
MluHijackObject.apply_hijack(EngineCore,
|
||||
EngineCore.step_with_batch_queue,
|
||||
EngineCore_MluHijack.step_with_batch_queue)
|
||||
MluHijackObject.apply_hijack(EngineCoreProc,
|
||||
EngineCoreProc.run_engine_core,
|
||||
EngineCoreProc_MluHijack.run_engine_core)
|
||||
MluHijackObject.apply_hijack(EngineCoreProc,
|
||||
EngineCoreProc._process_input_queue,
|
||||
EngineCoreProc_MluHijack._process_input_queue)
|
||||
227
vllm_mlu/v1/engine/core_client.py
Normal file
227
vllm_mlu/v1/engine/core_client.py
Normal file
@@ -0,0 +1,227 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from vllm.v1.engine.core_client import (
|
||||
EngineCoreClient,
|
||||
InprocClient,
|
||||
SyncMPClient,
|
||||
AsyncMPClient,
|
||||
DPAsyncMPClient,
|
||||
DPLBAsyncMPClient,
|
||||
)
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.executor import Executor
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
class EngineCoreClient_MluHiack(EngineCoreClient):
|
||||
|
||||
@staticmethod
|
||||
def make_async_mp_client(
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
client_addresses: dict[str, str] | None = None,
|
||||
client_count: int = 1,
|
||||
client_index: int = 0,
|
||||
) -> "MPClient":
|
||||
parallel_config = vllm_config.parallel_config
|
||||
client_args = (
|
||||
vllm_config,
|
||||
executor_class,
|
||||
log_stats,
|
||||
client_addresses,
|
||||
client_count,
|
||||
client_index,
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: disagg use DPAsyncMPClient instead of DPLBAsyncMPClient.
|
||||
'''
|
||||
if parallel_config.data_parallel_size > 1:
|
||||
if parallel_config.data_parallel_external_lb or vllm_config.kv_transfer_config is not None:
|
||||
# External load balancer - client per DP rank.
|
||||
return DPAsyncMPClient(*client_args)
|
||||
# Internal load balancer - client balances to all DP ranks.
|
||||
return DPLBAsyncMPClient(*client_args)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return AsyncMPClient(*client_args)
|
||||
|
||||
|
||||
class InprocClient_MluHiack(InprocClient):
|
||||
|
||||
def get_hfu_info(self, batch, input_len, output_len):
|
||||
return self.engine_core.get_hfu_info(batch, input_len, output_len)
|
||||
|
||||
def get_latency(self):
|
||||
return self.engine_core.get_latency()
|
||||
|
||||
def get_memory_usage(self):
|
||||
return self.engine_core.get_memory_usage()
|
||||
|
||||
def recapture_model(
|
||||
self,
|
||||
prefill_enable_mlugraph: bool,
|
||||
batch_size: int,
|
||||
input_len: int,
|
||||
):
|
||||
return self.engine_core.recapture_model(
|
||||
prefill_enable_mlugraph, batch_size, input_len
|
||||
)
|
||||
|
||||
def init_metric(self, use_unchunk_sched: bool, min_prefill_batch: int):
|
||||
return self.engine_core.init_metric(
|
||||
use_unchunk_sched, min_prefill_batch,
|
||||
)
|
||||
|
||||
def start_scheduler_profile(self):
|
||||
self.engine_core.start_scheduler_profile()
|
||||
|
||||
def stop_scheduler_profile(self):
|
||||
self.engine_core.stop_scheduler_profile()
|
||||
|
||||
def response_remote_alloc_once(self) -> None:
|
||||
self.engine_core.response_remote_alloc_once()
|
||||
|
||||
|
||||
class SyncMPClient_MluHiack(SyncMPClient):
|
||||
|
||||
def get_hfu_info(self, batch, input_len, output_len):
|
||||
try:
|
||||
return self.call_utility("get_hfu_info", batch, input_len, output_len)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to get HFU info: {str(e)}")
|
||||
|
||||
def get_latency(self):
|
||||
return self.call_utility("get_latency")
|
||||
|
||||
def get_memory_usage(self):
|
||||
return self.call_utility("get_memory_usage")
|
||||
|
||||
def recapture_model(self,
|
||||
prefill_enable_mlugraph: bool,
|
||||
batch_size: int,
|
||||
input_len: int):
|
||||
return self.call_utility("recapture_model",
|
||||
prefill_enable_mlugraph, batch_size, input_len)
|
||||
|
||||
def init_metric(self, use_unchunk_sched: bool, min_prefill_batch: int):
|
||||
return self.call_utility("init_metric",
|
||||
use_unchunk_sched,
|
||||
min_prefill_batch)
|
||||
|
||||
def start_scheduler_profile(self):
|
||||
self.call_utility("start_scheduler_profile")
|
||||
|
||||
def stop_scheduler_profile(self):
|
||||
self.call_utility("stop_scheduler_profile")
|
||||
|
||||
def response_remote_alloc_once(self) -> None:
|
||||
self.call_utility("response_remote_alloc_once")
|
||||
|
||||
|
||||
class AsyncMPClient_MluHijack(AsyncMPClient):
|
||||
|
||||
async def start_scheduler_profile(self) -> None:
|
||||
await self.call_utility_async("start_scheduler_profile")
|
||||
|
||||
async def stop_scheduler_profile(self) -> None:
|
||||
await self.call_utility_async("stop_scheduler_profile")
|
||||
|
||||
async def response_remote_alloc_once(self) -> None:
|
||||
await self.call_utility_async("response_remote_alloc_once")
|
||||
|
||||
|
||||
class DPAsyncMPClient_MluHijack(DPAsyncMPClient):
|
||||
|
||||
def get_core_engine_for_request(self, request: EngineCoreRequest):
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: disagg need proxy to assign dp_rank
|
||||
'''
|
||||
if request.data_parallel_rank is not None:
|
||||
# engines are already in rank order
|
||||
return self.core_engines[request.data_parallel_rank]
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
return self.core_engine
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(EngineCoreClient,
|
||||
EngineCoreClient.make_async_mp_client,
|
||||
EngineCoreClient_MluHiack.make_async_mp_client)
|
||||
MluHijackObject.apply_hijack(InprocClient,
|
||||
"get_hfu_info",
|
||||
InprocClient_MluHiack.get_hfu_info)
|
||||
MluHijackObject.apply_hijack(InprocClient,
|
||||
"get_latency",
|
||||
InprocClient_MluHiack.get_latency)
|
||||
MluHijackObject.apply_hijack(InprocClient,
|
||||
"get_memory_usage",
|
||||
InprocClient_MluHiack.get_memory_usage)
|
||||
MluHijackObject.apply_hijack(InprocClient,
|
||||
"recapture_model",
|
||||
InprocClient_MluHiack.recapture_model)
|
||||
MluHijackObject.apply_hijack(InprocClient,
|
||||
"init_metric",
|
||||
InprocClient_MluHiack.init_metric)
|
||||
MluHijackObject.apply_hijack(InprocClient,
|
||||
"start_scheduler_profile",
|
||||
InprocClient_MluHiack.start_scheduler_profile)
|
||||
MluHijackObject.apply_hijack(InprocClient,
|
||||
"stop_scheduler_profile",
|
||||
InprocClient_MluHiack.stop_scheduler_profile)
|
||||
MluHijackObject.apply_hijack(InprocClient,
|
||||
"response_remote_alloc_once",
|
||||
InprocClient_MluHiack.response_remote_alloc_once)
|
||||
MluHijackObject.apply_hijack(SyncMPClient,
|
||||
"get_hfu_info",
|
||||
SyncMPClient_MluHiack.get_hfu_info)
|
||||
MluHijackObject.apply_hijack(SyncMPClient,
|
||||
"get_latency",
|
||||
SyncMPClient_MluHiack.get_latency)
|
||||
MluHijackObject.apply_hijack(SyncMPClient,
|
||||
"get_memory_usage",
|
||||
SyncMPClient_MluHiack.get_memory_usage)
|
||||
MluHijackObject.apply_hijack(SyncMPClient,
|
||||
"recapture_model",
|
||||
SyncMPClient_MluHiack.recapture_model)
|
||||
MluHijackObject.apply_hijack(SyncMPClient,
|
||||
"init_metric",
|
||||
SyncMPClient_MluHiack.init_metric)
|
||||
MluHijackObject.apply_hijack(SyncMPClient,
|
||||
"start_scheduler_profile",
|
||||
SyncMPClient_MluHiack.start_scheduler_profile)
|
||||
MluHijackObject.apply_hijack(SyncMPClient,
|
||||
"stop_scheduler_profile",
|
||||
SyncMPClient_MluHiack.stop_scheduler_profile)
|
||||
MluHijackObject.apply_hijack(SyncMPClient,
|
||||
"response_remote_alloc_once",
|
||||
SyncMPClient_MluHiack.response_remote_alloc_once)
|
||||
MluHijackObject.apply_hijack(AsyncMPClient,
|
||||
"start_scheduler_profile",
|
||||
AsyncMPClient_MluHijack.start_scheduler_profile)
|
||||
MluHijackObject.apply_hijack(AsyncMPClient,
|
||||
"stop_scheduler_profile",
|
||||
AsyncMPClient_MluHijack.stop_scheduler_profile)
|
||||
MluHijackObject.apply_hijack(AsyncMPClient,
|
||||
"response_remote_alloc_once",
|
||||
AsyncMPClient_MluHijack.response_remote_alloc_once)
|
||||
MluHijackObject.apply_hijack(DPAsyncMPClient,
|
||||
DPAsyncMPClient.get_core_engine_for_request,
|
||||
DPAsyncMPClient_MluHijack.get_core_engine_for_request)
|
||||
43
vllm_mlu/v1/engine/llm_engine.py
Normal file
43
vllm_mlu/v1/engine/llm_engine.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from vllm.v1.engine.llm_engine import LLMEngine
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
def vllm__engine__llm_engine__LLMEngine__get_hfu_info(self, batch, input_len, output_len):
|
||||
return self.engine_core.get_hfu_info(batch, input_len, output_len)
|
||||
|
||||
|
||||
def vllm__engine__llm_engine__LLMEngine__get_latency(self):
|
||||
return self.engine_core.get_latency()
|
||||
|
||||
|
||||
def vllm__engine__llm_engine__LLMEngine__get_memory_usage(self):
|
||||
return self.engine_core.get_memory_usage()
|
||||
|
||||
|
||||
def vllm__engine__llm_engine__LLMEngine__start_scheduler_profile(self):
|
||||
self.engine_core.start_scheduler_profile()
|
||||
|
||||
|
||||
def vllm__engine__llm_engine__LLMEngine__stop_scheduler_profile(self):
|
||||
self.engine_core.stop_scheduler_profile()
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(LLMEngine,
|
||||
"get_hfu_info",
|
||||
vllm__engine__llm_engine__LLMEngine__get_hfu_info)
|
||||
MluHijackObject.apply_hijack(LLMEngine,
|
||||
"get_latency",
|
||||
vllm__engine__llm_engine__LLMEngine__get_latency)
|
||||
MluHijackObject.apply_hijack(LLMEngine,
|
||||
"get_memory_usage",
|
||||
vllm__engine__llm_engine__LLMEngine__get_memory_usage)
|
||||
MluHijackObject.apply_hijack(LLMEngine,
|
||||
"start_scheduler_profile",
|
||||
vllm__engine__llm_engine__LLMEngine__start_scheduler_profile)
|
||||
MluHijackObject.apply_hijack(LLMEngine,
|
||||
"stop_scheduler_profile",
|
||||
vllm__engine__llm_engine__LLMEngine__stop_scheduler_profile)
|
||||
3
vllm_mlu/v1/executor/__init__.py
Normal file
3
vllm_mlu/v1/executor/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
57
vllm_mlu/v1/executor/abstract.py
Normal file
57
vllm_mlu/v1/executor/abstract.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
def vllm__v1__executor__abstract__Executor__get_hfu_info(self, batch, input_len, output_len):
|
||||
output = self.collective_rpc("get_hfu_info", args=([batch, input_len, output_len]))
|
||||
return max(output)
|
||||
|
||||
def vllm__v1__executor__abstract__Executor__get_mm_encoder_latency(self):
|
||||
output = self.collective_rpc("get_mm_encoder_latency")
|
||||
return None if any(item is None for item in output) else max(output)
|
||||
|
||||
def vllm__v1__executor__abstract__Executor__get_latency(self):
|
||||
output = self.collective_rpc("get_latency")
|
||||
return max(output)
|
||||
|
||||
|
||||
def vllm__v1__executor__abstract__Executor__get_memory_usage(self):
|
||||
output = self.collective_rpc("get_memory_usage")
|
||||
return output[0]
|
||||
|
||||
|
||||
def vllm__v1__executor__abstract__Executor__recapture_model(
|
||||
self, prefill_enable_mlugraph: bool, batch_size: int, input_len: int):
|
||||
self.collective_rpc("recapture_model",
|
||||
args=(prefill_enable_mlugraph, batch_size, input_len))
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
Executor,
|
||||
"get_hfu_info",
|
||||
vllm__v1__executor__abstract__Executor__get_hfu_info
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
Executor,
|
||||
"get_latency",
|
||||
vllm__v1__executor__abstract__Executor__get_latency
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
Executor,
|
||||
"get_mm_encoder_latency",
|
||||
vllm__v1__executor__abstract__Executor__get_mm_encoder_latency
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
Executor,
|
||||
"get_memory_usage",
|
||||
vllm__v1__executor__abstract__Executor__get_memory_usage
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
Executor,
|
||||
"recapture_model",
|
||||
vllm__v1__executor__abstract__Executor__recapture_model
|
||||
)
|
||||
15
vllm_mlu/v1/executor/multiproc_executor.py
Normal file
15
vllm_mlu/v1/executor/multiproc_executor.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
class MultiprocExecutor_MluHijack(MultiprocExecutor):
|
||||
|
||||
def response_remote_alloc_once(self) -> None:
|
||||
self.collective_rpc("response_remote_alloc_once", unique_reply_rank=self.output_rank)
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(MultiprocExecutor,
|
||||
"response_remote_alloc_once",
|
||||
MultiprocExecutor_MluHijack.response_remote_alloc_once)
|
||||
363
vllm_mlu/v1/executor/ray_executor.py
Normal file
363
vllm_mlu/v1/executor/ray_executor.py
Normal file
@@ -0,0 +1,363 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.ray.ray_env import get_env_vars_to_copy
|
||||
from vllm.v1.executor.ray_executor import RayDistributedExecutor, RayWorkerMetaData
|
||||
from vllm.v1.executor.ray_utils import (
|
||||
RayWorkerWrapper,
|
||||
initialize_ray_cluster,
|
||||
ray,
|
||||
)
|
||||
from vllm.utils.network_utils import (
|
||||
get_distributed_init_method,
|
||||
get_ip,
|
||||
get_open_port,
|
||||
)
|
||||
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class RayDistributedExecutor_MluHijack(RayDistributedExecutor):
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
self.forward_dag: ray.dag.CompiledDAG | None = None
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: For MLU, avoid compiling NVIDIA's NCCL
|
||||
'''
|
||||
# For TPU or XPU, avoid compiling NVIDIA's NCCL
|
||||
if current_platform.is_tpu() or current_platform.is_xpu() or \
|
||||
current_platform.is_out_of_tree():
|
||||
os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm"
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
assert self.uses_ray
|
||||
initialize_ray_cluster(self.parallel_config)
|
||||
placement_group = self.parallel_config.placement_group
|
||||
|
||||
# Disable Ray usage stats collection.
|
||||
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
|
||||
if ray_usage != "1":
|
||||
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
|
||||
|
||||
# Create the parallel GPU workers.
|
||||
self._init_workers_ray(placement_group)
|
||||
|
||||
# KV connector setup
|
||||
self.has_connector = self.vllm_config.kv_transfer_config is not None
|
||||
|
||||
self.uses_sampler = self.vllm_config.model_config.runner_type != "pooling" and (
|
||||
self.vllm_config.ec_transfer_config is None
|
||||
or not self.vllm_config.ec_transfer_config.is_ec_producer
|
||||
)
|
||||
|
||||
self.scheduler_output: SchedulerOutput | None = None
|
||||
|
||||
def _configure_ray_workers_use_nsight(self, ray_remote_kwargs) -> dict[str, Any]:
|
||||
# If nsight profiling is enabled, we need to set the profiling
|
||||
# configuration for the ray workers as runtime env.
|
||||
runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use default cnperf config.
|
||||
'''
|
||||
runtime_env.update({
|
||||
# use default cnperf config
|
||||
"nsight": "default"
|
||||
})
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
return ray_remote_kwargs
|
||||
|
||||
def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs):
|
||||
num_gpus = envs.VLLM_RAY_PER_WORKER_GPUS
|
||||
|
||||
# The driver dummy worker does not actually use any resources.
|
||||
# It holds the resource for the driver worker.
|
||||
self.driver_dummy_worker: RayWorkerWrapper | None = None
|
||||
# The remaining workers are the actual ray actors.
|
||||
self.workers: list[RayWorkerWrapper] = []
|
||||
|
||||
# Used in ray compiled DAG: indexed first by PP rank,
|
||||
# and then TP rank. In other words, the inner list is
|
||||
# the TP group of workers for a PP rank.
|
||||
self.pp_tp_workers: list[list[RayWorkerWrapper]] = []
|
||||
|
||||
if self.parallel_config.ray_workers_use_nsight:
|
||||
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
|
||||
ray_remote_kwargs
|
||||
)
|
||||
|
||||
# Create the workers.
|
||||
bundle_indices: list[int]
|
||||
if envs.VLLM_RAY_BUNDLE_INDICES:
|
||||
# Use the bundle indices specified by the user.
|
||||
bundle_indices = list(map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(",")))
|
||||
assert len(bundle_indices) == self.parallel_config.world_size, (
|
||||
"VLLM_RAY_BUNDLE_INDICES must have the same size"
|
||||
f" as the world size, but got {bundle_indices=} "
|
||||
f"and {self.parallel_config.world_size=}"
|
||||
)
|
||||
assert len(set(bundle_indices)) == len(bundle_indices), (
|
||||
"VLLM_RAY_BUNDLE_INDICES cannot have duplicate values,"
|
||||
f" but got {bundle_indices=}"
|
||||
)
|
||||
else:
|
||||
# use the first N bundles that have GPU resources.
|
||||
bundle_indices = []
|
||||
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
||||
if bundle.get(current_platform.ray_device_key, 0):
|
||||
bundle_indices.append(bundle_id)
|
||||
bundle_indices = bundle_indices[: self.parallel_config.world_size]
|
||||
|
||||
worker_metadata: list[RayWorkerMetaData] = []
|
||||
driver_ip = get_ip()
|
||||
for rank, bundle_id in enumerate(bundle_indices):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: support ray + cnperf-cli
|
||||
'''
|
||||
if self.parallel_config.ray_workers_use_nsight:
|
||||
ray_remote_kwargs['runtime_env'].update({
|
||||
"nsight": {
|
||||
"o": f"cnperf_rank_{rank}",
|
||||
"force_overwrite": "true"
|
||||
}
|
||||
})
|
||||
if rank == 0:
|
||||
ray_remote_kwargs['runtime_env'].update({
|
||||
"nsight": {}
|
||||
})
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
scheduling_strategy = PlacementGroupSchedulingStrategy(
|
||||
placement_group=placement_group,
|
||||
placement_group_capture_child_tasks=True,
|
||||
placement_group_bundle_index=bundle_id,
|
||||
)
|
||||
|
||||
if current_platform.ray_device_key == "GPU":
|
||||
# NV+AMD GPUs, and Intel XPUs
|
||||
worker = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=num_gpus,
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorkerWrapper).remote( # type: ignore[attr-defined]
|
||||
vllm_config=self.vllm_config, rpc_rank=rank
|
||||
)
|
||||
else:
|
||||
worker = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=0,
|
||||
resources={current_platform.ray_device_key: num_gpus},
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorkerWrapper).remote( # type: ignore[attr-defined]
|
||||
vllm_config=self.vllm_config, rpc_rank=rank
|
||||
)
|
||||
worker_metadata.append(RayWorkerMetaData(worker=worker, created_rank=rank))
|
||||
|
||||
worker_ips = ray.get(
|
||||
[
|
||||
each.worker.get_node_ip.remote() # type: ignore[attr-defined]
|
||||
for each in worker_metadata
|
||||
]
|
||||
)
|
||||
|
||||
for each, ip in zip(worker_metadata, worker_ips):
|
||||
each.ip = ip
|
||||
|
||||
logger.debug("workers: %s", worker_metadata)
|
||||
logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
|
||||
|
||||
ip_counts: dict[str, int] = {}
|
||||
for ip in worker_ips:
|
||||
ip_counts[ip] = ip_counts.get(ip, 0) + 1
|
||||
|
||||
def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
|
||||
"""
|
||||
Sort the workers based on 3 properties:
|
||||
1. If the worker is on the same node as the driver (vllm engine),
|
||||
it should be placed first.
|
||||
2. Then, if the worker is on a node with fewer workers, it should
|
||||
be placed first.
|
||||
3. Finally, if the work is on a node with smaller IP address, it
|
||||
should be placed first.
|
||||
"""
|
||||
ip = item.ip
|
||||
return 0 if ip == driver_ip else 1, ip_counts[ip], ip
|
||||
|
||||
# After sorting, the workers on the same node will be
|
||||
# close to each other, and the workers on the driver
|
||||
# node will be placed first.
|
||||
sorted_worker_metadata = sorted(
|
||||
worker_metadata, key=sort_by_driver_then_worker_ip
|
||||
)
|
||||
for i, item in enumerate(sorted_worker_metadata):
|
||||
item.adjusted_rank = i
|
||||
self.workers = [item.worker for item in sorted_worker_metadata]
|
||||
rerank_mapping = {
|
||||
item.created_rank: item.adjusted_rank for item in sorted_worker_metadata
|
||||
}
|
||||
self.collective_rpc("adjust_rank", args=(rerank_mapping,))
|
||||
|
||||
# Get the set of GPU IDs used on each node.
|
||||
worker_node_and_gpu_ids = []
|
||||
for worker in [self.driver_dummy_worker] + self.workers:
|
||||
if worker is None:
|
||||
# driver_dummy_worker can be None when using ray spmd worker.
|
||||
continue
|
||||
worker_node_and_gpu_ids.append(
|
||||
ray.get(worker.get_node_and_gpu_ids.remote())
|
||||
) # type: ignore[attr-defined]
|
||||
|
||||
node_workers = defaultdict(list) # node id -> list of worker ranks
|
||||
node_gpus = defaultdict(list) # node id -> list of gpu ids
|
||||
|
||||
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
|
||||
node_workers[node_id].append(i)
|
||||
# `gpu_ids` can be a list of strings or integers.
|
||||
# convert them to integers for consistency.
|
||||
# NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
|
||||
# string sorting is not sufficient.
|
||||
# see https://github.com/vllm-project/vllm/issues/5590
|
||||
gpu_ids = [int(x) for x in gpu_ids]
|
||||
node_gpus[node_id].extend(gpu_ids)
|
||||
for node_id, gpu_ids in node_gpus.items():
|
||||
node_gpus[node_id] = sorted(gpu_ids)
|
||||
|
||||
all_ips = set(worker_ips + [driver_ip])
|
||||
n_ips = len(all_ips)
|
||||
n_nodes = len(node_workers)
|
||||
|
||||
if n_nodes != n_ips:
|
||||
raise RuntimeError(
|
||||
f"Every node should have a unique IP address. Got {n_nodes}"
|
||||
f" nodes with node ids {list(node_workers.keys())} and "
|
||||
f"{n_ips} unique IP addresses {all_ips}. Please check your"
|
||||
" network configuration. If you set `VLLM_HOST_IP`"
|
||||
" environment variable, make sure it is unique for"
|
||||
" each node."
|
||||
)
|
||||
|
||||
# Set environment variables for the driver and workers.
|
||||
all_args_to_update_environment_variables = [
|
||||
{
|
||||
current_platform.device_control_env_var: ",".join(
|
||||
map(str, node_gpus[node_id])
|
||||
),
|
||||
}
|
||||
for (node_id, _) in worker_node_and_gpu_ids
|
||||
]
|
||||
|
||||
# Environment variables to copy from driver to workers
|
||||
env_vars_to_copy = get_env_vars_to_copy(
|
||||
exclude_vars=self.WORKER_SPECIFIC_ENV_VARS,
|
||||
additional_vars=set(current_platform.additional_env_vars).union(
|
||||
self.ADDITIONAL_ENV_VARS
|
||||
),
|
||||
destination="workers",
|
||||
)
|
||||
|
||||
# Copy existing env vars to each worker's args
|
||||
for args in all_args_to_update_environment_variables:
|
||||
# TODO: refactor platform-specific env vars
|
||||
for name in env_vars_to_copy:
|
||||
if name in os.environ:
|
||||
args[name] = os.environ[name]
|
||||
|
||||
self._env_vars_for_all_workers = all_args_to_update_environment_variables
|
||||
|
||||
self.collective_rpc(
|
||||
"update_environment_variables", args=(self._get_env_vars_to_be_updated(),)
|
||||
)
|
||||
|
||||
if len(node_gpus) == 1:
|
||||
# in single node case, we don't need to get the IP address.
|
||||
# the loopback address is sufficient
|
||||
# NOTE: a node may have several IP addresses, one for each
|
||||
# network interface. `get_ip()` might return any of them,
|
||||
# while they might not work for communication inside the node
|
||||
# if the network setup is complicated. Using the loopback address
|
||||
# solves this issue, as it always works for communication inside
|
||||
# the node.
|
||||
driver_ip = "127.0.0.1"
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
driver_ip, get_open_port()
|
||||
)
|
||||
|
||||
# Initialize the actual workers inside worker wrapper.
|
||||
all_kwargs = []
|
||||
for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
|
||||
local_rank = node_workers[node_id].index(rank)
|
||||
kwargs = dict(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=(not self.parallel_config)
|
||||
or (rank % self.parallel_config.tensor_parallel_size == 0),
|
||||
)
|
||||
all_kwargs.append(kwargs)
|
||||
self.collective_rpc("init_worker", args=(all_kwargs,))
|
||||
|
||||
self.collective_rpc("init_device")
|
||||
self.collective_rpc("load_model")
|
||||
|
||||
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
|
||||
self.pp_tp_workers.append([])
|
||||
for tp_rank in range(self.parallel_config.tensor_parallel_size):
|
||||
# PP=2, TP=4
|
||||
# pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
|
||||
rank = (pp_rank * self.parallel_config.tensor_parallel_size) + tp_rank
|
||||
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
|
||||
assert pp_rank < len(self.pp_tp_workers)
|
||||
self.pp_tp_workers[pp_rank].append(self.workers[rank])
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
RayDistributedExecutor,
|
||||
RayDistributedExecutor._configure_ray_workers_use_nsight,
|
||||
RayDistributedExecutor_MluHijack._configure_ray_workers_use_nsight
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
RayDistributedExecutor,
|
||||
RayDistributedExecutor._init_workers_ray,
|
||||
RayDistributedExecutor_MluHijack._init_workers_ray
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
RayDistributedExecutor,
|
||||
RayDistributedExecutor._init_executor,
|
||||
RayDistributedExecutor_MluHijack._init_executor
|
||||
)
|
||||
213
vllm_mlu/v1/kv_cache_interface.py
Normal file
213
vllm_mlu/v1/kv_cache_interface.py
Normal file
@@ -0,0 +1,213 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing_extensions import Self
|
||||
|
||||
import torch
|
||||
|
||||
from math import prod
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.torch_utils import get_dtype_size
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
FullAttentionSpec,
|
||||
MLAAttentionSpec,
|
||||
SlidingWindowSpec,
|
||||
MambaSpec,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MLUFullAttentionSpec(FullAttentionSpec):
|
||||
|
||||
@property
|
||||
def type_id(self) -> str:
|
||||
return f"mlu_full_attention_{self.block_size}_{self.page_size_bytes}"
|
||||
|
||||
@property
|
||||
def cache_size_bytes(self) -> int:
|
||||
return (
|
||||
2
|
||||
* self.block_size
|
||||
* self.num_kv_heads
|
||||
* self.head_size
|
||||
* get_dtype_size(self.dtype)
|
||||
)
|
||||
|
||||
@property
|
||||
def scale_size_bytes(self) -> int:
|
||||
scale_size_bytes = 0
|
||||
if self.dtype in [torch.int8, torch.uint8]:
|
||||
scale_size_bytes = (
|
||||
2
|
||||
* self.block_size
|
||||
* self.num_kv_heads
|
||||
* get_dtype_size(torch.float32)
|
||||
)
|
||||
return scale_size_bytes
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: caculate kv_cache_scale size when kv_cache_dtype=int8
|
||||
'''
|
||||
return self.cache_size_bytes + self.scale_size_bytes
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MLUMLAAttentionSpec(MLAAttentionSpec):
|
||||
# Use to record k_cache info for DSA indexer
|
||||
index_head_dim: int = 0
|
||||
index_n_heads: int = 0
|
||||
|
||||
@property
|
||||
def type_id(self) -> str:
|
||||
return f"mlu_mla_attention_{self.block_size}_{self.page_size_bytes}"
|
||||
|
||||
@property
|
||||
def cache_size_bytes(self) -> int:
|
||||
return (
|
||||
self.block_size
|
||||
* self.num_kv_heads
|
||||
* self.head_size
|
||||
* get_dtype_size(self.dtype)
|
||||
)
|
||||
|
||||
@property
|
||||
def scale_size_bytes(self) -> int:
|
||||
scale_size_bytes = 0
|
||||
if self.dtype in [torch.int8, torch.uint8]:
|
||||
scale_size_bytes = (
|
||||
self.block_size
|
||||
* self.num_kv_heads
|
||||
* get_dtype_size(torch.float32)
|
||||
)
|
||||
return scale_size_bytes
|
||||
|
||||
@property
|
||||
def index_cache_size_bytes(self) -> int:
|
||||
return (
|
||||
self.block_size
|
||||
* self.index_n_heads
|
||||
* self.index_head_dim
|
||||
* get_dtype_size(self.dtype)
|
||||
)
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: caculate kv_cache_scale size when kv_cache_dtype=int8
|
||||
@brief: caculate indexer cache size for deepseek v3.2
|
||||
'''
|
||||
return self.cache_size_bytes + self.scale_size_bytes + self.index_cache_size_bytes
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
@classmethod
|
||||
def merge(cls, specs: list[Self]) -> Self:
|
||||
assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), (
|
||||
"All attention layers in the same KV cache group must be MLAAttentionSpec."
|
||||
)
|
||||
cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
|
||||
assert len(cache_dtype_str_set) == 1, (
|
||||
"All attention layers in the same KV cache group must use the same "
|
||||
"quantization method."
|
||||
)
|
||||
return cls(
|
||||
block_size=specs[0].block_size,
|
||||
num_kv_heads=specs[0].num_kv_heads,
|
||||
head_size=specs[0].head_size,
|
||||
dtype=specs[0].dtype,
|
||||
cache_dtype_str=cache_dtype_str_set.pop(),
|
||||
index_head_dim=specs[0].index_head_dim,
|
||||
index_n_heads=specs[0].index_n_heads,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MLUSlidingWindowSpec(SlidingWindowSpec):
|
||||
|
||||
@property
|
||||
def type_id(self) -> str:
|
||||
return f"mlu_sliding_window_{self.sliding_window}_{self.block_size}_{self.page_size_bytes}" # noqa
|
||||
|
||||
@property
|
||||
def cache_size_bytes(self) -> int:
|
||||
return (
|
||||
2
|
||||
* self.block_size
|
||||
* self.num_kv_heads
|
||||
* self.head_size
|
||||
* get_dtype_size(self.dtype)
|
||||
)
|
||||
|
||||
@property
|
||||
def scale_size_bytes(self) -> int:
|
||||
scale_size_bytes = 0
|
||||
if self.dtype in [torch.int8, torch.uint8]:
|
||||
scale_size_bytes = (
|
||||
2
|
||||
* self.block_size
|
||||
* self.num_kv_heads
|
||||
* get_dtype_size(torch.float32)
|
||||
)
|
||||
return scale_size_bytes
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: caculate kv_cache_scale size when kv_cache_dtype=int8
|
||||
'''
|
||||
return self.cache_size_bytes + self.scale_size_bytes
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
@property
|
||||
def vllm__v1__kv_cache_interface__MambaSpec__page_size_bytes(self) -> int:
|
||||
page_size = sum(
|
||||
prod(shape) * get_dtype_size(dtype)
|
||||
for (shape, dtype) in zip(self.shapes, self.dtypes)
|
||||
)
|
||||
if self.page_size_padded is not None:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: support qwen3-next
|
||||
'''
|
||||
# assert self.page_size_padded >= page_size
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return self.page_size_padded
|
||||
return page_size
|
||||
|
||||
MluHijackObject.apply_hijack(MambaSpec,
|
||||
MambaSpec.page_size_bytes,
|
||||
vllm__v1__kv_cache_interface__MambaSpec__page_size_bytes)
|
||||
946
vllm_mlu/v1/sample/rejection_sampler.py
Normal file
946
vllm_mlu/v1/sample/rejection_sampler.py
Normal file
@@ -0,0 +1,946 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional
|
||||
import math
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import vllm
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample import rejection_sampler
|
||||
from vllm.v1.sample.rejection_sampler import sample_recovered_tokens
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm_mlu._mlu_utils import *
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
|
||||
PLACEHOLDER_TOKEN_ID: tl.constexpr = -1
|
||||
GREEDY_TEMPERATURE: tl.constexpr = 0
|
||||
# Maximum number of speculative draft tokens allowed per request in a single
|
||||
# step. This value is chosen to be large enough to handle typical use cases.
|
||||
MAX_SPEC_LEN = 128
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief:
|
||||
- Limit maximum batch size due to NRAM memory constraints
|
||||
- Add generate_recovered_uniform_probs function for tmo rejection sampler
|
||||
'''
|
||||
MAX_BATCH_SIZE = 65536
|
||||
|
||||
def generate_recovered_uniform_probs(
|
||||
num_tokens: int,
|
||||
vocab_size: int,
|
||||
num_draft_tokens: list[int],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
q = torch.empty(
|
||||
(num_tokens, vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
q.exponential_()
|
||||
for i, generator in sampling_metadata.generators.items():
|
||||
# Do not generate random numbers for requests with no draft tokens.
|
||||
# This can be important for reproducibility.
|
||||
if num_draft_tokens[i] > 0:
|
||||
q[i].exponential_(generator=generator)
|
||||
return q
|
||||
|
||||
'''
|
||||
=============================
|
||||
End of MLU Hijack
|
||||
=============================
|
||||
'''
|
||||
|
||||
def vllm__v1__sample__rejection_sampler__expand_batch_to_tokens(
|
||||
x: torch.Tensor, # [batch_size]
|
||||
cu_num_tokens: torch.Tensor, # [batch_size]
|
||||
num_tokens: int,
|
||||
replace_from: int = 0,
|
||||
replace_to: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""Expand [batch_size] tensor to [num_tokens] tensor based on the number of
|
||||
tokens per batch in cu_num_tokens.
|
||||
|
||||
For example, if x = [a, b, c] and cu_num_tokens = [2, 5, 6], then
|
||||
num_tokens = 6, and expanded_x = [a, a, b, b, b, c].
|
||||
|
||||
Args:
|
||||
x: [batch_size] tensor to expand.
|
||||
cu_num_tokens: [batch_size] tensor containing the cumulative number of
|
||||
tokens per batch. Each element represents the total number of
|
||||
tokens up to and including that batch.
|
||||
num_tokens: Total number of tokens.
|
||||
replace_from: int = 0
|
||||
Value to be replaced if it is found in x.
|
||||
replace_to: int = 0
|
||||
Value to replace with when replace_from is found.
|
||||
Returns:
|
||||
expanded_x: [num_tokens] tensor.
|
||||
"""
|
||||
batch_size = x.shape[0]
|
||||
assert cu_num_tokens.shape[0] == batch_size
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
'''
|
||||
if batch_size > MAX_BATCH_SIZE:
|
||||
raise ValueError(f"Rejection Sampler Not Supported: "
|
||||
f"Batch size exceeds the maximum allowed value of {MAX_BATCH_SIZE}")
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
expanded_x = x.new_empty(num_tokens)
|
||||
vllm__v1__sample__rejection_sampler__expand_kernel[(batch_size, )](
|
||||
expanded_x,
|
||||
x,
|
||||
cu_num_tokens,
|
||||
replace_from,
|
||||
replace_to,
|
||||
MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation.
|
||||
)
|
||||
return expanded_x
|
||||
|
||||
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
|
||||
@triton.jit(do_not_specialize=["replace_from", "replace_to"])
|
||||
def vllm__v1__sample__rejection_sampler__expand_kernel(
|
||||
output_ptr, # [num_tokens]
|
||||
input_ptr, # [batch_size]
|
||||
cu_num_tokens_ptr, # [batch_size]
|
||||
replace_from,
|
||||
replace_to,
|
||||
MAX_NUM_TOKENS: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
if req_idx == 0: # noqa: SIM108
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
'''
|
||||
# Ensure data types are consistent
|
||||
start_idx = tl.full((), 0, tl.int64)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
else:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
'''
|
||||
start_idx = tl.load(cu_num_tokens_ptr + req_idx - 1).to(tl.int64)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
end_idx = tl.load(cu_num_tokens_ptr + req_idx)
|
||||
num_tokens = end_idx - start_idx
|
||||
|
||||
src_val = tl.load(input_ptr + req_idx)
|
||||
src_val = tl.where(src_val == replace_from, replace_to, src_val)
|
||||
offset = tl.arange(0, MAX_NUM_TOKENS)
|
||||
tl.store(output_ptr + start_idx + offset, src_val, mask=offset < num_tokens)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def vllm__v1__sample__rejection_sampler__sample_recovered_tokens_kernel(
|
||||
output_token_ids_ptr, # [num_tokens]
|
||||
cu_num_draft_tokens_ptr, # [batch_size]
|
||||
draft_token_ids_ptr, # [num_tokens]
|
||||
draft_probs_ptr, # [num_tokens, vocab_size] or None
|
||||
target_probs_ptr, # [num_tokens, vocab_size]
|
||||
q_ptr, # [batch_size, vocab_size]
|
||||
vocab_size,
|
||||
PADDED_VOCAB_SIZE: tl.constexpr,
|
||||
NO_DRAFT_PROBS: tl.constexpr,
|
||||
BLOCK_VOCAB: tl.constexpr = 2048,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
|
||||
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
|
||||
num_draft_tokens = end_idx - start_idx
|
||||
|
||||
# Early exit for out-of-range positions.
|
||||
pos = tl.program_id(1)
|
||||
if pos >= num_draft_tokens:
|
||||
return
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
'''
|
||||
max_score = -float("inf")
|
||||
max_index = 0
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
if NO_DRAFT_PROBS:
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
|
||||
draft_token_id)
|
||||
# Temporarily zero out the probability of the draft token.
|
||||
# This is essentially the same as target_prob - draft_prob, except that
|
||||
# n-gram does not have draft_prob. We regard it as 1.
|
||||
tl.store(
|
||||
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
|
||||
0)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Replace with block loop due to ngram limitations
|
||||
'''
|
||||
num_blocks = tl.cdiv(PADDED_VOCAB_SIZE, BLOCK_VOCAB)
|
||||
|
||||
for i in tl.range(0, num_blocks):
|
||||
offset = i * BLOCK_VOCAB + tl.arange(0, BLOCK_VOCAB)
|
||||
mask = offset < vocab_size
|
||||
|
||||
if NO_DRAFT_PROBS:
|
||||
prob = tl.load(
|
||||
target_probs_ptr + (start_idx + pos) * vocab_size + offset,
|
||||
mask=mask,
|
||||
other=0
|
||||
)
|
||||
else:
|
||||
draft_prob = tl.load(
|
||||
draft_probs_ptr + (start_idx + pos) * vocab_size + offset,
|
||||
mask=mask,
|
||||
other=0
|
||||
)
|
||||
target_prob = tl.load(
|
||||
target_probs_ptr + (start_idx + pos) * vocab_size + offset,
|
||||
mask=mask,
|
||||
other=0
|
||||
)
|
||||
prob = tl.maximum(target_prob - draft_prob, 0)
|
||||
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
|
||||
# `tl.argmax` will select the maximum value.
|
||||
q = tl.load(q_ptr + req_idx * vocab_size + offset,
|
||||
mask=mask,
|
||||
other=float("-inf"))
|
||||
score = prob / q # Broadcasting elementwise
|
||||
cur_max = tl.argmax(score, axis=0)
|
||||
|
||||
cur_score = score[cur_max]
|
||||
cur_index = offset[cur_max]
|
||||
|
||||
# Manually maintain argmax.
|
||||
if cur_score > max_score:
|
||||
max_score = cur_score
|
||||
max_index = cur_index
|
||||
|
||||
tl.store(output_token_ids_ptr + start_idx + pos, max_index)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
if NO_DRAFT_PROBS:
|
||||
# Restore the original probability.
|
||||
tl.store(
|
||||
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
|
||||
orig_prob)
|
||||
|
||||
"""
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
"""
|
||||
def filter_with_acceptance_rate(output_token_ids, # [batch_size, max_spec_len + 1]
|
||||
fixed_acceptance_rate):
|
||||
"""
|
||||
Filter speculative tokens based on a fixed acceptance rate using batch-level accept/reject decisions.
|
||||
|
||||
This function implements an adaptive acceptance rate control mechanism that maintains a target
|
||||
acceptance rate over time through error compensation and PID-style adjustments.
|
||||
|
||||
Args:
|
||||
output_token_ids (torch.Tensor): Input tensor of shape [batch_size, max_spec_len + 1]
|
||||
where the first column contains base tokens and remaining columns contain speculative tokens
|
||||
fixed_acceptance_rate (float or None): Target acceptance rate between 0.0 and 1.0
|
||||
If None, returns input tensor unchanged
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Modified tensor where rejected batches have all speculative tokens
|
||||
(columns 1 to max_spec_len) set to PLACEHOLDER_TOKEN_ID
|
||||
|
||||
Algorithm Flow:
|
||||
1. **Initialization Phase**:
|
||||
- Extract batch dimensions and device information
|
||||
- Initialize static variables for tracking acceptance statistics:
|
||||
* cumulative_error: Long-term error accumulation
|
||||
* total_batches/accepted_batches: Global acceptance tracking
|
||||
* acceptance_history: Sliding window for recent performance
|
||||
* precision_adjustment: PID controller adjustment factor
|
||||
* recent_adjustments: Error history for PID calculation
|
||||
|
||||
2. **Statistics Calculation**:
|
||||
- Calculate global acceptance rate from all historical data
|
||||
- Calculate sliding window acceptance rate from recent batches
|
||||
- Compute combined error using weighted average of global and window errors
|
||||
- Weight transitions from global-focused (early) to window-focused (later)
|
||||
|
||||
3. **PID Controller Adjustment** (after 50+ batches):
|
||||
- Proportional term: Current error magnitude
|
||||
- Integral term: Accumulated error over recent history
|
||||
- Derivative term: Rate of error change
|
||||
- Combines P, I, D terms to compute precision adjustment factor
|
||||
- Limits adjustment range to prevent over-correction
|
||||
|
||||
4. **Error Correction**:
|
||||
- Applies smooth nonlinear correction based on combined error magnitude
|
||||
- Uses exponential decay mapping for gradual adjustment strength
|
||||
- Handles boundary cases (0.0, 1.0, very low rates) specially
|
||||
|
||||
5. **Gap-based Adjustment**:
|
||||
- Calculates difference between target and actual accepted batches
|
||||
- Applies adaptive threshold-based corrections
|
||||
- Uses exponential smoothing for adjustment strength
|
||||
- Adjustment strength decreases as total batch count increases
|
||||
|
||||
6. **Random Perturbation** (after 100+ batches):
|
||||
- Adds small random noise to prevent local minima
|
||||
- Noise amplitude decreases over time for stability
|
||||
|
||||
7. **Batch Decision**:
|
||||
- Generates random value and compares with adjusted acceptance rate
|
||||
- Makes binary accept/reject decision for entire batch
|
||||
|
||||
8. **Token Modification**:
|
||||
- If accepted: Leave all tokens unchanged
|
||||
- If rejected: Set all speculative tokens (columns 1:) to PLACEHOLDER_TOKEN_ID
|
||||
- This ensures token-level acceptance rate matches batch-level rate
|
||||
|
||||
9. **State Updates**:
|
||||
- Update acceptance counters and history
|
||||
- Update cumulative error using exponential moving average
|
||||
- Prepare state for next function call
|
||||
|
||||
Key Features:
|
||||
- **Batch-level consistency**: All samples in a batch share the same accept/reject fate
|
||||
- **Adaptive control**: Uses multiple feedback mechanisms (global, windowed, PID)
|
||||
- **Error compensation**: Corrects for deviations from target rate over time
|
||||
- **Stability mechanisms**: Includes smoothing, limits, and perturbation for robustness
|
||||
- **Token-level alignment**: Ensures token acceptance rate matches batch acceptance rate
|
||||
|
||||
Note: This function maintains internal state across calls through static variables,
|
||||
so it will converge to the target acceptance rate over multiple invocations.
|
||||
"""
|
||||
if fixed_acceptance_rate is None:
|
||||
return output_token_ids
|
||||
else:
|
||||
# Apply accept/reject decisions for the entire batch based on fixed_acceptance_rate
|
||||
batch_size = output_token_ids.shape[0]
|
||||
max_spec_len = output_token_ids.shape[1] - 1 # Get max_spec_len
|
||||
device = output_token_ids.device
|
||||
|
||||
assert fixed_acceptance_rate >= 0 and fixed_acceptance_rate <= 1
|
||||
|
||||
# Use error compensation method to track global acceptance rate
|
||||
# These are static variables that persist between calls
|
||||
if not hasattr(filter_with_acceptance_rate, "cumulative_error"):
|
||||
filter_with_acceptance_rate.cumulative_error = 0.0
|
||||
if not hasattr(filter_with_acceptance_rate, "total_batches"):
|
||||
filter_with_acceptance_rate.total_batches = 0
|
||||
if not hasattr(filter_with_acceptance_rate, "accepted_batches"):
|
||||
filter_with_acceptance_rate.accepted_batches = 0
|
||||
if not hasattr(filter_with_acceptance_rate, "window_size"):
|
||||
filter_with_acceptance_rate.window_size = 1000 # Sliding window size
|
||||
if not hasattr(filter_with_acceptance_rate, "acceptance_history"):
|
||||
filter_with_acceptance_rate.acceptance_history = [] # Track recent accept/reject history
|
||||
if not hasattr(filter_with_acceptance_rate, "precision_adjustment"):
|
||||
filter_with_acceptance_rate.precision_adjustment = 0.0 # Precision adjustment factor
|
||||
if not hasattr(filter_with_acceptance_rate, "recent_adjustments"):
|
||||
filter_with_acceptance_rate.recent_adjustments = [] # Recent adjustment history
|
||||
if not hasattr(filter_with_acceptance_rate, "target_rate"):
|
||||
filter_with_acceptance_rate.target_rate = fixed_acceptance_rate # Record target acceptance rate
|
||||
else:
|
||||
# If target acceptance rate changes, reset adjustment state
|
||||
if filter_with_acceptance_rate.target_rate != fixed_acceptance_rate:
|
||||
filter_with_acceptance_rate.precision_adjustment = 0.0
|
||||
filter_with_acceptance_rate.recent_adjustments = []
|
||||
filter_with_acceptance_rate.target_rate = fixed_acceptance_rate
|
||||
|
||||
# Update batch count
|
||||
filter_with_acceptance_rate.total_batches += 1
|
||||
|
||||
# Calculate current global acceptance rate
|
||||
global_rate = (filter_with_acceptance_rate.accepted_batches /
|
||||
filter_with_acceptance_rate.total_batches if
|
||||
filter_with_acceptance_rate.total_batches > 0 else 0.0)
|
||||
|
||||
# Calculate sliding window acceptance rate (focusing on recent performance)
|
||||
filter_with_acceptance_rate.acceptance_history.append(0) # Default to reject
|
||||
if len(filter_with_acceptance_rate.acceptance_history) > filter_with_acceptance_rate.window_size:
|
||||
filter_with_acceptance_rate.acceptance_history.pop(0) # Remove oldest record
|
||||
|
||||
window_rate = sum(filter_with_acceptance_rate.acceptance_history) / len(filter_with_acceptance_rate.acceptance_history)
|
||||
|
||||
# Enhance precision for small batches - use smoother weight function
|
||||
batch_weight_factor = 1.0 - math.exp(-filter_with_acceptance_rate.total_batches / 30.0) # Exponential smooth transition
|
||||
|
||||
# Dynamically adjust error weights: rely more on global error for fewer batches,
|
||||
# more on sliding window error as batch count increases
|
||||
window_size = len(filter_with_acceptance_rate.acceptance_history)
|
||||
window_significance = min(window_size / 100.0, 0.9) # Window significance depends on historical data volume
|
||||
window_weight = window_significance * batch_weight_factor
|
||||
global_weight = 1.0 - window_weight
|
||||
|
||||
# Consider both global error and window error
|
||||
combined_error = (global_weight * (global_rate - fixed_acceptance_rate) +
|
||||
window_weight * (window_rate - fixed_acceptance_rate))
|
||||
|
||||
# Update precision adjustment factor - use PID controller style adjustment
|
||||
if filter_with_acceptance_rate.total_batches > 50:
|
||||
# Only perform precision adjustment when there's enough data
|
||||
current_error = global_rate - fixed_acceptance_rate
|
||||
|
||||
# Save recent adjustment history
|
||||
filter_with_acceptance_rate.recent_adjustments.append(current_error)
|
||||
if len(filter_with_acceptance_rate.recent_adjustments) > 20: # Keep recent 20 errors
|
||||
filter_with_acceptance_rate.recent_adjustments.pop(0)
|
||||
|
||||
# PID controller parameters
|
||||
kp = 0.05 # Proportional coefficient
|
||||
ki = 0.001 # Integral coefficient
|
||||
kd = 0.01 # Derivative coefficient
|
||||
|
||||
# Proportional term - current error
|
||||
p_term = current_error
|
||||
|
||||
# Integral term - accumulated error
|
||||
i_term = sum(filter_with_acceptance_rate.recent_adjustments)
|
||||
|
||||
# Derivative term - error change rate
|
||||
d_term = 0.0
|
||||
if len(filter_with_acceptance_rate.recent_adjustments) >= 2:
|
||||
d_term = filter_with_acceptance_rate.recent_adjustments[-1] - filter_with_acceptance_rate.recent_adjustments[-2]
|
||||
|
||||
# Calculate PID adjustment
|
||||
pid_adjustment = kp * p_term + ki * i_term + kd * d_term
|
||||
|
||||
# Update precision adjustment factor
|
||||
filter_with_acceptance_rate.precision_adjustment = pid_adjustment
|
||||
|
||||
# Limit adjustment factor range to prevent over-adjustment
|
||||
max_adjustment = 0.02 + 0.03 * (1.0 - math.exp(-filter_with_acceptance_rate.total_batches / 500.0))
|
||||
filter_with_acceptance_rate.precision_adjustment = max(-max_adjustment, min(max_adjustment, filter_with_acceptance_rate.precision_adjustment))
|
||||
|
||||
# Calculate acceptance rate correction factor
|
||||
error_magnitude = abs(combined_error)
|
||||
correction_factor = 1.0
|
||||
|
||||
# Use more refined error correction logic - use smooth nonlinear correction function
|
||||
if error_magnitude > 0.0005: # Correct even smaller errors
|
||||
# Use smooth correction function instead of piecewise function
|
||||
base_strength = 2.0
|
||||
error_scale = 1.0 - math.exp(-error_magnitude * 50.0) # Exponential decay mapping to [0,1]
|
||||
correction_strength = base_strength + error_scale * 1.5 # Range from 2.0 to 3.5
|
||||
|
||||
# Smooth correction
|
||||
sign = 1 if combined_error > 0 else -1
|
||||
correction_factor = 1.0 + (correction_strength * error_magnitude * sign)
|
||||
|
||||
# Handle boundary cases to avoid division by zero
|
||||
if correction_factor == 0.0:
|
||||
correction_factor = 1.0
|
||||
|
||||
# Apply correction factor
|
||||
adjusted_rate = max(0.0, min(1.0, fixed_acceptance_rate * (1.0 / correction_factor)))
|
||||
|
||||
# Apply precision adjustment factor
|
||||
adjusted_rate = max(0.0, min(1.0, adjusted_rate - filter_with_acceptance_rate.precision_adjustment))
|
||||
|
||||
# More precise boundary case handling
|
||||
if fixed_acceptance_rate > 0 and fixed_acceptance_rate < 0.05:
|
||||
if filter_with_acceptance_rate.total_batches % int(1/fixed_acceptance_rate) == 0:
|
||||
adjusted_rate = 1.0 # Periodically force accept to ensure accuracy in low acceptance rate scenarios
|
||||
# If fixed_acceptance_rate is 0, directly reject
|
||||
elif fixed_acceptance_rate == 0.0:
|
||||
adjusted_rate = 0.0
|
||||
# If fixed_acceptance_rate is 1, directly accept
|
||||
elif fixed_acceptance_rate == 1.0:
|
||||
adjusted_rate = 1.0
|
||||
|
||||
# Make precise adjustments for cases with large remaining errors
|
||||
target_accepted = int(filter_with_acceptance_rate.total_batches * fixed_acceptance_rate + 0.5) # Round to nearest
|
||||
actual_accepted = filter_with_acceptance_rate.accepted_batches
|
||||
acceptance_gap = target_accepted - actual_accepted
|
||||
|
||||
# More aggressive gap adjustment strategy - use adaptive threshold and smooth adjustment
|
||||
gap_relative = abs(acceptance_gap) / max(1, filter_with_acceptance_rate.total_batches)
|
||||
gap_threshold = max(1, int(filter_with_acceptance_rate.total_batches * 0.005)) # Smaller dynamic threshold, at least 1
|
||||
|
||||
# Dynamically adjust acceptance rate based on the gap
|
||||
if abs(acceptance_gap) >= gap_threshold: # Use dynamic threshold
|
||||
# Use smooth adjustment strategy
|
||||
if acceptance_gap > 0: # Need to accept more
|
||||
# Use exponential function for smooth adjustment
|
||||
gap_importance = 1.0 - math.exp(-gap_relative * 50.0) # Map to [0,1]
|
||||
# Adjustment strength decreases as total batch count increases
|
||||
strength_factor = math.exp(-filter_with_acceptance_rate.total_batches / 1000.0)
|
||||
boost_factor = gap_importance * (0.2 + 0.8 * strength_factor) # Range from 0 to 1, decreases with total batch count
|
||||
adjusted_rate = min(1.0, adjusted_rate + (1.0 - adjusted_rate) * boost_factor)
|
||||
else: # Accepted too many, need to reject
|
||||
# Use exponential function for smooth adjustment
|
||||
gap_importance = 1.0 - math.exp(-gap_relative * 50.0) # Map to [0,1]
|
||||
# Adjustment strength decreases as total batch count increases
|
||||
strength_factor = math.exp(-filter_with_acceptance_rate.total_batches / 1000.0)
|
||||
reduction_factor = gap_importance * (0.2 + 0.8 * strength_factor) # Range from 0 to 1, decreases with total batch count
|
||||
adjusted_rate = max(0.0, adjusted_rate * (1.0 - reduction_factor))
|
||||
|
||||
# Add small random perturbation in fixed intervals to enhance convergence
|
||||
if 0.01 < adjusted_rate < 0.99 and filter_with_acceptance_rate.total_batches > 100:
|
||||
# Random perturbation amplitude decreases as batch count increases
|
||||
noise_amplitude = 0.01 * math.exp(-filter_with_acceptance_rate.total_batches / 500.0)
|
||||
noise = (torch.rand(1, device=device).item() * 2 - 1) * noise_amplitude
|
||||
adjusted_rate = max(0.0, min(1.0, adjusted_rate + noise))
|
||||
|
||||
# Generate a random number to decide whether to accept the current batch
|
||||
random_value = torch.rand(1, device=device).item()
|
||||
accept_batch = random_value < adjusted_rate
|
||||
|
||||
# Set some tokens to PLACEHOLDER_TOKEN_ID to achieve specified acceptance rate
|
||||
# Support max_spec_len > 1 cases
|
||||
if accept_batch:
|
||||
# Accept batch - don't modify token_ids
|
||||
filter_with_acceptance_rate.accepted_batches += 1
|
||||
filter_with_acceptance_rate.acceptance_history[-1] = 1 # Update the most recent acceptance status
|
||||
else:
|
||||
# Reject batch - set all speculative tokens (except first column) to PLACEHOLDER_TOKEN_ID
|
||||
# This ensures token-level acceptance rate matches batch-level acceptance rate
|
||||
output_token_ids[:, 1:] = PLACEHOLDER_TOKEN_ID
|
||||
|
||||
# Note: acceptance rate calculation is still based on entire batch accept/reject, no modification needed
|
||||
# But we can add a comment explaining how actual token-level acceptance rate is calculated
|
||||
# Actual token-level acceptance rate = 1 - (number of PLACEHOLDER_TOKEN_ID in output_token_ids / max_spec_len)
|
||||
|
||||
# Update cumulative error - use exponential moving average for smoother error adjustment
|
||||
actual_rate = filter_with_acceptance_rate.accepted_batches / filter_with_acceptance_rate.total_batches
|
||||
# Use EMA to smooth error updates - use adaptive EMA coefficient
|
||||
alpha = 0.05 * math.exp(-filter_with_acceptance_rate.total_batches / 200.0) + 0.01 # EMA coefficient gradually decreases over time
|
||||
filter_with_acceptance_rate.cumulative_error = (alpha * (actual_rate - fixed_acceptance_rate) +
|
||||
(1 - alpha) * filter_with_acceptance_rate.cumulative_error)
|
||||
|
||||
return output_token_ids
|
||||
"""
|
||||
=============================
|
||||
End of MLU Hijack
|
||||
=============================
|
||||
"""
|
||||
|
||||
def vllm__v1__sample__rejection_sampler__rejection_sample(
|
||||
# [num_tokens]
|
||||
draft_token_ids: torch.Tensor,
|
||||
# [batch_size]
|
||||
num_draft_tokens: list[int],
|
||||
max_spec_len: int,
|
||||
# [batch_size]
|
||||
cu_num_draft_tokens: torch.Tensor,
|
||||
# [num_tokens, vocab_size]
|
||||
draft_probs: torch.Tensor | None,
|
||||
# [num_tokens, vocab_size]
|
||||
target_probs: torch.Tensor,
|
||||
# [batch_size, 1]
|
||||
bonus_token_ids: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert draft_token_ids.ndim == 1
|
||||
assert draft_probs is None or draft_probs.ndim == 2
|
||||
assert cu_num_draft_tokens.ndim == 1
|
||||
assert target_probs.ndim == 2
|
||||
|
||||
batch_size = len(num_draft_tokens)
|
||||
num_tokens = draft_token_ids.shape[0]
|
||||
vocab_size = target_probs.shape[-1]
|
||||
device = target_probs.device
|
||||
assert draft_token_ids.is_contiguous()
|
||||
assert draft_probs is None or draft_probs.is_contiguous()
|
||||
assert target_probs.is_contiguous()
|
||||
assert bonus_token_ids.is_contiguous()
|
||||
assert target_probs.shape == (num_tokens, vocab_size)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use tmo rejection_sample for all random sampling requests
|
||||
'''
|
||||
fixed_acceptance_rate = VLLM_MTP_FIXED_ACCEPTANCE_RATE
|
||||
use_fusion_kernel = (sampling_metadata.all_random
|
||||
and max_spec_len == 1
|
||||
and (num_draft_tokens is not None
|
||||
and 0 not in num_draft_tokens))
|
||||
if use_fusion_kernel:
|
||||
# All data is random, use tmo rejection_sample
|
||||
# Generate uniform probabilities for rejection sampling.
|
||||
# [num_tokens]
|
||||
uniform_rand = vllm__v1__sample__rejection_sampler__generate_uniform_probs(
|
||||
num_tokens,
|
||||
num_draft_tokens,
|
||||
sampling_metadata.generators,
|
||||
device,
|
||||
)
|
||||
# generate random probs for recovered tokens
|
||||
uniform_probs = generate_recovered_uniform_probs(
|
||||
num_tokens,
|
||||
vocab_size,
|
||||
num_draft_tokens,
|
||||
sampling_metadata,
|
||||
device,
|
||||
)
|
||||
# num_draft_tokens need to be a tensor
|
||||
num_draft_tokens_tensor = torch.tensor(num_draft_tokens, dtype=torch.int32, device=device)
|
||||
# tmo rejection_sample dtype need to be int32
|
||||
bonus_token_ids = bonus_token_ids.to(torch.int32)
|
||||
draft_token_ids = draft_token_ids.to(torch.int32)
|
||||
# use tmo rejection_sample
|
||||
output_token_ids = mlu_ops.rejection_sample(
|
||||
draft_token_ids,
|
||||
num_draft_tokens_tensor,
|
||||
cu_num_draft_tokens,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
uniform_rand,
|
||||
uniform_probs,
|
||||
max_spec_len,
|
||||
high_acc=True # for now, only support high_acc
|
||||
).view(batch_size, max_spec_len + 1)
|
||||
if fixed_acceptance_rate is not None:
|
||||
# set all speculative tokens to placeholder token
|
||||
output_token_ids[:, 1:] = 0
|
||||
output_token_ids = filter_with_acceptance_rate(output_token_ids, fixed_acceptance_rate)
|
||||
return output_token_ids
|
||||
'''
|
||||
=============================
|
||||
End of MLU Hijack
|
||||
=============================
|
||||
'''
|
||||
|
||||
# Create output buffer.
|
||||
output_token_ids = torch.full(
|
||||
(batch_size, max_spec_len + 1),
|
||||
PLACEHOLDER_TOKEN_ID,
|
||||
dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids.
|
||||
device=device,
|
||||
)
|
||||
|
||||
if sampling_metadata.all_greedy:
|
||||
is_greedy = None
|
||||
else:
|
||||
is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE
|
||||
if not sampling_metadata.all_random:
|
||||
# Rejection sampling for greedy sampling requests.
|
||||
target_argmax = target_probs.argmax(dim=-1)
|
||||
vllm__v1__sample__rejection_sampler__rejection_greedy_sample_kernel[(batch_size, )](
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
target_argmax,
|
||||
bonus_token_ids,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
has_acceptance_rate=fixed_acceptance_rate is not None,
|
||||
)
|
||||
if sampling_metadata.all_greedy:
|
||||
output_token_ids = filter_with_acceptance_rate(output_token_ids, fixed_acceptance_rate)
|
||||
return output_token_ids
|
||||
|
||||
# Generate uniform probabilities for rejection sampling.
|
||||
# [num_tokens]
|
||||
uniform_probs = vllm__v1__sample__rejection_sampler__generate_uniform_probs(
|
||||
num_tokens,
|
||||
num_draft_tokens,
|
||||
sampling_metadata.generators,
|
||||
device,
|
||||
)
|
||||
|
||||
# Sample recovered tokens for each position.
|
||||
# [num_tokens]
|
||||
recovered_token_ids = sample_recovered_tokens(
|
||||
max_spec_len,
|
||||
num_draft_tokens,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
sampling_metadata,
|
||||
device,
|
||||
)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Add fixed acceptance rate check
|
||||
'''
|
||||
# Rejection sampling for random sampling requests.
|
||||
vllm__v1__sample__rejection_sampler__rejection_random_sample_kernel[(batch_size, )](
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
recovered_token_ids,
|
||||
uniform_probs,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
NO_DRAFT_PROBS=draft_probs is None,
|
||||
has_acceptance_rate=fixed_acceptance_rate is not None,
|
||||
)
|
||||
output_token_ids = filter_with_acceptance_rate(output_token_ids, fixed_acceptance_rate)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return output_token_ids
|
||||
|
||||
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
|
||||
@triton.jit(do_not_specialize=["max_spec_len"])
|
||||
def vllm__v1__sample__rejection_sampler__rejection_random_sample_kernel(
|
||||
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
|
||||
cu_num_draft_tokens_ptr, # [batch_size]
|
||||
draft_token_ids_ptr, # [num_tokens]
|
||||
draft_probs_ptr, # [num_tokens, vocab_size] or None
|
||||
target_probs_ptr, # [num_tokens, vocab_size]
|
||||
bonus_token_ids_ptr, # [batch_size]
|
||||
recovered_token_ids_ptr, # [num_tokens]
|
||||
uniform_probs_ptr, # [num_tokens]
|
||||
is_greedy_ptr, # [batch_size]
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
NO_DRAFT_PROBS: tl.constexpr,
|
||||
has_acceptance_rate: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
is_greedy = tl.load(is_greedy_ptr + req_idx)
|
||||
if is_greedy:
|
||||
# Early exit for greedy sampling requests.
|
||||
return
|
||||
|
||||
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
|
||||
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
|
||||
num_draft_tokens = end_idx - start_idx
|
||||
|
||||
rejected = False
|
||||
for pos in range(num_draft_tokens):
|
||||
if not rejected:
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
if NO_DRAFT_PROBS:
|
||||
draft_prob = 1
|
||||
else:
|
||||
draft_prob = tl.load(
|
||||
draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id
|
||||
)
|
||||
target_prob = tl.load(
|
||||
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id
|
||||
)
|
||||
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add accept rate check, always accept if has_acceptance_rate is True
|
||||
'''
|
||||
# NOTE(woosuk): While the draft probability should never be 0,
|
||||
# we check it to avoid NaNs. If it happens to be 0, we reject.
|
||||
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob or has_acceptance_rate:
|
||||
# Accept.
|
||||
token_id = draft_token_id
|
||||
else:
|
||||
# Reject. Use recovered token.
|
||||
rejected = True
|
||||
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
|
||||
'''
|
||||
=============================
|
||||
End of MLU Hijack
|
||||
=============================
|
||||
'''
|
||||
tl.store(
|
||||
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, token_id
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Check whether to accept bonus token through acceptance_rate_ptr
|
||||
'''
|
||||
# If has acceptance rate, all tokens are accepted
|
||||
if has_acceptance_rate:
|
||||
rejected = False
|
||||
if not rejected:
|
||||
# If all tokens are accepted, append the bonus token.
|
||||
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
|
||||
tl.store(
|
||||
output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens,
|
||||
bonus_token_id,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
|
||||
@triton.jit(do_not_specialize=["max_spec_len"])
|
||||
def vllm__v1__sample__rejection_sampler__rejection_greedy_sample_kernel(
|
||||
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
|
||||
cu_num_draft_tokens_ptr, # [batch_size]
|
||||
draft_token_ids_ptr, # [num_tokens]
|
||||
target_argmax_ptr, # [num_tokens]
|
||||
bonus_token_ids_ptr, # [batch_size]
|
||||
is_greedy_ptr, # [batch_size] or None
|
||||
max_spec_len,
|
||||
has_acceptance_rate: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
# FIXME(woosuk): Because is_greedy_ptr is not None at profiling run,
|
||||
# re-compilation may happen during runtime when is_greedy_ptr is None.
|
||||
is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr + req_idx)
|
||||
if not is_greedy:
|
||||
# Early exit for non-greedy sampling requests.
|
||||
return
|
||||
|
||||
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
|
||||
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
|
||||
num_draft_tokens = end_idx - start_idx
|
||||
|
||||
rejected = False
|
||||
for pos in range(num_draft_tokens):
|
||||
if not rejected:
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
|
||||
tl.store(
|
||||
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
|
||||
target_argmax_id,
|
||||
)
|
||||
if draft_token_id != target_argmax_id:
|
||||
# Reject.
|
||||
rejected = True
|
||||
if has_acceptance_rate:
|
||||
rejected = False
|
||||
if not rejected:
|
||||
# If all tokens are accepted, append the bonus token.
|
||||
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
|
||||
tl.store(
|
||||
output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens,
|
||||
bonus_token_id,
|
||||
)
|
||||
|
||||
|
||||
def vllm__v1__sample__rejection_sampler__generate_uniform_probs(
|
||||
num_tokens: int,
|
||||
num_draft_tokens: list[int],
|
||||
generators: dict[int, torch.Generator],
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Generates a batch of uniform random samples, with optional seeding
|
||||
if available.
|
||||
|
||||
This method creates a tensor of shape `(num_tokens, )` filled
|
||||
with uniform random values in the range [0, 1). If `generators` is provided,
|
||||
the requests with their own seeds will use the provided `torch.Generator`
|
||||
for reproducibility. The samples for the other requests will be generated
|
||||
without a seed.
|
||||
|
||||
Args:
|
||||
num_tokens: int
|
||||
Total number of tokens.
|
||||
num_draft_tokens: List[List[int]]
|
||||
Number of draft tokens per request.
|
||||
generators: Optional[Dict[int, torch.Generator]]
|
||||
A dictionary mapping indices in the batch to
|
||||
`torch.Generator` objects.
|
||||
device: torch.device
|
||||
The device on which to allocate the tensor.
|
||||
Returns:
|
||||
uniform_rand: torch.Tensor
|
||||
A tensor of shape `(num_tokens, )` containing uniform
|
||||
random values in the range [0, 1).
|
||||
"""
|
||||
# NOTE(woosuk): We deliberately use float64 instead of float32 here
|
||||
# because when using float32, there's a non-negligible chance that
|
||||
# uniform_prob is sampled to be exact 0.0 as reported in
|
||||
# https://github.com/pytorch/pytorch/issues/16706. Using float64
|
||||
# mitigates the issue.
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Changed torch.float64 to torch.float32
|
||||
'''
|
||||
uniform_probs = torch.rand(
|
||||
(num_tokens,),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
start_idx = 0
|
||||
for req_idx, n in enumerate(num_draft_tokens):
|
||||
# Do not generate random numbers for requests with no draft tokens.
|
||||
# This can be important for reproducibility.
|
||||
if n == 0:
|
||||
continue
|
||||
end_idx = start_idx + n
|
||||
generator = generators.get(req_idx)
|
||||
if generator is not None:
|
||||
uniform_probs[start_idx:end_idx].uniform_(generator=generator)
|
||||
start_idx = end_idx
|
||||
return uniform_probs
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(rejection_sampler,
|
||||
rejection_sampler.generate_uniform_probs,
|
||||
vllm__v1__sample__rejection_sampler__generate_uniform_probs)
|
||||
|
||||
MluHijackObject.apply_hijack(rejection_sampler,
|
||||
rejection_sampler.expand_batch_to_tokens,
|
||||
vllm__v1__sample__rejection_sampler__expand_batch_to_tokens)
|
||||
|
||||
MluHijackObject.apply_hijack(rejection_sampler,
|
||||
rejection_sampler.expand_kernel,
|
||||
vllm__v1__sample__rejection_sampler__expand_kernel)
|
||||
|
||||
MluHijackObject.apply_hijack(rejection_sampler,
|
||||
rejection_sampler.sample_recovered_tokens_kernel,
|
||||
vllm__v1__sample__rejection_sampler__sample_recovered_tokens_kernel)
|
||||
|
||||
MluHijackObject.apply_hijack(rejection_sampler,
|
||||
rejection_sampler.rejection_sample,
|
||||
vllm__v1__sample__rejection_sampler__rejection_sample)
|
||||
|
||||
MluHijackObject.apply_hijack(rejection_sampler,
|
||||
rejection_sampler.rejection_random_sample_kernel,
|
||||
vllm__v1__sample__rejection_sampler__rejection_random_sample_kernel)
|
||||
|
||||
MluHijackObject.apply_hijack(rejection_sampler,
|
||||
rejection_sampler.rejection_greedy_sample_kernel,
|
||||
vllm__v1__sample__rejection_sampler__rejection_greedy_sample_kernel)
|
||||
118
vllm_mlu/v1/sample/sampler.py
Normal file
118
vllm_mlu/v1/sample/sampler.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import torch
|
||||
from vllm.config.model import LogprobsMode
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.sampler import Sampler, _SAMPLING_EPS
|
||||
|
||||
from vllm_mlu._mlu_utils import *
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
|
||||
"""
|
||||
@brief: use tmo random_sample
|
||||
"""
|
||||
def mlu_random_sample(
|
||||
probs: torch.Tensor,
|
||||
generators: dict[int, torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
is_gumbel_max = True
|
||||
return mlu_ops.random_sample(probs, is_gumbel_max, generators).view(-1)
|
||||
|
||||
class MluSampler(Sampler):
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
logprobs_mode_override: LogprobsMode | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""Sample logits based on sampling metadata.
|
||||
|
||||
The various logits processing functions called in this method
|
||||
may update the logits tensor in-place.
|
||||
"""
|
||||
|
||||
logprobs_mode = logprobs_mode_override or self.logprobs_mode
|
||||
assert not (sampling_metadata.all_greedy and sampling_metadata.all_random)
|
||||
if sampling_metadata.all_random:
|
||||
greedy_sampled = None
|
||||
else:
|
||||
greedy_sampled = self.greedy_sample(logits)
|
||||
if sampling_metadata.all_greedy:
|
||||
processed_logprobs = None
|
||||
if sampling_metadata.max_num_logprobs is not None:
|
||||
if logprobs_mode == "processed_logits":
|
||||
processed_logprobs = logits
|
||||
elif logprobs_mode == "processed_logprobs":
|
||||
processed_logprobs = self.compute_logprobs(logits)
|
||||
return greedy_sampled, processed_logprobs
|
||||
|
||||
assert sampling_metadata.temperature is not None
|
||||
"""
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use tmo topk_topp_sampler to sample.
|
||||
"""
|
||||
use_tmo = (sampling_metadata.top_k is not None) or (sampling_metadata.top_p is not None)
|
||||
if use_tmo:
|
||||
batch_size, vocab_size = logits.shape
|
||||
index_in = torch.arange(vocab_size, dtype=torch.int32, device=logits.device)
|
||||
(
|
||||
logits_out,
|
||||
sorted_logits_out,
|
||||
index_out,
|
||||
true_select_len,
|
||||
) = mlu_ops.apply_topkp_v2(
|
||||
logits,
|
||||
index_in,
|
||||
sampling_metadata.temperature,
|
||||
None,
|
||||
sampling_metadata.top_k.to(torch.int32) if sampling_metadata.top_k is not None else None,
|
||||
sampling_metadata.top_p,
|
||||
)
|
||||
|
||||
processed_logprobs = None
|
||||
if logprobs_mode == "processed_logits":
|
||||
processed_logprobs = logits
|
||||
elif logprobs_mode == "processed_logprobs":
|
||||
processed_logprobs = self.compute_logprobs(logits)
|
||||
|
||||
probs = logits_out.softmax(dim=-1, dtype=torch.float32)
|
||||
random_sampled = mlu_random_sample(probs, sampling_metadata.generators)
|
||||
else:
|
||||
# Apply temperature.
|
||||
logits = self.apply_temperature(
|
||||
logits, sampling_metadata.temperature, sampling_metadata.all_random
|
||||
)
|
||||
|
||||
# Apply logits processors that only apply to random sampling
|
||||
# (argmax invariant)
|
||||
for processor in sampling_metadata.logitsprocs.argmax_invariant:
|
||||
logits = processor.apply(logits)
|
||||
|
||||
# Apply top_k and/or top_p.
|
||||
random_sampled, processed_logprobs = self.topk_topp_sampler(
|
||||
logits,
|
||||
sampling_metadata.generators,
|
||||
sampling_metadata.top_k,
|
||||
sampling_metadata.top_p,
|
||||
)
|
||||
"""
|
||||
=================
|
||||
End of MLU Hijack
|
||||
=================
|
||||
"""
|
||||
|
||||
if greedy_sampled is None:
|
||||
return random_sampled, processed_logprobs
|
||||
|
||||
sampled = torch.where(
|
||||
sampling_metadata.temperature < _SAMPLING_EPS,
|
||||
greedy_sampled,
|
||||
random_sampled,
|
||||
out=greedy_sampled, # Reuse tensor
|
||||
)
|
||||
return sampled, processed_logprobs
|
||||
530
vllm_mlu/v1/spec_decode/dp_eagle.py
Normal file
530
vllm_mlu/v1/spec_decode/dp_eagle.py
Normal file
@@ -0,0 +1,530 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
from typing import List, Optional, Any
|
||||
import copy
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from vllm.config.vllm import CUDAGraphMode
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata,
|
||||
FlashAttentionMetadata)
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, logger
|
||||
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_gather_into_list
|
||||
from vllm.distributed import (
|
||||
get_logits_tp_world_size,
|
||||
get_logits_tp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm_mlu.v1.attention.backends.flash_attn import pad_attn_metadata
|
||||
from vllm_mlu.v1.attention.backends.mla.flashmla import FlashMLAMetadataBuilder
|
||||
from vllm_mlu.v1.attention.backends.utils import (
|
||||
MLUCommonAttentionMetadata, COMMON_METADATA_STR)
|
||||
from vllm_mlu._mlu_utils import *
|
||||
from vllm_mlu.v1.attention.backends.utils import MLUInferMode
|
||||
|
||||
from vllm_mlu.mlu_forward_context import MLUDPMetadata
|
||||
from vllm_mlu.v1.spec_decode.eagle import MluEagleProposer
|
||||
from vllm_mlu.model_executor.models.dp_utils import (
|
||||
enable_data_parallel,
|
||||
DataParallelRuntimeParams
|
||||
)
|
||||
|
||||
class DPMluEagleProposer(MluEagleProposer):
|
||||
|
||||
def get_logits_batch_sizes(self, batch_size: int) -> Optional[List[int]]:
|
||||
tp_world_size, logits_batch_sizes = get_logits_tp_world_size(), None
|
||||
if tp_world_size != get_tensor_model_parallel_world_size():
|
||||
tp_tensor = torch.tensor([batch_size]).to(self.runner.device)
|
||||
outputs = tensor_model_parallel_all_gather_into_list(tp_tensor, get_logits_tp_group())
|
||||
# Convert device tensor to host list
|
||||
outputs = torch.cat(outputs).tolist()
|
||||
logits_batch_sizes = [outputs[i] for i in range(tp_world_size)]
|
||||
return logits_batch_sizes
|
||||
|
||||
def propose_ds_execute_dummy_batch(
|
||||
self,
|
||||
# [num_tokens]
|
||||
target_token_ids: torch.Tensor,
|
||||
# [num_tokens]
|
||||
target_positions: torch.Tensor,
|
||||
# [num_tokens, hidden_size]
|
||||
target_hidden_states: torch.Tensor,
|
||||
dp_params: DataParallelRuntimeParams,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# num_scheduled_tokens
|
||||
num_tokens = target_token_ids.shape[0]
|
||||
input_ids = self.input_ids[:num_tokens]
|
||||
# Shift the input ids by one token.
|
||||
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
|
||||
input_ids[:-1] = target_token_ids[1:]
|
||||
|
||||
# always skip attn compute
|
||||
attn_metadata: Optional[dict[str, Any]] = None
|
||||
|
||||
# Get graph capture related infomation for deepseek model.
|
||||
|
||||
with set_forward_context(attn_metadata, self.vllm_config, num_tokens=num_tokens):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=target_positions,
|
||||
hidden_states=target_hidden_states,
|
||||
intermediate_tensors=None,
|
||||
inputs_embeds=None,
|
||||
dp_params=dp_params,
|
||||
)
|
||||
if dp_params is not None:
|
||||
dp_params.logits_batch_split_list = self.get_logits_batch_sizes(num_tokens)
|
||||
_ = self.model.compute_logits(hidden_states, dp_params=dp_params)
|
||||
|
||||
if self.num_speculative_tokens == 1:
|
||||
return
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
@brief: support k > 1, need run draft model k-1 times
|
||||
=============================
|
||||
'''
|
||||
# support k > 1
|
||||
for _ in range(self.num_speculative_tokens - 1):
|
||||
new_dp_params = self.runner._get_data_parallel_metadata(
|
||||
num_tokens, num_tokens, True, [1] * num_tokens)
|
||||
with set_forward_context(attn_metadata, self.vllm_config, num_tokens=num_tokens):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=target_positions,
|
||||
hidden_states=target_hidden_states,
|
||||
intermediate_tensors=None,
|
||||
inputs_embeds=None,
|
||||
dp_params=new_dp_params,
|
||||
)
|
||||
_ = self.model.compute_logits(hidden_states, dp_params=new_dp_params)
|
||||
'''
|
||||
=============================
|
||||
End of MLU Hijack
|
||||
=============================
|
||||
'''
|
||||
|
||||
def propose(
|
||||
self,
|
||||
# [num_tokens]
|
||||
target_token_ids: torch.Tensor,
|
||||
# [num_tokens]
|
||||
target_positions: torch.Tensor,
|
||||
# [num_tokens, hidden_size]
|
||||
target_hidden_states: torch.Tensor,
|
||||
# [batch_size]
|
||||
next_token_ids: torch.Tensor,
|
||||
last_token_indices: torch.Tensor | None,
|
||||
common_attn_metadata: MLUCommonAttentionMetadata,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
# [batch_size]
|
||||
num_rejected_tokens: torch.Tensor,
|
||||
# [num_tokens]
|
||||
token_indices: torch.Tensor,
|
||||
whole_block_table: torch.Tensor,
|
||||
main_model_dp_params: Optional[DataParallelRuntimeParams] = None,
|
||||
time_markers: List =[],
|
||||
) -> torch.Tensor:
|
||||
num_tokens = target_token_ids.shape[0]
|
||||
batch_size = next_token_ids.shape[0]
|
||||
if last_token_indices is None:
|
||||
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
|
||||
|
||||
if self.method == "eagle3":
|
||||
assert isinstance(self.model, Eagle3LlamaForCausalLM)
|
||||
target_hidden_states = self.model.combine_hidden_states(
|
||||
target_hidden_states)
|
||||
assert target_hidden_states.shape[-1] == self.hidden_size
|
||||
|
||||
# Shift the input ids by one token.
|
||||
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
|
||||
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
|
||||
# Replace the last token with the next token.
|
||||
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
||||
self.input_ids[last_token_indices] = next_token_ids
|
||||
hidden_states_indices = last_token_indices
|
||||
|
||||
assert self.runner is not None
|
||||
|
||||
if self.attn_metadata_builder is None:
|
||||
attn_metadata_builder = self._get_attention_metadata_builder()
|
||||
else:
|
||||
attn_metadata_builder = self.attn_metadata_builder
|
||||
|
||||
# FIXME: need to consider multiple kv_cache_groups
|
||||
attn_metadata = attn_metadata_builder.build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
draft_index=0,
|
||||
)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Use full graph with draft model and pad batch_size for dp
|
||||
'''
|
||||
dp_group_max_token_num = max(main_model_dp_params.token_split_list)
|
||||
if dp_group_max_token_num <= self.vllm_config.compilation_config.max_cudagraph_capture_size:
|
||||
batch_descriptor_num_tokens = self.vllm_config.pad_for_cudagraph(dp_group_max_token_num)
|
||||
captured_already = True
|
||||
else:
|
||||
batch_descriptor_num_tokens = num_tokens
|
||||
captured_already = False
|
||||
|
||||
# Determine if we can use full graph
|
||||
decode_only = all(not prefill for prefill in main_model_dp_params.dp_is_prefill)
|
||||
# FIXME(wangchao2): disable mtp graph for ds3.2 with dp fow now(core dump)
|
||||
is_dsv32 = self.vllm_config.model_config.hf_config.model_type == "deepseek_v32"
|
||||
use_full_graph = (self.use_cuda_graph
|
||||
and decode_only and captured_already and not is_dsv32)
|
||||
if (self.use_cuda_graph and decode_only and not use_full_graph and not is_dsv32):
|
||||
logger.warning_once(
|
||||
f"Select MLU-V1 Full-MLUGraph mode with drafter, however running in " +
|
||||
f"eager mode: decode_only={decode_only}, captured_already={captured_already}, " +
|
||||
f"num_tokens={num_tokens}."
|
||||
)
|
||||
|
||||
cudagraph_runtime_mode = CUDAGraphMode.FULL if use_full_graph else CUDAGraphMode.NONE
|
||||
batch_descriptor = BatchDescriptor(
|
||||
num_tokens=batch_descriptor_num_tokens,
|
||||
uniform_decode=True,
|
||||
)
|
||||
|
||||
# dp pad batch_size
|
||||
if use_full_graph:
|
||||
K = self.num_speculative_tokens
|
||||
num_input_tokens = batch_descriptor_num_tokens
|
||||
padded_batch_size = num_input_tokens // (K + 1)
|
||||
else:
|
||||
padded_batch_size = batch_size
|
||||
num_input_tokens = num_tokens
|
||||
|
||||
# change attn metadata num_actual_tokens
|
||||
attn_metadata.num_actual_tokens = num_input_tokens
|
||||
|
||||
common_attn_metadata_copy = None
|
||||
# copy common_attn_metadata when k>1 for draft model,
|
||||
# because dp pad batch_size will change common_attn_metadata
|
||||
if self.num_speculative_tokens > 1:
|
||||
common_attn_metadata_copy = copy.deepcopy(common_attn_metadata)
|
||||
# pad attn metadata
|
||||
if use_full_graph and enable_data_parallel() and num_input_tokens != num_tokens:
|
||||
assert self.runner is not None
|
||||
# Update attention metadata.
|
||||
pad_attn_metadata(
|
||||
attn_metadata,
|
||||
common_attn_metadata,
|
||||
whole_block_table,
|
||||
self.runner,
|
||||
num_tokens,
|
||||
num_input_tokens,
|
||||
batch_size,
|
||||
padded_batch_size,
|
||||
)
|
||||
|
||||
# Update input ids, pad with 0 if necessary.
|
||||
token_pad_size = num_input_tokens - num_tokens
|
||||
assert token_pad_size >= 0
|
||||
# Update target hidden states, pad with zeros if necessary.
|
||||
if token_pad_size > 0:
|
||||
target_hidden_states = F.pad(
|
||||
target_hidden_states,
|
||||
(0, 0, 0, token_pad_size),
|
||||
value=0.0
|
||||
)
|
||||
|
||||
# Update positions, pad with zeros if necessary.
|
||||
if token_pad_size > 0:
|
||||
target_positions = F.pad(
|
||||
target_positions,
|
||||
(0, token_pad_size),
|
||||
value=0
|
||||
)
|
||||
|
||||
# At this moment, we assume all eagle layers belong to the same KV
|
||||
# cache group, thus using the same attention metadata.
|
||||
per_layer_attn_metadata = {}
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
per_layer_attn_metadata[COMMON_METADATA_STR] = common_attn_metadata
|
||||
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.positions[:num_input_tokens] = target_positions
|
||||
self.hidden_states[:num_input_tokens] = target_hidden_states
|
||||
|
||||
kwargs = {} if main_model_dp_params is None else {"dp_params": main_model_dp_params}
|
||||
if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN:
|
||||
start = torch.mlu.Event(enable_timing=True)
|
||||
start.record()
|
||||
with set_forward_context(per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
batch_descriptor=batch_descriptor if use_full_graph else None,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode):
|
||||
if use_full_graph:
|
||||
ret_hidden_states = self.model(
|
||||
input_ids=self.input_ids[:num_input_tokens],
|
||||
positions=self.positions[:num_input_tokens],
|
||||
hidden_states=self.hidden_states[:num_input_tokens],
|
||||
intermediate_tensors=None,
|
||||
inputs_embeds=None,
|
||||
is_running_drafter=True,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
ret_hidden_states = self.model(
|
||||
input_ids=self.input_ids[:num_input_tokens],
|
||||
positions=self.positions[:num_input_tokens],
|
||||
hidden_states=self.hidden_states[:num_input_tokens],
|
||||
**kwargs,
|
||||
)
|
||||
if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN:
|
||||
end = torch.mlu.Event(enable_timing=True)
|
||||
end.record()
|
||||
time_markers.append([start, end])
|
||||
if self.method == "mtp":
|
||||
last_hidden_states = ret_hidden_states
|
||||
else:
|
||||
last_hidden_states, hidden_states = ret_hidden_states
|
||||
'''
|
||||
=============================
|
||||
End of MLU Hijack
|
||||
=============================
|
||||
'''
|
||||
if main_model_dp_params is not None:
|
||||
# Ensure main_model_dp_params has required attribute before assignment
|
||||
if hasattr(main_model_dp_params, 'logits_batch_split_list'):
|
||||
main_model_dp_params.logits_batch_split_list = self.get_logits_batch_sizes(batch_size)
|
||||
else:
|
||||
raise AttributeError("dp_params must have 'logits_batch_split_list' attribute")
|
||||
|
||||
sample_hidden_states = last_hidden_states[hidden_states_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, dp_params=main_model_dp_params)
|
||||
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
|
||||
# Early exit if there is only one draft token to be generated.
|
||||
if self.num_speculative_tokens == 1:
|
||||
# [batch_size, 1]
|
||||
return draft_token_ids.view(-1, 1)
|
||||
|
||||
if self.uses_mrope:
|
||||
positions = target_positions[:, last_token_indices]
|
||||
else:
|
||||
positions = target_positions[last_token_indices]
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
'''
|
||||
hidden_states = last_hidden_states[hidden_states_indices]
|
||||
'''
|
||||
=============================
|
||||
End of MLU Hijack
|
||||
=============================
|
||||
'''
|
||||
# Generate the remaining draft tokens.
|
||||
draft_token_ids_list = [draft_token_ids]
|
||||
|
||||
input_batch_size = batch_size
|
||||
|
||||
if common_attn_metadata.infer_mode != MLUInferMode.DECODE_ONLY:
|
||||
seq_lens_cpu = torch.ones(input_batch_size, dtype=torch.int32,)
|
||||
cu_num_tokens = torch.cumsum(seq_lens_cpu, dim=0)
|
||||
query_start_loc_cpu = torch.empty(input_batch_size + 1, dtype=torch.int32)
|
||||
query_start_loc_cpu[0] = 0
|
||||
query_start_loc_cpu[1:] = cu_num_tokens
|
||||
seq_start_loc_cpu = self.arange[:input_batch_size + 1]
|
||||
common_attn_metadata_k = MLUCommonAttentionMetadata.build(
|
||||
query_start_loc=query_start_loc_cpu.to(self.device, non_blocking=True),
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=seq_lens_cpu.to(self.device, non_blocking=True),
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
seq_start_loc=seq_start_loc_cpu.to(self.device, non_blocking=True),
|
||||
is_start_loc_match=False, # not prefill
|
||||
max_query_len=1,
|
||||
num_actual_tokens=input_batch_size,
|
||||
num_input_tokens=input_batch_size,
|
||||
num_speculative_tokens=self.num_speculative_tokens,
|
||||
has_prefill_reqs=common_attn_metadata.infer_mode == MLUInferMode.CHUNKED,
|
||||
)
|
||||
else:
|
||||
common_attn_metadata_k = common_attn_metadata_copy
|
||||
common_attn_metadata_k.num_actual_tokens = batch_size
|
||||
common_attn_metadata_k.num_input_tokens = batch_size
|
||||
common_attn_metadata_k.max_query_len = 1
|
||||
common_attn_metadata_k.query_start_loc = self.arange[: batch_size + 1]
|
||||
common_attn_metadata_k.query_start_loc_cpu = torch.from_numpy(
|
||||
self.token_arange_np[: batch_size + 1]
|
||||
).clone()
|
||||
# In padded drafter batch, we need to adjust the sequence lengths
|
||||
# to remove the "padding" (i.e. rejected tokens).
|
||||
# Only apply this adjustment when we have rejected tokens
|
||||
# (i.e., not the first proposal).
|
||||
for token_index in range(self.num_speculative_tokens - 1):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: get dp_params for draft model
|
||||
'''
|
||||
# dp_params for draft model
|
||||
if main_model_dp_params is not None:
|
||||
dp_params = self.runner._get_data_parallel_metadata(
|
||||
input_batch_size,
|
||||
input_batch_size,
|
||||
common_attn_metadata.is_decode_only,
|
||||
[1] * input_batch_size
|
||||
)
|
||||
kwargs = {} if main_model_dp_params is None else {"dp_params": dp_params}
|
||||
'''
|
||||
=============================
|
||||
End of MLU Hijack
|
||||
=============================
|
||||
'''
|
||||
# Update the inputs.
|
||||
# cast to int32 is crucial when eagle model is compiled.
|
||||
# tensor.argmax() returns int64 by default.
|
||||
input_ids = draft_token_ids_list[-1].int()
|
||||
if self.uses_mrope:
|
||||
positions += 1
|
||||
# NOTE(woosuk): We should handle the case where the draft model
|
||||
# generates tokens beyond the max model length.
|
||||
# Since it is complex to remove such requests from the batch,
|
||||
# we keep them in the batch but adjust the position ids
|
||||
# and slot mappings to avoid the
|
||||
# out-of-range access during the model execution.
|
||||
# The draft tokens generated with this adjustment
|
||||
# should be ignored.
|
||||
exceeds_max_model_len = positions[0] >= self.max_model_len
|
||||
# Mask out the position ids that exceed the max model length.
|
||||
# Otherwise, we may get out-of-range error in RoPE.
|
||||
clamped_positions = torch.where(
|
||||
exceeds_max_model_len.unsqueeze(0),
|
||||
torch.zeros_like(positions),
|
||||
positions,
|
||||
)
|
||||
else:
|
||||
positions += 1
|
||||
exceeds_max_model_len = positions >= self.max_model_len
|
||||
clamped_positions = torch.where(exceeds_max_model_len, 0, positions)
|
||||
|
||||
# For data integrity when async scheduling, we shouldn't use in place
|
||||
# operations in case they are modified in next step's `prepare_input`
|
||||
# of main model.
|
||||
# Increment the sequence lengths.
|
||||
common_attn_metadata_k.seq_lens += 1
|
||||
# This is an out-of-place operation to avoid modifying the original tensor.
|
||||
common_attn_metadata_k.seq_lens_cpu = common_attn_metadata_k.seq_lens_cpu + 1
|
||||
# For the requests that exceed the max model length, we set the
|
||||
# sequence length to 1 to minimize their overheads in attention.
|
||||
|
||||
common_attn_metadata_k.seq_lens.masked_fill_(exceeds_max_model_len, 1)
|
||||
|
||||
common_attn_metadata_k.num_computed_tokens_cpu = (
|
||||
common_attn_metadata_k.seq_lens_cpu - 1
|
||||
)
|
||||
|
||||
# Compute the slot mapping.
|
||||
if self.uses_mrope:
|
||||
# all dimensions of positions are the same
|
||||
block_numbers = clamped_positions[0] // self.block_size
|
||||
else:
|
||||
block_numbers = clamped_positions // self.block_size
|
||||
block_ids = common_attn_metadata_k.block_table_tensor.gather(
|
||||
dim=1, index=block_numbers.view(-1, 1)
|
||||
)
|
||||
block_ids = block_ids.view(-1)
|
||||
if self.uses_mrope:
|
||||
common_attn_metadata_k.slot_mapping = (
|
||||
block_ids * self.block_size + clamped_positions[0] % self.block_size
|
||||
)
|
||||
else:
|
||||
common_attn_metadata_k.slot_mapping = (
|
||||
block_ids * self.block_size + clamped_positions % self.block_size
|
||||
)
|
||||
# Mask out the slot mappings that exceed the max model length.
|
||||
# Otherwise, the KV cache will be inadvertently updated with the
|
||||
# padding tokens.
|
||||
common_attn_metadata_k.slot_mapping.masked_fill_(
|
||||
exceeds_max_model_len, PADDING_SLOT_ID
|
||||
)
|
||||
|
||||
# Rebuild attention metadata
|
||||
attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore
|
||||
common_attn_metadata=common_attn_metadata_k, draft_index=token_index + 1
|
||||
)
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
per_layer_attn_metadata[COMMON_METADATA_STR] = common_attn_metadata_k
|
||||
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.input_ids[:batch_size] = input_ids
|
||||
self.positions[:batch_size] = clamped_positions
|
||||
self.hidden_states[:batch_size] = hidden_states
|
||||
if self.supports_mm_inputs:
|
||||
self.inputs_embeds[:batch_size] = self.model.embed_input_ids(input_ids)
|
||||
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds[:input_batch_size]
|
||||
else:
|
||||
input_ids = self.input_ids[:input_batch_size]
|
||||
inputs_embeds = None
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: record latency
|
||||
'''
|
||||
if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN:
|
||||
start = torch.mlu.Event(enable_timing=True)
|
||||
start.record()
|
||||
'''
|
||||
=============================
|
||||
End of MLU Hijack
|
||||
=============================
|
||||
'''
|
||||
# Run the model.
|
||||
with set_forward_context(per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=input_batch_size):
|
||||
ret_hidden_states = self.model(
|
||||
input_ids=self.input_ids[:input_batch_size],
|
||||
positions=self.positions[:input_batch_size],
|
||||
hidden_states=self.hidden_states[:input_batch_size],
|
||||
**kwargs,
|
||||
)
|
||||
if self.method == "mtp":
|
||||
last_hidden_states = ret_hidden_states
|
||||
else:
|
||||
last_hidden_states, hidden_states = ret_hidden_states
|
||||
if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN:
|
||||
end = torch.mlu.Event(enable_timing=True)
|
||||
end.record()
|
||||
time_markers.append([start, end])
|
||||
'''
|
||||
=============================
|
||||
End of MLU Hijack
|
||||
=============================
|
||||
'''
|
||||
hidden_states = last_hidden_states[:batch_size]
|
||||
logits = self.model.compute_logits(last_hidden_states[:batch_size],
|
||||
dp_params=dp_params)
|
||||
|
||||
# TODO(wenlong): get more than one token for tree attention
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
draft_token_ids_list.append(draft_token_ids)
|
||||
|
||||
# [batch_size, num_speculative_tokens]
|
||||
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
||||
return draft_token_ids
|
||||
1067
vllm_mlu/v1/spec_decode/eagle.py
Normal file
1067
vllm_mlu/v1/spec_decode/eagle.py
Normal file
File diff suppressed because it is too large
Load Diff
3
vllm_mlu/v1/worker/__init__.py
Normal file
3
vllm_mlu/v1/worker/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
112
vllm_mlu/v1/worker/block_table.py
Normal file
112
vllm_mlu/v1/worker/block_table.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.distributed import get_dcp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BlockTable_MluHijack(BlockTable):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_size: int,
|
||||
max_num_reqs: int,
|
||||
max_num_blocks_per_req: int,
|
||||
max_num_batched_tokens: int,
|
||||
pin_memory: bool,
|
||||
device: torch.device,
|
||||
kernel_block_size: int,
|
||||
dcp_kv_cache_interleave_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
block_size: Block size used for KV cache memory allocation
|
||||
max_num_reqs: Maximum number of concurrent requests supported.
|
||||
max_num_blocks_per_req: Maximum number of blocks per request.
|
||||
max_num_batched_tokens: Maximum number of tokens in a batch.
|
||||
pin_memory: Whether to pin memory for faster GPU transfers.
|
||||
device: Target device for the block table.
|
||||
kernel_block_size: The block_size of underlying attention kernel.
|
||||
Will be the same as `block_size` if `block_size` is supported
|
||||
by the attention kernel.
|
||||
"""
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.pin_memory = pin_memory
|
||||
self.device = device
|
||||
|
||||
if kernel_block_size == block_size:
|
||||
# Standard case: allocation and computation use same block size
|
||||
# No block splitting needed, direct mapping
|
||||
self.block_size = block_size
|
||||
self.blocks_per_kv_block = 1
|
||||
self.use_hybrid_blocks = False
|
||||
else:
|
||||
# Hybrid case: allocation block size differs from kernel block size
|
||||
# Memory blocks are subdivided to match kernel requirements
|
||||
# Example: 32-token memory blocks with 16-token kernel blocks
|
||||
# → Each memory block corresponds to 2 kernel blocks
|
||||
if block_size % kernel_block_size != 0:
|
||||
raise ValueError(
|
||||
f"kernel_block_size {kernel_block_size} must divide "
|
||||
f"kv_manager_block_size size {block_size} evenly"
|
||||
)
|
||||
|
||||
self.block_size = kernel_block_size
|
||||
self.blocks_per_kv_block = block_size // kernel_block_size
|
||||
self.use_hybrid_blocks = True
|
||||
|
||||
self.max_num_blocks_per_req = max_num_blocks_per_req * self.blocks_per_kv_block
|
||||
|
||||
self.block_table = self._make_buffer(
|
||||
self.max_num_reqs, self.max_num_blocks_per_req, dtype=torch.int32
|
||||
)
|
||||
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: change slot_mapping dtype for int64 to int32
|
||||
'''
|
||||
self.slot_mapping = self._make_buffer(
|
||||
self.max_num_batched_tokens, dtype=torch.int32
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
if self.use_hybrid_blocks:
|
||||
self._kernel_block_arange = np.arange(0, self.blocks_per_kv_block).reshape(
|
||||
1, -1
|
||||
)
|
||||
else:
|
||||
self._kernel_block_arange = None
|
||||
|
||||
try:
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
self.dcp_world_size = 1
|
||||
self.dcp_rank = 0
|
||||
self.dcp_kv_cache_interleave_size = dcp_kv_cache_interleave_size
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
BlockTable,
|
||||
BlockTable.__init__,
|
||||
BlockTable_MluHijack.__init__
|
||||
)
|
||||
1007
vllm_mlu/v1/worker/dp_gpu_model_runner.py
Normal file
1007
vllm_mlu/v1/worker/dp_gpu_model_runner.py
Normal file
File diff suppressed because it is too large
Load Diff
25
vllm_mlu/v1/worker/gpu_input_batch.py
Normal file
25
vllm_mlu/v1/worker/gpu_input_batch.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
def split_decodes_and_prefills(self):
|
||||
decodes = 0
|
||||
prefills = 0
|
||||
for i, req_id in enumerate(self.req_ids):
|
||||
req_index = self.req_id_to_index.get(req_id)
|
||||
num_prompt_tokens = self.num_prompt_tokens[req_index]
|
||||
num_computed_tokens = self.num_computed_tokens_cpu[req_index]
|
||||
if num_computed_tokens < num_prompt_tokens:
|
||||
prefills += 1
|
||||
else:
|
||||
decodes += 1
|
||||
return decodes, prefills
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(InputBatch,
|
||||
"split_decodes_and_prefills",
|
||||
split_decodes_and_prefills)
|
||||
4166
vllm_mlu/v1/worker/gpu_model_runner.py
Normal file
4166
vllm_mlu/v1/worker/gpu_model_runner.py
Normal file
File diff suppressed because it is too large
Load Diff
638
vllm_mlu/v1/worker/gpu_worker.py
Normal file
638
vllm_mlu/v1/worker/gpu_worker.py
Normal file
@@ -0,0 +1,638 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""A GPU worker class."""
|
||||
import copy
|
||||
import gc
|
||||
import os
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from types import NoneType
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.parallel_state import get_tp_group, get_pp_group
|
||||
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
|
||||
has_kv_transfer_group)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
|
||||
from vllm.v1.utils import report_usage_stats
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.utils.mem_constants import GiB_bytes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
from vllm_mlu.model_executor.warmup.kernel_warmup import kernel_warmup
|
||||
from vllm_mlu.profiler.mlu_profiler import MluProfilerWrapper
|
||||
from vllm_mlu.utils import MemorySnapshot, memory_profiling
|
||||
from vllm_mlu._mlu_utils import VLLM_DUMP_MLU_INFO_EN
|
||||
from vllm_mlu.device_allocator.cnmem import CnMemAllocator
|
||||
from vllm_mlu.v1.worker.mlu_quant import MLUWorkerQuant
|
||||
from vllm_mlu.v1.worker.gpu_model_runner import MLUModelRunner
|
||||
from vllm_mlu.v1.worker.dp_gpu_model_runner import DPMLUModelRunner
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MLUWorker(Worker, MLUWorkerQuant):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False,
|
||||
):
|
||||
|
||||
WorkerBase.__init__(self, vllm_config=vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=is_driver_worker)
|
||||
|
||||
if self.model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils.import_utils import init_cached_hf_modules
|
||||
init_cached_hf_modules()
|
||||
|
||||
# Buffers saved before sleep
|
||||
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
|
||||
|
||||
# Torch profiler. Enabled and configured through env vars:
|
||||
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
|
||||
logger.info(
|
||||
"Profiling enabled. Traces will be saved to: %s",
|
||||
torch_profiler_trace_dir,
|
||||
)
|
||||
logger.debug(
|
||||
"Profiler config: record_shapes=%s,"
|
||||
"profile_memory=%s,with_stack=%s,with_flops=%s",
|
||||
envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
|
||||
envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
|
||||
envs.VLLM_TORCH_PROFILER_WITH_STACK,
|
||||
envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
|
||||
)
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.MLU,
|
||||
],
|
||||
record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
|
||||
profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
|
||||
with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
|
||||
with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
torch_profiler_trace_dir, worker_name=worker_name, use_gzip=True
|
||||
),
|
||||
)
|
||||
elif envs.VLLM_TORCH_CUDA_PROFILE:
|
||||
self.profiler = MluProfilerWrapper()
|
||||
else:
|
||||
self.profiler = None
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
free_bytes_before_sleep = torch.mlu.mem_get_info()[0]
|
||||
|
||||
# Save the buffers before level 2 sleep
|
||||
if level == 2:
|
||||
model = self.model_runner.model
|
||||
self._sleep_saved_buffers = {
|
||||
name: buffer.cpu().clone() for name, buffer in model.named_buffers()
|
||||
}
|
||||
|
||||
allocator = CnMemAllocator.get_instance()
|
||||
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
|
||||
free_bytes_after_sleep, total = torch.mlu.mem_get_info()
|
||||
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
|
||||
used_bytes = total - free_bytes_after_sleep
|
||||
assert freed_bytes >= 0, "Memory usage increased after sleeping."
|
||||
logger.info(
|
||||
"Sleep mode freed %.2f GiB memory, "
|
||||
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
|
||||
used_bytes / GiB_bytes)
|
||||
|
||||
def wake_up(self, tags: Optional[list[str]] = None) -> None:
|
||||
allocator = CnMemAllocator.get_instance()
|
||||
allocator.wake_up(tags)
|
||||
|
||||
# Restore the buffers after level 2 sleep
|
||||
if len(self._sleep_saved_buffers):
|
||||
model = self.model_runner.model
|
||||
for name, buffer in model.named_buffers():
|
||||
if name in self._sleep_saved_buffers:
|
||||
buffer.data.copy_(self._sleep_saved_buffers[name].data)
|
||||
self._sleep_saved_buffers = {}
|
||||
|
||||
def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
allocator = CnMemAllocator.get_instance()
|
||||
if tag == "weights":
|
||||
assert allocator.get_current_usage() == 0, (
|
||||
"Sleep mode can only be used for one instance per process."
|
||||
)
|
||||
context = allocator.use_memory_pool(tag=tag)
|
||||
else:
|
||||
context = nullcontext()
|
||||
return context
|
||||
|
||||
def init_device(self):
|
||||
if self.device_config.device.type == "mlu":
|
||||
# This env var set by Ray causes exceptions with graph building.
|
||||
os.environ.pop("CNCL_ASYNC_ERROR_HANDLING", None)
|
||||
# if (
|
||||
# self.parallel_config.data_parallel_size > 1
|
||||
# and self.parallel_config.data_parallel_size_local > 0
|
||||
# and self.parallel_config.distributed_executor_backend
|
||||
# not in ["ray", "external_launcher"]
|
||||
# and self.vllm_config.parallel_config.data_parallel_backend != "ray"
|
||||
# ):
|
||||
# # Use local DP rank if available, otherwise use global DP rank.
|
||||
# dp_local_rank = self.parallel_config.data_parallel_rank_local
|
||||
# if dp_local_rank is None:
|
||||
# dp_local_rank = self.parallel_config.data_parallel_rank
|
||||
|
||||
# tp_pp_world_size = (
|
||||
# self.parallel_config.pipeline_parallel_size
|
||||
# * self.parallel_config.tensor_parallel_size
|
||||
# )
|
||||
|
||||
# # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
|
||||
# self.local_rank += dp_local_rank * tp_pp_world_size
|
||||
# assert self.local_rank < torch.mlu.device_count(), (
|
||||
# f"DP adjusted local rank {self.local_rank} is out of bounds. "
|
||||
# )
|
||||
|
||||
self.device = torch.device(f"mlu:{self.local_rank}")
|
||||
current_platform.set_device(self.device)
|
||||
|
||||
current_platform.check_if_supports_dtype(self.model_config.dtype)
|
||||
|
||||
# Initialize the distributed environment BEFORE taking
|
||||
# memory snapshot
|
||||
# This ensures NCCL buffers are allocated before we measure
|
||||
# available memory
|
||||
init_worker_distributed_environment(
|
||||
self.vllm_config,
|
||||
self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank,
|
||||
current_platform.dist_backend,
|
||||
)
|
||||
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
gc.collect()
|
||||
torch.mlu.empty_cache()
|
||||
|
||||
# take current memory snapshot
|
||||
self.init_snapshot = MemorySnapshot()
|
||||
self.requested_memory = (
|
||||
self.init_snapshot.total_memory
|
||||
* self.cache_config.gpu_memory_utilization
|
||||
)
|
||||
if self.init_snapshot.free_memory < self.requested_memory:
|
||||
GiB = lambda b: round(b / GiB_bytes, 2)
|
||||
raise ValueError(
|
||||
f"Free memory on device "
|
||||
f"({GiB(self.init_snapshot.free_memory)}/"
|
||||
f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
|
||||
f"is less than desired GPU memory utilization "
|
||||
f"({self.cache_config.gpu_memory_utilization}, "
|
||||
f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
|
||||
f"utilization or reduce GPU memory used by other processes."
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Not support device type: {self.device_config.device}")
|
||||
|
||||
# Construct the model runner
|
||||
model_runner_cls = (DPMLUModelRunner
|
||||
if self._enable_moe_dp_opt() else MLUModelRunner)
|
||||
self.model_runner: MLUModelRunner = model_runner_cls(
|
||||
self.vllm_config, self.device)
|
||||
|
||||
if self.rank == 0:
|
||||
# If usage stat is enabled, collect relevant info.
|
||||
report_usage_stats(self.vllm_config)
|
||||
|
||||
@torch.inference_mode()
|
||||
def determine_available_memory(self) -> int:
|
||||
"""Profiles the peak memory usage of the model to determine how much
|
||||
memory can be used for KV cache without OOMs.
|
||||
|
||||
The engine will first conduct a profiling of the existing memory usage.
|
||||
Then, it calculate the free memory that can be used for KV cache in
|
||||
bytes.
|
||||
|
||||
Tip:
|
||||
You may limit the usage of GPU memory
|
||||
by adjusting the `gpu_memory_utilization` parameter.
|
||||
"""
|
||||
GiB = lambda b: b / GiB_bytes
|
||||
if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
|
||||
# still need a profile run which compiles the model for
|
||||
# max_num_batched_tokens
|
||||
self.model_runner.profile_run()
|
||||
|
||||
msg = (
|
||||
f"Initial free memory {GiB(self.init_snapshot.free_memory):.2f} "
|
||||
f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f} GiB memory for "
|
||||
"KV Cache as specified by kv_cache_memory_bytes config and "
|
||||
"skipped memory profiling. This does not respect the "
|
||||
"gpu_memory_utilization config. Only use kv_cache_memory_bytes "
|
||||
"config when you want manual control of KV cache memory "
|
||||
"size. If OOM'ed, check the difference of initial free "
|
||||
"memory between the current run and the previous run "
|
||||
"where kv_cache_memory_bytes is suggested and update it "
|
||||
"correspondingly."
|
||||
)
|
||||
logger.info(msg)
|
||||
return kv_cache_memory_bytes
|
||||
|
||||
torch.mlu.empty_cache()
|
||||
torch.mlu.reset_peak_memory_stats()
|
||||
|
||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||
# of the model.
|
||||
with memory_profiling(
|
||||
self.init_snapshot,
|
||||
weights_memory=int(self.model_runner.model_memory_usage),
|
||||
) as profile_result:
|
||||
self.model_runner.profile_run()
|
||||
|
||||
self.non_torch_memory = profile_result.non_torch_increase
|
||||
self.peak_activation_memory = profile_result.torch_peak_increase
|
||||
|
||||
free_gpu_memory = profile_result.after_profile.free_memory
|
||||
GiB = lambda b: b / GiB_bytes
|
||||
|
||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||
# of the model.
|
||||
with memory_profiling(
|
||||
self.init_snapshot,
|
||||
weights_memory=int(
|
||||
self.model_runner.model_memory_usage)) as profile_result:
|
||||
self.model_runner.profile_run()
|
||||
|
||||
free_gpu_memory = profile_result.after_profile.free_memory
|
||||
# NOTE(woosuk): Here we assume that the other processes using the same
|
||||
# GPU did not change their memory usage during the profiling.
|
||||
assert self.init_snapshot.free_memory > free_gpu_memory, (
|
||||
"Error in memory profiling. "
|
||||
f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, "
|
||||
f"current free memory {GiB(free_gpu_memory)} GiB. "
|
||||
"This happens when other processes sharing the same container "
|
||||
"release GPU memory while vLLM is profiling during initialization. "
|
||||
"To fix this, ensure consistent GPU memory allocation or "
|
||||
"isolate vLLM in its own container."
|
||||
)
|
||||
self.available_kv_cache_memory_bytes = (
|
||||
self.requested_memory - profile_result.non_kv_cache_memory
|
||||
)
|
||||
|
||||
unrequested_memory = self.init_snapshot.free_memory - self.requested_memory
|
||||
logger.debug(
|
||||
"Initial free memory: %.2f GiB; Requested memory: %.2f (util), %.2f GiB",
|
||||
GiB(self.init_snapshot.free_memory),
|
||||
self.cache_config.gpu_memory_utilization,
|
||||
GiB(self.requested_memory),
|
||||
)
|
||||
logger.debug(
|
||||
"Free memory after profiling: %.2f GiB (total), "
|
||||
"%.2f GiB (within requested)",
|
||||
GiB(free_gpu_memory),
|
||||
GiB(free_gpu_memory - unrequested_memory),
|
||||
)
|
||||
logger.debug(profile_result)
|
||||
logger.info_once(
|
||||
"Available KV cache memory: %.2f GiB",
|
||||
GiB(self.available_kv_cache_memory_bytes),
|
||||
scope="local",
|
||||
)
|
||||
gc.collect()
|
||||
|
||||
self.peak_memory = profile_result.non_kv_cache_memory
|
||||
self.block_memory = self.available_kv_cache_memory_bytes
|
||||
|
||||
|
||||
return int(self.available_kv_cache_memory_bytes)
|
||||
|
||||
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""Allocate GPU KV cache with the specified kv_cache_config."""
|
||||
|
||||
# Init kv cache connector here, because it requires
|
||||
# `kv_cache_config`.
|
||||
# NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
|
||||
# because `initialize_kv_cache` will inject kv cache groups not
|
||||
# related to kv cache connector (e.g. kv cache sharing layers).
|
||||
ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config)
|
||||
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
allocator = CnMemAllocator.get_instance()
|
||||
context = allocator.use_memory_pool(tag="kv_cache")
|
||||
else:
|
||||
context = nullcontext()
|
||||
with context:
|
||||
self.model_runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
# warm up sizes that are not in cudagraph capture sizes,
|
||||
# but users still want to compile for better performance,
|
||||
# e.g. for the max-num-batched token size in chunked prefill.
|
||||
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
|
||||
if not self.model_config.enforce_eager:
|
||||
warmup_sizes = [
|
||||
x for x in warmup_sizes
|
||||
if x not in self.vllm_config.compilation_config.cudagraph_capture_sizes
|
||||
]
|
||||
# We skip EPLB here since we don't want to record dummy metrics
|
||||
for size in sorted(warmup_sizes, reverse=True):
|
||||
logger.info("Compile and warming up model for size %d", size)
|
||||
self.model_runner._dummy_run(size, skip_eplb=True, remove_lora=False)
|
||||
self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
|
||||
|
||||
# Warmup and tune the kernels used during model execution before
|
||||
# cuda graph capture.
|
||||
kernel_warmup(self)
|
||||
|
||||
cuda_graph_memory_bytes = 0
|
||||
if not self.model_config.enforce_eager:
|
||||
cuda_graph_memory_bytes = self.model_runner.capture_model()
|
||||
|
||||
if self.cache_config.kv_cache_memory_bytes is None and hasattr(
|
||||
self, "peak_activation_memory"
|
||||
):
|
||||
# Suggests optimal kv cache memory size if we rely on
|
||||
# memory_profiling to guess the kv cache memory size which
|
||||
# provides peak_activation_memory and a few other memory
|
||||
# consumption. `memory_profiling` does not consider
|
||||
# CUDAGraph memory size and may not utilize all gpu memory.
|
||||
# Users may want fine-grained control to specify kv cache
|
||||
# memory size.
|
||||
GiB = lambda b: round(b / GiB_bytes, 2)
|
||||
|
||||
# empirically observed that the memory profiling may
|
||||
# slightly underestimate the memory consumption.
|
||||
# So leave a small buffer (=150MiB) to avoid OOM.
|
||||
redundancy_buffer_memory = 150 * (1 << 20)
|
||||
non_kv_cache_memory = (
|
||||
self.model_runner.model_memory_usage
|
||||
+ self.peak_activation_memory
|
||||
+ self.non_torch_memory
|
||||
+ cuda_graph_memory_bytes
|
||||
)
|
||||
kv_cache_memory_bytes_to_gpu_limit = (
|
||||
self.init_snapshot.free_memory
|
||||
- non_kv_cache_memory
|
||||
- redundancy_buffer_memory
|
||||
)
|
||||
kv_cache_memory_bytes_to_requested_limit = (
|
||||
int(self.requested_memory)
|
||||
- non_kv_cache_memory
|
||||
- redundancy_buffer_memory
|
||||
)
|
||||
|
||||
msg = (
|
||||
f"Free memory on device "
|
||||
f"({GiB(self.init_snapshot.free_memory)}/"
|
||||
f"{GiB(self.init_snapshot.total_memory)} GiB) on startup. "
|
||||
f"Desired GPU memory utilization is "
|
||||
f"({self.cache_config.gpu_memory_utilization}, "
|
||||
f"{GiB(self.requested_memory)} GiB). "
|
||||
f"Actual usage is {GiB(self.model_runner.model_memory_usage)} "
|
||||
f"GiB for weight, {GiB(self.peak_activation_memory)} GiB "
|
||||
f"for peak activation, {GiB(self.non_torch_memory)} GiB "
|
||||
f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
|
||||
f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
|
||||
f"config with `--kv-cache-memory="
|
||||
f"{kv_cache_memory_bytes_to_requested_limit}` "
|
||||
f"({GiB(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit "
|
||||
f"into requested memory, or `--kv-cache-memory="
|
||||
f"{kv_cache_memory_bytes_to_gpu_limit}` "
|
||||
f"({GiB(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully "
|
||||
f"utilize gpu memory. Current kv cache memory in use is "
|
||||
f"{GiB(self.available_kv_cache_memory_bytes)} GiB."
|
||||
)
|
||||
|
||||
logger.debug(msg)
|
||||
|
||||
# Warm up sampler and preallocate memory buffer for logits and other
|
||||
# sampling related tensors of max possible shape to avoid memory
|
||||
# fragmentation issue.
|
||||
# NOTE: This is called after `capture_model` on purpose to prevent
|
||||
# memory buffers from being cleared by `torch.cuda.empty_cache`.
|
||||
if get_pp_group().is_last_rank:
|
||||
max_num_reqs = min(
|
||||
self.scheduler_config.max_num_seqs,
|
||||
self.scheduler_config.max_num_batched_tokens,
|
||||
)
|
||||
|
||||
# We skip EPLB here since we don't want to record dummy metrics
|
||||
hidden_states, last_hidden_states = self.model_runner._dummy_run(
|
||||
num_tokens=max_num_reqs,
|
||||
skip_eplb=True,
|
||||
)
|
||||
if self.model_runner.is_pooling_model:
|
||||
self.model_runner._dummy_pooler_run(hidden_states)
|
||||
else:
|
||||
self.model_runner._dummy_sampler_run(hidden_states=last_hidden_states)
|
||||
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self, scheduler_output: "SchedulerOutput",
|
||||
) -> ModelRunnerOutput | None:
|
||||
intermediate_tensors = None
|
||||
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
num_input_tokens = self.model_runner._get_num_input_tokens(num_scheduled_tokens)
|
||||
all_gather_tensors = {
|
||||
"residual": not is_residual_scattered_for_sp(
|
||||
self.vllm_config, num_input_tokens
|
||||
)
|
||||
}
|
||||
if forward_pass and not get_pp_group().is_first_rank:
|
||||
intermediate_tensors = IntermediateTensors(
|
||||
get_pp_group().recv_tensor_dict(
|
||||
all_gather_group=get_tp_group(),
|
||||
all_gather_tensors=all_gather_tensors,
|
||||
)
|
||||
)
|
||||
|
||||
with self.annotate_profile(scheduler_output):
|
||||
output = self.model_runner.execute_model(
|
||||
scheduler_output, intermediate_tensors
|
||||
)
|
||||
if isinstance(output, (ModelRunnerOutput, NoneType)):
|
||||
return output
|
||||
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
assert (
|
||||
parallel_config.distributed_executor_backend != "external_launcher"
|
||||
and not get_pp_group().is_last_rank
|
||||
)
|
||||
|
||||
get_pp_group().send_tensor_dict(
|
||||
output.tensors,
|
||||
all_gather_group=get_tp_group(),
|
||||
all_gather_tensors=all_gather_tensors,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _enable_moe_dp_opt(self):
|
||||
'''
|
||||
We will enable the MLU-optimized DP scheme for the specified MoE models,
|
||||
otherwise the native DP implementation will be used.
|
||||
'''
|
||||
# case0 enable data parallel
|
||||
enable_dp = self.parallel_config.data_parallel_size > 1
|
||||
# case1 ds mla
|
||||
is_ds_mla = self.model_config.is_deepseek_mla
|
||||
# case2 qwen3 moe
|
||||
is_supported_moe_model = hasattr(self.model_config.hf_text_config, "model_type") and \
|
||||
self.model_config.hf_text_config.model_type in ('qwen3_moe', 'glm4_moe')
|
||||
# case 3, private model
|
||||
is_private_model = getattr(self.model_config.hf_config, "is_private", False)
|
||||
return enable_dp and (is_ds_mla or is_supported_moe_model or is_private_model)
|
||||
|
||||
def execute_dummy_batch(self) -> None:
|
||||
if self._enable_moe_dp_opt():
|
||||
self.model_runner.moe_dp_execute_dummy_batch(1)
|
||||
else:
|
||||
self.model_runner._dummy_run(1, uniform_decode=True)
|
||||
|
||||
def response_remote_alloc_once(self) -> None:
|
||||
self.model_runner.response_remote_alloc_once()
|
||||
|
||||
def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info(
|
||||
"[Elastic EP] Starting expert resharding before scaling down..."
|
||||
)
|
||||
rank_mapping = {
|
||||
old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
|
||||
for old_ep_rank in range(old_ep_size)
|
||||
}
|
||||
assert self.model_runner.eplb_state is not None
|
||||
self.model_runner.eplb_state.rearrange(
|
||||
execute_shuffle=True,
|
||||
global_expert_load=None,
|
||||
rank_mapping=rank_mapping,
|
||||
)
|
||||
torch.mlu.synchronize()
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info("[Elastic EP] Expert resharding completed!")
|
||||
|
||||
def reinitialize_distributed(
|
||||
self, reconfig_request: ReconfigureDistributedRequest
|
||||
) -> None:
|
||||
from vllm.config import set_current_vllm_config
|
||||
from vllm.distributed.parallel_state import (
|
||||
cleanup_dist_env_and_memory,
|
||||
get_ep_group,
|
||||
)
|
||||
|
||||
old_ep_size = get_ep_group().world_size
|
||||
old_ep_rank = get_ep_group().rank
|
||||
new_ep_size = (
|
||||
reconfig_request.new_data_parallel_size
|
||||
* get_tp_group().world_size
|
||||
* get_pp_group().world_size
|
||||
)
|
||||
if new_ep_size < old_ep_size:
|
||||
self._eplb_before_scale_down(old_ep_size, new_ep_size)
|
||||
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank
|
||||
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
|
||||
):
|
||||
assert old_ep_rank >= new_ep_size
|
||||
# shutdown
|
||||
return
|
||||
|
||||
self._reconfigure_parallel_config(reconfig_request)
|
||||
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
init_worker_distributed_environment(
|
||||
self.vllm_config,
|
||||
self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank,
|
||||
current_platform.dist_backend,
|
||||
)
|
||||
|
||||
global_expert_loads = self._reconfigure_moe(old_ep_size, new_ep_size)
|
||||
|
||||
if new_ep_size > old_ep_size:
|
||||
assert global_expert_loads is not None
|
||||
self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_loads)
|
||||
|
||||
def get_hfu_info(self, batch, input_len, output_len):
|
||||
try:
|
||||
self.model_runner.model.collect_hfu_io_effciency_info(batch, input_len, output_len)
|
||||
if VLLM_DUMP_MLU_INFO_EN:
|
||||
return self.model_runner.model.hfu_info, self.model_runner.model.io_efficiency
|
||||
else:
|
||||
return self.model_runner.model.flops_info, 0.0
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
"Model match failure when get HFU info, please check if an init method was registed."
|
||||
)
|
||||
|
||||
def _get_latency(self, time_markers):
|
||||
total_latency = 0
|
||||
if not isinstance(time_markers, list):
|
||||
time_markers = [time_markers]
|
||||
for time_marker in time_markers:
|
||||
start, end = time_marker
|
||||
latency = start.elapsed_time(end)
|
||||
total_latency += latency
|
||||
return total_latency
|
||||
|
||||
def get_latency(self):
|
||||
return self._get_latency(self.model_runner.time_markers)
|
||||
|
||||
def get_mm_encoder_latency(self):
|
||||
if not hasattr(self.model_runner, "mm_time_markers"):
|
||||
return None
|
||||
mm_time_markers = self.model_runner.mm_time_markers
|
||||
return None if len(mm_time_markers) == 0 else\
|
||||
self._get_latency(mm_time_markers)
|
||||
|
||||
def get_memory_usage(self):
|
||||
return (self.peak_memory, self.block_memory)
|
||||
|
||||
def recapture_model(self,
|
||||
prefill_enable_mlugraph: bool,
|
||||
batch_size: int,
|
||||
input_len: int):
|
||||
# Reset history capture context
|
||||
self.model_runner.reset_capture_context(
|
||||
prefill_enable_mlugraph, batch_size, input_len)
|
||||
# Re-capture decode graph(full graph or peicewise graph)
|
||||
self.compile_or_warm_up_model()
|
||||
120
vllm_mlu/v1/worker/kv_connector_model_runner_mixin.py
Normal file
120
vllm_mlu/v1/worker/kv_connector_model_runner_mixin.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
"""
|
||||
Define KV connector functionality mixin for model runners.
|
||||
"""
|
||||
|
||||
import copy
|
||||
from collections.abc import Generator
|
||||
from contextlib import AbstractContextManager, contextmanager, nullcontext
|
||||
from typing import (
|
||||
TYPE_CHECKING, # noqa: UP035
|
||||
)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer import (
|
||||
ensure_kv_transfer_shutdown,
|
||||
get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.outputs import (
|
||||
EMPTY_MODEL_RUNNER_OUTPUT,
|
||||
KVConnectorOutput,
|
||||
ModelRunnerOutput,
|
||||
)
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# Defined as a kv connector functionality mixin for ModelRunner (GPU, TPU)
|
||||
class KVConnectorModelRunnerMixin_MluHijack(KVConnectorModelRunnerMixin):
|
||||
@staticmethod
|
||||
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
|
||||
# Update KVConnector with the KVConnector metadata forward().
|
||||
if has_kv_transfer_group():
|
||||
kv_connector = get_kv_transfer_group()
|
||||
assert isinstance(kv_connector, KVConnectorBase)
|
||||
assert scheduler_output.kv_connector_metadata is not None
|
||||
kv_connector.bind_connector_metadata(scheduler_output.kv_connector_metadata)
|
||||
|
||||
# Background KV cache transfers happen here.
|
||||
# These transfers are designed to be async and the requests
|
||||
# involved may be disjoint from the running requests.
|
||||
# Do this here to save a collective_rpc.
|
||||
kv_connector.start_load_kv(get_forward_context())
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: supoort disagg for mlu.
|
||||
'''
|
||||
kv_connector.request_remote_memory_send()
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
# This context manager must be used within an active forward context.
|
||||
# It encapsulates the entire KV connector lifecycle within execute_model
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def _get_kv_connector_output(
|
||||
scheduler_output: "SchedulerOutput", wait_for_save: bool = True
|
||||
) -> Generator[KVConnectorOutput, None, None]:
|
||||
output = KVConnectorOutput()
|
||||
|
||||
# Update KVConnector with the KVConnector metadata forward().
|
||||
kv_connector = get_kv_transfer_group()
|
||||
assert isinstance(kv_connector, KVConnectorBase)
|
||||
assert scheduler_output.kv_connector_metadata is not None
|
||||
kv_connector.bind_connector_metadata(scheduler_output.kv_connector_metadata)
|
||||
|
||||
# Background KV cache transfers happen here.
|
||||
# These transfers are designed to be async and the requests
|
||||
# involved may be disjoint from the running requests.
|
||||
# Do this here to save a collective_rpc.
|
||||
kv_connector.start_load_kv(get_forward_context())
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: supoort disagg for mlu.
|
||||
'''
|
||||
kv_connector.request_remote_memory_send()
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
try:
|
||||
yield output
|
||||
finally:
|
||||
output.finished_sending, output.finished_recving = (
|
||||
kv_connector.get_finished(scheduler_output.finished_req_ids)
|
||||
)
|
||||
output.invalid_block_ids = kv_connector.get_block_ids_with_load_errors()
|
||||
|
||||
output.kv_connector_stats = (
|
||||
KVConnectorModelRunnerMixin.get_kv_connector_stats()
|
||||
)
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(KVConnectorModelRunnerMixin,
|
||||
KVConnectorModelRunnerMixin.maybe_setup_kv_connector,
|
||||
KVConnectorModelRunnerMixin_MluHijack.maybe_setup_kv_connector)
|
||||
MluHijackObject.apply_hijack(KVConnectorModelRunnerMixin,
|
||||
KVConnectorModelRunnerMixin._get_kv_connector_output,
|
||||
KVConnectorModelRunnerMixin_MluHijack._get_kv_connector_output)
|
||||
33
vllm_mlu/v1/worker/lora_model_runner_mixin.py
Normal file
33
vllm_mlu/v1/worker/lora_model_runner_mixin.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import List
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
def vllm_mlu__v1__worker__LoRAModelRunnerMixin__add_dummy_loras(self, num_loras: int) -> List[LoRARequest]:
|
||||
assert num_loras > 0
|
||||
assert self.lora_manager is not None
|
||||
|
||||
dummy_lora_requests: list[LoRARequest] = []
|
||||
with self.lora_manager.dummy_lora_cache():
|
||||
for idx in range(num_loras):
|
||||
lora_id = idx + 1
|
||||
dummy_lora_request = LoRARequest(
|
||||
lora_name=f"capture_graph_{lora_id}",
|
||||
lora_int_id=lora_id,
|
||||
lora_path="/not/a/real/path",
|
||||
)
|
||||
self.lora_manager.add_dummy_lora(dummy_lora_request,
|
||||
rank=self.LORA_WARMUP_RANK)
|
||||
dummy_lora_requests.append(dummy_lora_request)
|
||||
return dummy_lora_requests
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(LoRAModelRunnerMixin,
|
||||
"add_dummy_loras",
|
||||
vllm_mlu__v1__worker__LoRAModelRunnerMixin__add_dummy_loras)
|
||||
281
vllm_mlu/v1/worker/mlu_quant.py
Normal file
281
vllm_mlu/v1/worker/mlu_quant.py
Normal file
@@ -0,0 +1,281 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
"""A MLU quant class."""
|
||||
import functools
|
||||
from collections import defaultdict
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from vllm.distributed import (
|
||||
get_moe_tensor_parallel_rank, get_moe_tensor_parallel_world_size,
|
||||
get_moe_expert_parallel_rank, get_moe_expert_parallel_world_size,
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
import vllm.envs as envs
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (VocabParallelEmbedding,
|
||||
ParallelLMHead)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV2MLAAttention
|
||||
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
|
||||
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def default_act_range_value():
|
||||
return {
|
||||
"x": None,
|
||||
"split": None,
|
||||
"is_linear": False,
|
||||
"is_qkv": False,
|
||||
"q_proj_size": 0,
|
||||
"num_kv_head_replicas": 1,
|
||||
"is_merge": False,
|
||||
"input_id": [],
|
||||
"self_rank": 0,
|
||||
"rank": None,
|
||||
"tensor_rank": None,
|
||||
"tp_world_size": None,
|
||||
"moe_tp_rank": None,
|
||||
"moe_tp_world_size": None,
|
||||
"moe_ep_rank": None,
|
||||
"moe_ep_world_size": None,
|
||||
"weight": None,
|
||||
}
|
||||
|
||||
|
||||
def _str_to_torch_dtype(dtype: str) -> torch.dtype:
|
||||
dtype = dtype.split(".")[-1]
|
||||
# STR_DTYPE_TO_TORCH_DTYPE dict does not have float16 type
|
||||
return STR_DTYPE_TO_TORCH_DTYPE[dtype] if dtype != "float16" else torch.float16
|
||||
|
||||
|
||||
class ActRangeValue:
|
||||
"""
|
||||
ActRangeValue for v1 MsgpackEncoder and MsgpackDecoder. This is a *WorkAround*.
|
||||
The decode tensor can be wrong if we pass act range dict directly.
|
||||
|
||||
NOTE: here, we transfer torch.Tensor to numpy ndarray because torch.Tensor
|
||||
may cause core dump.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.layer_name: str = ""
|
||||
self.x: Optional[np.ndarray] = None
|
||||
self.split: str = None
|
||||
self.is_linear: bool = False
|
||||
self.is_qkv: bool = False
|
||||
self.q_proj_size: int = 0
|
||||
self.num_kv_head_replicas: int = 1
|
||||
self.is_merge: bool = False
|
||||
self.input_id_dtype: str = None
|
||||
self.input_id: Optional[List[np.ndarray]] = []
|
||||
self.self_rank: int = 0
|
||||
self.rank: Optional[int] = None
|
||||
self.tensor_rank: Optional[int] = None
|
||||
self.tp_world_size: Optional[int] = None
|
||||
self.moe_tp_rank: Optional[int] = None
|
||||
self.moe_tp_world_size: Optional[int] = None
|
||||
self.moe_ep_rank: Optional[int] = None
|
||||
self.moe_ep_world_size: Optional[int] = None
|
||||
self.weight: Optional[np.ndarray] = None
|
||||
self.weight_dtype: str = None
|
||||
|
||||
@classmethod
|
||||
def serial(cls, layer_name: str, act_range: Dict[str, Any]) -> 'ActRangeValue':
|
||||
instance = cls()
|
||||
instance.layer_name = layer_name
|
||||
instance.x = act_range.get("x")
|
||||
instance.split = act_range.get("split")
|
||||
instance.is_linear = act_range.get("is_linear", False)
|
||||
instance.is_qkv = act_range.get("is_qkv", False)
|
||||
instance.q_proj_size = act_range.get("q_proj_size", 0)
|
||||
instance.num_kv_head_replicas = act_range.get("num_kv_head_replicas", 1)
|
||||
instance.is_merge = act_range.get("is_merge", False)
|
||||
instance.input_id = act_range.get("input_id", [])
|
||||
instance.self_rank = act_range.get("self_rank", 0)
|
||||
instance.rank = act_range.get("rank")
|
||||
instance.tensor_rank = act_range.get("tensor_rank")
|
||||
instance.tp_world_size = act_range.get("tp_world_size")
|
||||
instance.moe_tp_rank = act_range.get("moe_tp_rank")
|
||||
instance.moe_tp_world_size = act_range.get("moe_tp_world_size")
|
||||
instance.moe_ep_rank = act_range.get("moe_ep_rank")
|
||||
instance.moe_ep_world_size = act_range.get("moe_ep_world_size")
|
||||
instance.weight = act_range.get("weight")
|
||||
|
||||
if instance.x is not None:
|
||||
instance.x = instance.x.numpy()
|
||||
# input_id and weight are used for debug
|
||||
if isinstance(instance.input_id, torch.Tensor):
|
||||
instance.input_id_dtype = str(instance.input_id.dtype)
|
||||
instance.input_id = instance.input_id.float().numpy()
|
||||
else:
|
||||
input_id_np = []
|
||||
for input_id in instance.input_id:
|
||||
instance.input_id_dtype = str(input_id.dtype)
|
||||
input_id_np.append(input_id.float().numpy())
|
||||
instance.input_id = input_id_np
|
||||
if instance.weight is not None:
|
||||
instance.weight_dtype = str(instance.weight.dtype)
|
||||
instance.weight = instance.weight.float().numpy()
|
||||
|
||||
return instance
|
||||
|
||||
def deserial(self) -> Dict[str, Any]:
|
||||
act_range = self.to_dict()
|
||||
if self.x is not None:
|
||||
act_range["x"] = torch.from_numpy(self.x)
|
||||
if self.input_id is not None:
|
||||
if isinstance(self.input_id, torch.Tensor):
|
||||
act_range["input_id"] = torch.from_numpy(self.input_id).to(
|
||||
_str_to_torch_dtype(self.input_id_dtype))
|
||||
else:
|
||||
input_id_tensor = []
|
||||
for input_id in self.input_id:
|
||||
input_id_tensor.append(torch.from_numpy(input_id).to(
|
||||
_str_to_torch_dtype(self.input_id_dtype)))
|
||||
act_range["input_id"] = input_id_tensor
|
||||
if self.weight_dtype is not None:
|
||||
act_range["weight"] = torch.from_numpy(self.weight).to(
|
||||
_str_to_torch_dtype(self.weight_dtype))
|
||||
return act_range
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"x": self.x,
|
||||
"split": self.split,
|
||||
"is_linear": self.is_linear,
|
||||
"is_qkv": self.is_qkv,
|
||||
"q_proj_size": self.q_proj_size,
|
||||
"num_kv_head_replicas": self.num_kv_head_replicas,
|
||||
"is_merge": self.is_merge,
|
||||
"input_id": self.input_id,
|
||||
"self_rank": self.self_rank,
|
||||
"rank": self.rank,
|
||||
"tensor_rank": self.tensor_rank,
|
||||
"tp_world_size": self.tp_world_size,
|
||||
"moe_tp_rank": self.moe_tp_rank,
|
||||
"moe_tp_world_size": self.moe_tp_world_size,
|
||||
"moe_ep_rank": self.moe_ep_rank,
|
||||
"moe_ep_world_size": self.moe_ep_world_size,
|
||||
"weight": self.weight,
|
||||
}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"layer: {self.layer_name}, ActRangeValue({self.to_dict()})"
|
||||
|
||||
|
||||
class MLUWorkerQuant(object):
|
||||
'''
|
||||
mlu quant
|
||||
'''
|
||||
def stat_tensor(self, name, tensor, act_range, key, dim):
|
||||
logger.debug(f"name:{name}, key:{key}, dim:{dim}, tensor.shape:{tensor.shape}")
|
||||
hidden_dim = tensor.shape[-1]
|
||||
tensor = tensor.view(-1, hidden_dim).abs()
|
||||
comming_max = torch.max(tensor, dim=dim)[0].float()
|
||||
|
||||
if act_range[name][key] is None:
|
||||
act_range[name][key] = comming_max
|
||||
else:
|
||||
act_range[name][key] = torch.max(act_range[name][key], comming_max)
|
||||
|
||||
def stat_input_hook(self, m, x, y, name, act_range, is_linear, is_save_input_id):
|
||||
if isinstance(x, tuple):
|
||||
x = x[0]
|
||||
if isinstance(y, tuple):
|
||||
y = y[0]
|
||||
logger.debug(f"name:{name}, x.shape:{x.shape}, y.shape:{y.shape}, m.weight.shape:{m.weight.shape}")
|
||||
if is_linear:
|
||||
self.stat_tensor(name, x, act_range, "x", 0)
|
||||
if act_range[name]["is_qkv"] and is_save_input_id and ".0." in name:
|
||||
x_cpu = x.clone().to("cpu")
|
||||
act_range[name]["input_id"].append(x_cpu)
|
||||
|
||||
def setup_smooth_hook(self, is_save_input_id: bool = False, is_save_moe_info: bool = False):
|
||||
models = [self.model_runner.model]
|
||||
if hasattr(self.model_runner, "drafter") and self.model_runner.drafter is not None:
|
||||
models += [self.model_runner.drafter.model]
|
||||
self.act_range = defaultdict(default_act_range_value)
|
||||
self.hooks = []
|
||||
linear_class_list = (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear)
|
||||
other_class_list = (VocabParallelEmbedding, ParallelLMHead)
|
||||
class_list = linear_class_list + other_class_list
|
||||
row_class_list = (RowParallelLinear)
|
||||
|
||||
for model in models:
|
||||
for name, m in model.named_modules():
|
||||
if isinstance(m, FeedForward):
|
||||
m.use_bt_ffn = False
|
||||
if isinstance(m, SparseMoeMlp):
|
||||
m.is_use_fused_moe = False
|
||||
if isinstance(m, DeepseekV2MLAAttention):
|
||||
m.use_fused_mla_qkv = False
|
||||
|
||||
if isinstance(m, class_list):
|
||||
is_linear = True if isinstance(m, linear_class_list) else False
|
||||
split_type = "row" if isinstance(m, row_class_list) else "col"
|
||||
self.act_range[name]["split"] = split_type
|
||||
self.act_range[name]["is_linear"] = is_linear
|
||||
if isinstance(m, QKVParallelLinear):
|
||||
self.act_range[name]["is_qkv"] = True
|
||||
self.act_range[name]["q_proj_size"] = m.num_heads * m.head_size
|
||||
self.act_range[name]["num_kv_head_replicas"] = m.num_kv_head_replicas
|
||||
self.act_range[name]["is_merge"] = isinstance(m, MergedColumnParallelLinear)
|
||||
if is_save_moe_info:
|
||||
self.act_range[name]["rank"] = torch.distributed.get_rank()
|
||||
self.act_range[name]["tensor_rank"] = get_tensor_model_parallel_rank()
|
||||
self.act_range[name]["tp_world_size"] = get_tensor_model_parallel_world_size()
|
||||
self.act_range[name]["moe_tp_rank"] = get_moe_tensor_parallel_rank()
|
||||
self.act_range[name]["moe_tp_world_size"] = get_moe_tensor_parallel_world_size()
|
||||
self.act_range[name]["moe_ep_rank"] = get_moe_expert_parallel_rank()
|
||||
self.act_range[name]["moe_ep_world_size"] = get_moe_expert_parallel_world_size()
|
||||
if ".expert." in name:
|
||||
self.act_range[name]["weight"] = m.weight
|
||||
logger.info(f"rank:{self.rank}, add hook to {name}, is_linear:{is_linear}, split_type:{split_type}")
|
||||
self.hooks.append(m.register_forward_hook(functools.partial(self.stat_input_hook,
|
||||
name=name, act_range=self.act_range,
|
||||
is_linear=is_linear,
|
||||
is_save_input_id=is_save_input_id)))
|
||||
|
||||
def remove_hooks(self):
|
||||
for h in self.hooks:
|
||||
h.remove()
|
||||
|
||||
def get_act_range(self):
|
||||
act_range = defaultdict(default_act_range_value)
|
||||
for layer_name, layer_range in self.act_range.items():
|
||||
for tensor_key, tensor_value in layer_range.items():
|
||||
if isinstance(tensor_value, torch.Tensor):
|
||||
act_range[layer_name][tensor_key] = tensor_value.to("cpu")
|
||||
elif tensor_key == "input_id" and isinstance(tensor_value, list):
|
||||
input_id_len = len(tensor_value)
|
||||
for i in range(input_id_len):
|
||||
if isinstance(tensor_value[i], torch.Tensor):
|
||||
act_range[layer_name][tensor_key].append(tensor_value[i].to("cpu"))
|
||||
else:
|
||||
act_range[layer_name][tensor_key].append(tensor_value[i])
|
||||
else:
|
||||
act_range[layer_name][tensor_key] = tensor_value
|
||||
|
||||
serialization_result = []
|
||||
for layer_name, layer_range in act_range.items():
|
||||
serialization_result.append(ActRangeValue.serial(layer_name, layer_range))
|
||||
return serialization_result
|
||||
|
||||
@torch.no_grad()
|
||||
def get_named_parameters(self):
|
||||
name_parameters = {}
|
||||
for name, param in self.model_runner.model.named_parameters():
|
||||
name_parameters[name] = param.to("cpu")
|
||||
|
||||
return name_parameters
|
||||
Reference in New Issue
Block a user