forked from EngineX-Hygon/enginex-hygon-vllm
init src 0.9.2
This commit is contained in:
317
vllm/zero_overhead/v1/eagle.py
Normal file
317
vllm/zero_overhead/v1/eagle.py
Normal file
@@ -0,0 +1,317 @@
|
||||
|
||||
import torch
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, EagleProposer
|
||||
|
||||
|
||||
class V1ZeroEagleProposer(EagleProposer):
|
||||
def __init__(self, vllm_config, device, runner=None):
|
||||
super().__init__(vllm_config, device, runner)
|
||||
self.spec_scheduler_max_num_tokens = 0
|
||||
|
||||
|
||||
def propose(
|
||||
self,
|
||||
# [num_tokens]
|
||||
target_token_ids: torch.Tensor,
|
||||
# [num_tokens]
|
||||
target_positions: torch.Tensor,
|
||||
# [num_tokens, hidden_size]
|
||||
target_hidden_states: torch.Tensor,
|
||||
# [num_tokens]
|
||||
target_slot_mapping: torch.Tensor,
|
||||
# [batch_size]
|
||||
next_token_ids: torch.Tensor,
|
||||
# [batch_size + 1] starting with 0
|
||||
cu_num_tokens: torch.Tensor,
|
||||
# [batch_size, max_num_blocks_per_req]
|
||||
block_table: torch.Tensor,
|
||||
# [batch_size]
|
||||
sampling_metadata: SamplingMetadata,
|
||||
decoding: bool = False,
|
||||
) -> torch.Tensor:
|
||||
num_tokens = target_token_ids.shape[0]
|
||||
batch_size = next_token_ids.shape[0]
|
||||
last_token_indices = cu_num_tokens[1:] - 1
|
||||
|
||||
if self.method == "eagle3":
|
||||
assert isinstance(self.model, Eagle3LlamaForCausalLM)
|
||||
target_hidden_states = self.model.combine_hidden_states(
|
||||
target_hidden_states)
|
||||
assert target_hidden_states.shape[-1] == self.hidden_size
|
||||
|
||||
# Shift the input ids by one token.
|
||||
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
|
||||
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
|
||||
# Replace the last token with the next token.
|
||||
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
||||
self.input_ids[last_token_indices] = next_token_ids
|
||||
|
||||
# FA requires seq_len to have dtype int32.
|
||||
seq_lens = (target_positions[last_token_indices] + 1).int()
|
||||
|
||||
if self.method in ["eagle", "eagle3"]:
|
||||
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
|
||||
max_seq_len = seq_lens.max().item()
|
||||
max_num_tokens = (cu_num_tokens[1:] -
|
||||
cu_num_tokens[:-1]).max().item()
|
||||
attn_metadata = FlashAttentionMetadata(
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=max_num_tokens,
|
||||
query_start_loc=cu_num_tokens,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table,
|
||||
slot_mapping=target_slot_mapping,
|
||||
# TODO(woosuk): Support cascade attention.
|
||||
use_cascade=False,
|
||||
common_prefix_len=0,
|
||||
cu_prefix_query_lens=None,
|
||||
prefix_kv_lens=None,
|
||||
suffix_kv_lens=None,
|
||||
)
|
||||
elif self.method == "deepseek_mtp":
|
||||
max_query_len = self.spec_scheduler_max_num_tokens
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=cu_num_tokens,
|
||||
seq_lens=seq_lens,
|
||||
num_reqs=batch_size,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=max_query_len,
|
||||
slot_mapping=target_slot_mapping,
|
||||
spec_layer_decoding=decoding
|
||||
)
|
||||
|
||||
assert self.runner is not None
|
||||
|
||||
# FIXME: need to consider multiple kv_cache_groups
|
||||
attn_metadata = self.runner.attn_metadata_builders[0].build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported method: {self.method}")
|
||||
|
||||
# At this moment, we assume all eagle layers belong to the same KV
|
||||
# cache group, thus using the same attention metadata.
|
||||
per_layer_attn_metadata = {}
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
if self.use_cuda_graph and \
|
||||
num_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||
else:
|
||||
num_input_tokens = num_tokens
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.positions[:num_tokens] = target_positions
|
||||
self.hidden_states[:num_tokens] = target_hidden_states
|
||||
|
||||
if (decoding and self.use_full_cuda_graph
|
||||
and num_tokens <= self.cudagraph_batch_sizes[-1]):
|
||||
assert self.attn_metadata_cudagraph
|
||||
if self.method in ["eagle", "eagle3"]:
|
||||
self.attn_metadata_cudagraph.seq_lens[:batch_size] = (
|
||||
attn_metadata.seq_lens)
|
||||
self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
|
||||
attn_metadata.slot_mapping)
|
||||
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
|
||||
attn_metadata.query_start_loc)
|
||||
self.attn_metadata_cudagraph.block_table[:batch_size] = (
|
||||
attn_metadata.block_table)
|
||||
elif self.method == "deepseek_mtp":
|
||||
self.attn_metadata_cudagraph.num_actual_tokens = (
|
||||
attn_metadata.num_actual_tokens)
|
||||
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
|
||||
attn_metadata.query_start_loc)
|
||||
self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
|
||||
attn_metadata.slot_mapping)
|
||||
self.attn_metadata_cudagraph.num_decodes = (
|
||||
attn_metadata.num_decodes)
|
||||
self.attn_metadata_cudagraph.num_decode_tokens = (
|
||||
attn_metadata.num_decode_tokens)
|
||||
self.attn_metadata_cudagraph.num_prefills = (
|
||||
attn_metadata.num_prefills)
|
||||
|
||||
if attn_metadata.decode is not None:
|
||||
self.attn_metadata_cudagraph.decode.block_table[:attn_metadata.num_decode_tokens] = (
|
||||
attn_metadata.decode.block_table)
|
||||
self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
|
||||
attn_metadata.decode.seq_lens)
|
||||
|
||||
with set_forward_context(per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
skip_cuda_graphs=not decoding):
|
||||
ret_hidden_states = self.model(
|
||||
self.input_ids[:num_input_tokens],
|
||||
self.positions[:num_input_tokens],
|
||||
self.hidden_states[:num_input_tokens],
|
||||
)
|
||||
if self.method == "deepseek_mtp":
|
||||
last_hidden_states = ret_hidden_states
|
||||
else:
|
||||
last_hidden_states, hidden_states = ret_hidden_states
|
||||
sample_hidden_states = last_hidden_states[last_token_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
|
||||
# Early exit if there is only one draft token to be generated.
|
||||
if self.num_speculative_tokens == 1:
|
||||
# [batch_size, 1]
|
||||
return draft_token_ids.view(-1, 1)
|
||||
|
||||
# TODO: Currently, MTP module released by deepseek only has
|
||||
# one layer. Adapt this code to support multiple layers once
|
||||
# there's a multi-layer MTP module.
|
||||
|
||||
# Generate the remaining draft tokens.
|
||||
draft_token_ids_list = [draft_token_ids]
|
||||
|
||||
positions = target_positions[last_token_indices]
|
||||
|
||||
if self.method == "deepseek_mtp":
|
||||
hidden_states = last_hidden_states[last_token_indices]
|
||||
else:
|
||||
hidden_states = hidden_states[last_token_indices]
|
||||
|
||||
if self.use_cuda_graph and \
|
||||
batch_size <= self.cudagraph_batch_sizes[-1]:
|
||||
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
|
||||
else:
|
||||
input_batch_size = batch_size
|
||||
attn_metadata.num_actual_tokens = batch_size
|
||||
attn_metadata.max_query_len = 1
|
||||
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
|
||||
|
||||
if isinstance(attn_metadata, MLACommonMetadata):
|
||||
attn_metadata.num_decodes = batch_size
|
||||
attn_metadata.num_decode_tokens = batch_size
|
||||
attn_metadata.num_prefills = 0
|
||||
block_table = self.runner.attn_metadata_builders[0].block_table.get_device_tensor()[:batch_size, ...]
|
||||
attn_metadata.decode = self.runner.attn_metadata_builders[0]._build_decode(
|
||||
block_table_tensor=block_table,
|
||||
seq_lens=seq_lens,
|
||||
)
|
||||
|
||||
for i in range(self.num_speculative_tokens - 1):
|
||||
# Update the inputs.
|
||||
# cast to int32 is crucial when eagle model is compiled.
|
||||
# tensor.argmax() returns int64 by default.
|
||||
input_ids = draft_token_ids_list[-1].int()
|
||||
positions += 1
|
||||
|
||||
# NOTE(woosuk): We should handle the case where the draft model
|
||||
# generates tokens beyond the max model length. Since it is complex
|
||||
# to remove such requests from the batch, we keep them in the batch
|
||||
# but adjust the position ids and slot mappings to avoid the
|
||||
# out-of-range access during the model execution. The draft tokens
|
||||
# generated with this adjustment should be ignored.
|
||||
exceeds_max_model_len = positions >= self.max_model_len
|
||||
# Mask out the position ids that exceed the max model length.
|
||||
# Otherwise, we may get out-of-range error in RoPE.
|
||||
clamped_positions = torch.where(exceeds_max_model_len, 0,
|
||||
positions)
|
||||
|
||||
if isinstance(attn_metadata, MLACommonMetadata):
|
||||
attn_metadata.decode.seq_lens += 1
|
||||
else:
|
||||
attn_metadata.seq_lens += 1
|
||||
|
||||
# Increment the sequence lengths.
|
||||
attn_metadata.max_seq_len += 1
|
||||
# Consider max model length.
|
||||
attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
|
||||
self.max_model_len)
|
||||
|
||||
# For the requests that exceed the max model length, we set the
|
||||
# sequence length to 1 to minimize their overheads in attention.
|
||||
attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
|
||||
|
||||
# Compute the slot mapping.
|
||||
block_numbers = clamped_positions // self.block_size
|
||||
block_ids = block_table.gather(dim=1,
|
||||
index=block_numbers.view(-1, 1))
|
||||
block_ids = block_ids.view(-1)
|
||||
attn_metadata.slot_mapping = (block_ids * self.block_size +
|
||||
clamped_positions % self.block_size)
|
||||
# Mask out the slot mappings that exceed the max model length.
|
||||
# Otherwise, the KV cache will be inadvertently updated with the
|
||||
# padding tokens.
|
||||
attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len,
|
||||
PADDING_SLOT_ID)
|
||||
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.input_ids[:batch_size] = input_ids
|
||||
self.positions[:batch_size] = clamped_positions
|
||||
self.hidden_states[:batch_size] = hidden_states
|
||||
|
||||
if (self.use_full_cuda_graph
|
||||
and batch_size <= self.cudagraph_batch_sizes[-1]):
|
||||
assert self.attn_metadata_cudagraph
|
||||
if self.method in ["eagle", "eagle3"]:
|
||||
self.attn_metadata_cudagraph.seq_lens[:batch_size] = (
|
||||
attn_metadata.seq_lens)
|
||||
self.attn_metadata_cudagraph.slot_mapping[:batch_size] = (
|
||||
attn_metadata.slot_mapping)
|
||||
if i == 0:
|
||||
self.attn_metadata_cudagraph.query_start_loc[:batch_size +
|
||||
1] = (
|
||||
attn_metadata
|
||||
.
|
||||
query_start_loc
|
||||
)
|
||||
self.attn_metadata_cudagraph.block_table[:batch_size] = (
|
||||
attn_metadata.block_table)
|
||||
elif self.method == "deepseek_mtp":
|
||||
self.attn_metadata_cudagraph.num_actual_tokens = (
|
||||
attn_metadata.num_actual_tokens)
|
||||
self.attn_metadata_cudagraph.slot_mapping[:attn_metadata.num_decode_tokens] = (
|
||||
attn_metadata.slot_mapping)
|
||||
self.attn_metadata_cudagraph.num_decodes = (
|
||||
attn_metadata.num_decodes)
|
||||
self.attn_metadata_cudagraph.num_decode_tokens = (
|
||||
attn_metadata.num_decode_tokens)
|
||||
self.attn_metadata_cudagraph.num_prefills = (
|
||||
attn_metadata.num_prefills)
|
||||
self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
|
||||
attn_metadata.decode.seq_lens)
|
||||
|
||||
if i == 0:
|
||||
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
|
||||
attn_metadata.query_start_loc)
|
||||
self.attn_metadata_cudagraph.decode.block_table[:attn_metadata.num_decode_tokens] = (
|
||||
attn_metadata.decode.block_table)
|
||||
|
||||
# Run the model.
|
||||
with set_forward_context(per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=input_batch_size):
|
||||
ret_hidden_states = self.model(
|
||||
self.input_ids[:input_batch_size],
|
||||
self.positions[:input_batch_size],
|
||||
self.hidden_states[:input_batch_size],
|
||||
)
|
||||
if self.method == "deepseek_mtp":
|
||||
last_hidden_states = ret_hidden_states
|
||||
hidden_states = last_hidden_states[:batch_size]
|
||||
else:
|
||||
last_hidden_states, hidden_states = ret_hidden_states
|
||||
hidden_states = hidden_states[:batch_size]
|
||||
|
||||
logits = self.model.compute_logits(last_hidden_states[:batch_size],
|
||||
None)
|
||||
|
||||
# TODO(wenlong): get more than one token for tree attention
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
draft_token_ids_list.append(draft_token_ids)
|
||||
|
||||
# [batch_size, num_speculative_tokens]
|
||||
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
||||
|
||||
return draft_token_ids
|
||||
Reference in New Issue
Block a user