531 lines
23 KiB
Python
531 lines
23 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
|
from typing import List, Optional, Any
|
|
import copy
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from vllm.config.vllm import CUDAGraphMode
|
|
from vllm.forward_context import BatchDescriptor, set_forward_context
|
|
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
|
from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata,
|
|
FlashAttentionMetadata)
|
|
from vllm.v1.sample.metadata import SamplingMetadata
|
|
|
|
from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, logger
|
|
|
|
from vllm.distributed.communication_op import tensor_model_parallel_all_gather_into_list
|
|
from vllm.distributed import (
|
|
get_logits_tp_world_size,
|
|
get_logits_tp_group,
|
|
get_tensor_model_parallel_world_size,
|
|
)
|
|
from vllm_mlu.v1.attention.backends.flash_attn import pad_attn_metadata
|
|
from vllm_mlu.v1.attention.backends.mla.flashmla import FlashMLAMetadataBuilder
|
|
from vllm_mlu.v1.attention.backends.utils import (
|
|
MLUCommonAttentionMetadata, COMMON_METADATA_STR)
|
|
from vllm_mlu._mlu_utils import *
|
|
from vllm_mlu.v1.attention.backends.utils import MLUInferMode
|
|
|
|
from vllm_mlu.mlu_forward_context import MLUDPMetadata
|
|
from vllm_mlu.v1.spec_decode.eagle import MluEagleProposer
|
|
from vllm_mlu.model_executor.models.dp_utils import (
|
|
enable_data_parallel,
|
|
DataParallelRuntimeParams
|
|
)
|
|
|
|
class DPMluEagleProposer(MluEagleProposer):
|
|
|
|
def get_logits_batch_sizes(self, batch_size: int) -> Optional[List[int]]:
|
|
tp_world_size, logits_batch_sizes = get_logits_tp_world_size(), None
|
|
if tp_world_size != get_tensor_model_parallel_world_size():
|
|
tp_tensor = torch.tensor([batch_size]).to(self.runner.device)
|
|
outputs = tensor_model_parallel_all_gather_into_list(tp_tensor, get_logits_tp_group())
|
|
# Convert device tensor to host list
|
|
outputs = torch.cat(outputs).tolist()
|
|
logits_batch_sizes = [outputs[i] for i in range(tp_world_size)]
|
|
return logits_batch_sizes
|
|
|
|
def propose_ds_execute_dummy_batch(
|
|
self,
|
|
# [num_tokens]
|
|
target_token_ids: torch.Tensor,
|
|
# [num_tokens]
|
|
target_positions: torch.Tensor,
|
|
# [num_tokens, hidden_size]
|
|
target_hidden_states: torch.Tensor,
|
|
dp_params: DataParallelRuntimeParams,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
# num_scheduled_tokens
|
|
num_tokens = target_token_ids.shape[0]
|
|
input_ids = self.input_ids[:num_tokens]
|
|
# Shift the input ids by one token.
|
|
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
|
|
input_ids[:-1] = target_token_ids[1:]
|
|
|
|
# always skip attn compute
|
|
attn_metadata: Optional[dict[str, Any]] = None
|
|
|
|
# Get graph capture related infomation for deepseek model.
|
|
|
|
with set_forward_context(attn_metadata, self.vllm_config, num_tokens=num_tokens):
|
|
hidden_states = self.model(
|
|
input_ids=input_ids,
|
|
positions=target_positions,
|
|
hidden_states=target_hidden_states,
|
|
intermediate_tensors=None,
|
|
inputs_embeds=None,
|
|
dp_params=dp_params,
|
|
)
|
|
if dp_params is not None:
|
|
dp_params.logits_batch_split_list = self.get_logits_batch_sizes(num_tokens)
|
|
_ = self.model.compute_logits(hidden_states, dp_params=dp_params)
|
|
|
|
if self.num_speculative_tokens == 1:
|
|
return
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
@brief: support k > 1, need run draft model k-1 times
|
|
=============================
|
|
'''
|
|
# support k > 1
|
|
for _ in range(self.num_speculative_tokens - 1):
|
|
new_dp_params = self.runner._get_data_parallel_metadata(
|
|
num_tokens, num_tokens, True, [1] * num_tokens)
|
|
with set_forward_context(attn_metadata, self.vllm_config, num_tokens=num_tokens):
|
|
hidden_states = self.model(
|
|
input_ids=input_ids,
|
|
positions=target_positions,
|
|
hidden_states=target_hidden_states,
|
|
intermediate_tensors=None,
|
|
inputs_embeds=None,
|
|
dp_params=new_dp_params,
|
|
)
|
|
_ = self.model.compute_logits(hidden_states, dp_params=new_dp_params)
|
|
'''
|
|
=============================
|
|
End of MLU Hijack
|
|
=============================
|
|
'''
|
|
|
|
def propose(
|
|
self,
|
|
# [num_tokens]
|
|
target_token_ids: torch.Tensor,
|
|
# [num_tokens]
|
|
target_positions: torch.Tensor,
|
|
# [num_tokens, hidden_size]
|
|
target_hidden_states: torch.Tensor,
|
|
# [batch_size]
|
|
next_token_ids: torch.Tensor,
|
|
last_token_indices: torch.Tensor | None,
|
|
common_attn_metadata: MLUCommonAttentionMetadata,
|
|
sampling_metadata: SamplingMetadata,
|
|
# [batch_size]
|
|
num_rejected_tokens: torch.Tensor,
|
|
# [num_tokens]
|
|
token_indices: torch.Tensor,
|
|
whole_block_table: torch.Tensor,
|
|
main_model_dp_params: Optional[DataParallelRuntimeParams] = None,
|
|
time_markers: List =[],
|
|
) -> torch.Tensor:
|
|
num_tokens = target_token_ids.shape[0]
|
|
batch_size = next_token_ids.shape[0]
|
|
if last_token_indices is None:
|
|
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
|
|
|
|
if self.method == "eagle3":
|
|
assert isinstance(self.model, Eagle3LlamaForCausalLM)
|
|
target_hidden_states = self.model.combine_hidden_states(
|
|
target_hidden_states)
|
|
assert target_hidden_states.shape[-1] == self.hidden_size
|
|
|
|
# Shift the input ids by one token.
|
|
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
|
|
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
|
|
# Replace the last token with the next token.
|
|
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
|
self.input_ids[last_token_indices] = next_token_ids
|
|
hidden_states_indices = last_token_indices
|
|
|
|
assert self.runner is not None
|
|
|
|
if self.attn_metadata_builder is None:
|
|
attn_metadata_builder = self._get_attention_metadata_builder()
|
|
else:
|
|
attn_metadata_builder = self.attn_metadata_builder
|
|
|
|
# FIXME: need to consider multiple kv_cache_groups
|
|
attn_metadata = attn_metadata_builder.build_for_drafting(
|
|
common_attn_metadata=common_attn_metadata,
|
|
draft_index=0,
|
|
)
|
|
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: Use full graph with draft model and pad batch_size for dp
|
|
'''
|
|
dp_group_max_token_num = max(main_model_dp_params.token_split_list)
|
|
if dp_group_max_token_num <= self.vllm_config.compilation_config.max_cudagraph_capture_size:
|
|
batch_descriptor_num_tokens = self.vllm_config.pad_for_cudagraph(dp_group_max_token_num)
|
|
captured_already = True
|
|
else:
|
|
batch_descriptor_num_tokens = num_tokens
|
|
captured_already = False
|
|
|
|
# Determine if we can use full graph
|
|
decode_only = all(not prefill for prefill in main_model_dp_params.dp_is_prefill)
|
|
# FIXME(wangchao2): disable mtp graph for ds3.2 with dp fow now(core dump)
|
|
is_dsv32 = self.vllm_config.model_config.hf_config.model_type == "deepseek_v32"
|
|
use_full_graph = (self.use_cuda_graph
|
|
and decode_only and captured_already and not is_dsv32)
|
|
if (self.use_cuda_graph and decode_only and not use_full_graph and not is_dsv32):
|
|
logger.warning_once(
|
|
f"Select MLU-V1 Full-MLUGraph mode with drafter, however running in " +
|
|
f"eager mode: decode_only={decode_only}, captured_already={captured_already}, " +
|
|
f"num_tokens={num_tokens}."
|
|
)
|
|
|
|
cudagraph_runtime_mode = CUDAGraphMode.FULL if use_full_graph else CUDAGraphMode.NONE
|
|
batch_descriptor = BatchDescriptor(
|
|
num_tokens=batch_descriptor_num_tokens,
|
|
uniform_decode=True,
|
|
)
|
|
|
|
# dp pad batch_size
|
|
if use_full_graph:
|
|
K = self.num_speculative_tokens
|
|
num_input_tokens = batch_descriptor_num_tokens
|
|
padded_batch_size = num_input_tokens // (K + 1)
|
|
else:
|
|
padded_batch_size = batch_size
|
|
num_input_tokens = num_tokens
|
|
|
|
# change attn metadata num_actual_tokens
|
|
attn_metadata.num_actual_tokens = num_input_tokens
|
|
|
|
common_attn_metadata_copy = None
|
|
# copy common_attn_metadata when k>1 for draft model,
|
|
# because dp pad batch_size will change common_attn_metadata
|
|
if self.num_speculative_tokens > 1:
|
|
common_attn_metadata_copy = copy.deepcopy(common_attn_metadata)
|
|
# pad attn metadata
|
|
if use_full_graph and enable_data_parallel() and num_input_tokens != num_tokens:
|
|
assert self.runner is not None
|
|
# Update attention metadata.
|
|
pad_attn_metadata(
|
|
attn_metadata,
|
|
common_attn_metadata,
|
|
whole_block_table,
|
|
self.runner,
|
|
num_tokens,
|
|
num_input_tokens,
|
|
batch_size,
|
|
padded_batch_size,
|
|
)
|
|
|
|
# Update input ids, pad with 0 if necessary.
|
|
token_pad_size = num_input_tokens - num_tokens
|
|
assert token_pad_size >= 0
|
|
# Update target hidden states, pad with zeros if necessary.
|
|
if token_pad_size > 0:
|
|
target_hidden_states = F.pad(
|
|
target_hidden_states,
|
|
(0, 0, 0, token_pad_size),
|
|
value=0.0
|
|
)
|
|
|
|
# Update positions, pad with zeros if necessary.
|
|
if token_pad_size > 0:
|
|
target_positions = F.pad(
|
|
target_positions,
|
|
(0, token_pad_size),
|
|
value=0
|
|
)
|
|
|
|
# At this moment, we assume all eagle layers belong to the same KV
|
|
# cache group, thus using the same attention metadata.
|
|
per_layer_attn_metadata = {}
|
|
for layer_name in self.attn_layer_names:
|
|
per_layer_attn_metadata[layer_name] = attn_metadata
|
|
per_layer_attn_metadata[COMMON_METADATA_STR] = common_attn_metadata
|
|
|
|
# copy inputs to buffer for cudagraph
|
|
self.positions[:num_input_tokens] = target_positions
|
|
self.hidden_states[:num_input_tokens] = target_hidden_states
|
|
|
|
kwargs = {} if main_model_dp_params is None else {"dp_params": main_model_dp_params}
|
|
if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN:
|
|
start = torch.mlu.Event(enable_timing=True)
|
|
start.record()
|
|
with set_forward_context(per_layer_attn_metadata,
|
|
self.vllm_config,
|
|
num_tokens=num_input_tokens,
|
|
batch_descriptor=batch_descriptor if use_full_graph else None,
|
|
cudagraph_runtime_mode=cudagraph_runtime_mode):
|
|
if use_full_graph:
|
|
ret_hidden_states = self.model(
|
|
input_ids=self.input_ids[:num_input_tokens],
|
|
positions=self.positions[:num_input_tokens],
|
|
hidden_states=self.hidden_states[:num_input_tokens],
|
|
intermediate_tensors=None,
|
|
inputs_embeds=None,
|
|
is_running_drafter=True,
|
|
**kwargs,
|
|
)
|
|
else:
|
|
ret_hidden_states = self.model(
|
|
input_ids=self.input_ids[:num_input_tokens],
|
|
positions=self.positions[:num_input_tokens],
|
|
hidden_states=self.hidden_states[:num_input_tokens],
|
|
**kwargs,
|
|
)
|
|
if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN:
|
|
end = torch.mlu.Event(enable_timing=True)
|
|
end.record()
|
|
time_markers.append([start, end])
|
|
if self.method == "mtp":
|
|
last_hidden_states = ret_hidden_states
|
|
else:
|
|
last_hidden_states, hidden_states = ret_hidden_states
|
|
'''
|
|
=============================
|
|
End of MLU Hijack
|
|
=============================
|
|
'''
|
|
if main_model_dp_params is not None:
|
|
# Ensure main_model_dp_params has required attribute before assignment
|
|
if hasattr(main_model_dp_params, 'logits_batch_split_list'):
|
|
main_model_dp_params.logits_batch_split_list = self.get_logits_batch_sizes(batch_size)
|
|
else:
|
|
raise AttributeError("dp_params must have 'logits_batch_split_list' attribute")
|
|
|
|
sample_hidden_states = last_hidden_states[hidden_states_indices]
|
|
logits = self.model.compute_logits(sample_hidden_states, dp_params=main_model_dp_params)
|
|
|
|
draft_token_ids = logits.argmax(dim=-1)
|
|
|
|
# Early exit if there is only one draft token to be generated.
|
|
if self.num_speculative_tokens == 1:
|
|
# [batch_size, 1]
|
|
return draft_token_ids.view(-1, 1)
|
|
|
|
if self.uses_mrope:
|
|
positions = target_positions[:, last_token_indices]
|
|
else:
|
|
positions = target_positions[last_token_indices]
|
|
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
'''
|
|
hidden_states = last_hidden_states[hidden_states_indices]
|
|
'''
|
|
=============================
|
|
End of MLU Hijack
|
|
=============================
|
|
'''
|
|
# Generate the remaining draft tokens.
|
|
draft_token_ids_list = [draft_token_ids]
|
|
|
|
input_batch_size = batch_size
|
|
|
|
if common_attn_metadata.infer_mode != MLUInferMode.DECODE_ONLY:
|
|
seq_lens_cpu = torch.ones(input_batch_size, dtype=torch.int32,)
|
|
cu_num_tokens = torch.cumsum(seq_lens_cpu, dim=0)
|
|
query_start_loc_cpu = torch.empty(input_batch_size + 1, dtype=torch.int32)
|
|
query_start_loc_cpu[0] = 0
|
|
query_start_loc_cpu[1:] = cu_num_tokens
|
|
seq_start_loc_cpu = self.arange[:input_batch_size + 1]
|
|
common_attn_metadata_k = MLUCommonAttentionMetadata.build(
|
|
query_start_loc=query_start_loc_cpu.to(self.device, non_blocking=True),
|
|
query_start_loc_cpu=query_start_loc_cpu,
|
|
seq_lens=seq_lens_cpu.to(self.device, non_blocking=True),
|
|
seq_lens_cpu=seq_lens_cpu,
|
|
num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
|
|
num_reqs=common_attn_metadata.num_reqs,
|
|
block_table_tensor=common_attn_metadata.block_table_tensor,
|
|
slot_mapping=common_attn_metadata.slot_mapping,
|
|
seq_start_loc=seq_start_loc_cpu.to(self.device, non_blocking=True),
|
|
is_start_loc_match=False, # not prefill
|
|
max_query_len=1,
|
|
num_actual_tokens=input_batch_size,
|
|
num_input_tokens=input_batch_size,
|
|
num_speculative_tokens=self.num_speculative_tokens,
|
|
has_prefill_reqs=common_attn_metadata.infer_mode == MLUInferMode.CHUNKED,
|
|
)
|
|
else:
|
|
common_attn_metadata_k = common_attn_metadata_copy
|
|
common_attn_metadata_k.num_actual_tokens = batch_size
|
|
common_attn_metadata_k.num_input_tokens = batch_size
|
|
common_attn_metadata_k.max_query_len = 1
|
|
common_attn_metadata_k.query_start_loc = self.arange[: batch_size + 1]
|
|
common_attn_metadata_k.query_start_loc_cpu = torch.from_numpy(
|
|
self.token_arange_np[: batch_size + 1]
|
|
).clone()
|
|
# In padded drafter batch, we need to adjust the sequence lengths
|
|
# to remove the "padding" (i.e. rejected tokens).
|
|
# Only apply this adjustment when we have rejected tokens
|
|
# (i.e., not the first proposal).
|
|
for token_index in range(self.num_speculative_tokens - 1):
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: get dp_params for draft model
|
|
'''
|
|
# dp_params for draft model
|
|
if main_model_dp_params is not None:
|
|
dp_params = self.runner._get_data_parallel_metadata(
|
|
input_batch_size,
|
|
input_batch_size,
|
|
common_attn_metadata.is_decode_only,
|
|
[1] * input_batch_size
|
|
)
|
|
kwargs = {} if main_model_dp_params is None else {"dp_params": dp_params}
|
|
'''
|
|
=============================
|
|
End of MLU Hijack
|
|
=============================
|
|
'''
|
|
# Update the inputs.
|
|
# cast to int32 is crucial when eagle model is compiled.
|
|
# tensor.argmax() returns int64 by default.
|
|
input_ids = draft_token_ids_list[-1].int()
|
|
if self.uses_mrope:
|
|
positions += 1
|
|
# NOTE(woosuk): We should handle the case where the draft model
|
|
# generates tokens beyond the max model length.
|
|
# Since it is complex to remove such requests from the batch,
|
|
# we keep them in the batch but adjust the position ids
|
|
# and slot mappings to avoid the
|
|
# out-of-range access during the model execution.
|
|
# The draft tokens generated with this adjustment
|
|
# should be ignored.
|
|
exceeds_max_model_len = positions[0] >= self.max_model_len
|
|
# Mask out the position ids that exceed the max model length.
|
|
# Otherwise, we may get out-of-range error in RoPE.
|
|
clamped_positions = torch.where(
|
|
exceeds_max_model_len.unsqueeze(0),
|
|
torch.zeros_like(positions),
|
|
positions,
|
|
)
|
|
else:
|
|
positions += 1
|
|
exceeds_max_model_len = positions >= self.max_model_len
|
|
clamped_positions = torch.where(exceeds_max_model_len, 0, positions)
|
|
|
|
# For data integrity when async scheduling, we shouldn't use in place
|
|
# operations in case they are modified in next step's `prepare_input`
|
|
# of main model.
|
|
# Increment the sequence lengths.
|
|
common_attn_metadata_k.seq_lens += 1
|
|
# This is an out-of-place operation to avoid modifying the original tensor.
|
|
common_attn_metadata_k.seq_lens_cpu = common_attn_metadata_k.seq_lens_cpu + 1
|
|
# For the requests that exceed the max model length, we set the
|
|
# sequence length to 1 to minimize their overheads in attention.
|
|
|
|
common_attn_metadata_k.seq_lens.masked_fill_(exceeds_max_model_len, 1)
|
|
|
|
common_attn_metadata_k.num_computed_tokens_cpu = (
|
|
common_attn_metadata_k.seq_lens_cpu - 1
|
|
)
|
|
|
|
# Compute the slot mapping.
|
|
if self.uses_mrope:
|
|
# all dimensions of positions are the same
|
|
block_numbers = clamped_positions[0] // self.block_size
|
|
else:
|
|
block_numbers = clamped_positions // self.block_size
|
|
block_ids = common_attn_metadata_k.block_table_tensor.gather(
|
|
dim=1, index=block_numbers.view(-1, 1)
|
|
)
|
|
block_ids = block_ids.view(-1)
|
|
if self.uses_mrope:
|
|
common_attn_metadata_k.slot_mapping = (
|
|
block_ids * self.block_size + clamped_positions[0] % self.block_size
|
|
)
|
|
else:
|
|
common_attn_metadata_k.slot_mapping = (
|
|
block_ids * self.block_size + clamped_positions % self.block_size
|
|
)
|
|
# Mask out the slot mappings that exceed the max model length.
|
|
# Otherwise, the KV cache will be inadvertently updated with the
|
|
# padding tokens.
|
|
common_attn_metadata_k.slot_mapping.masked_fill_(
|
|
exceeds_max_model_len, PADDING_SLOT_ID
|
|
)
|
|
|
|
# Rebuild attention metadata
|
|
attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore
|
|
common_attn_metadata=common_attn_metadata_k, draft_index=token_index + 1
|
|
)
|
|
for layer_name in self.attn_layer_names:
|
|
per_layer_attn_metadata[layer_name] = attn_metadata
|
|
per_layer_attn_metadata[COMMON_METADATA_STR] = common_attn_metadata_k
|
|
|
|
# copy inputs to buffer for cudagraph
|
|
self.input_ids[:batch_size] = input_ids
|
|
self.positions[:batch_size] = clamped_positions
|
|
self.hidden_states[:batch_size] = hidden_states
|
|
if self.supports_mm_inputs:
|
|
self.inputs_embeds[:batch_size] = self.model.embed_input_ids(input_ids)
|
|
|
|
input_ids = None
|
|
inputs_embeds = self.inputs_embeds[:input_batch_size]
|
|
else:
|
|
input_ids = self.input_ids[:input_batch_size]
|
|
inputs_embeds = None
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: record latency
|
|
'''
|
|
if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN:
|
|
start = torch.mlu.Event(enable_timing=True)
|
|
start.record()
|
|
'''
|
|
=============================
|
|
End of MLU Hijack
|
|
=============================
|
|
'''
|
|
# Run the model.
|
|
with set_forward_context(per_layer_attn_metadata,
|
|
self.vllm_config,
|
|
num_tokens=input_batch_size):
|
|
ret_hidden_states = self.model(
|
|
input_ids=self.input_ids[:input_batch_size],
|
|
positions=self.positions[:input_batch_size],
|
|
hidden_states=self.hidden_states[:input_batch_size],
|
|
**kwargs,
|
|
)
|
|
if self.method == "mtp":
|
|
last_hidden_states = ret_hidden_states
|
|
else:
|
|
last_hidden_states, hidden_states = ret_hidden_states
|
|
if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN:
|
|
end = torch.mlu.Event(enable_timing=True)
|
|
end.record()
|
|
time_markers.append([start, end])
|
|
'''
|
|
=============================
|
|
End of MLU Hijack
|
|
=============================
|
|
'''
|
|
hidden_states = last_hidden_states[:batch_size]
|
|
logits = self.model.compute_logits(last_hidden_states[:batch_size],
|
|
dp_params=dp_params)
|
|
|
|
# TODO(wenlong): get more than one token for tree attention
|
|
draft_token_ids = logits.argmax(dim=-1)
|
|
draft_token_ids_list.append(draft_token_ids)
|
|
|
|
# [batch_size, num_speculative_tokens]
|
|
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
|
return draft_token_ids
|