[Feat] Merge the multi eagle graphs to one graph (#5940)
### What this PR does / why we need it?
This PR merge all steps of draft model in fullgraph mode, to avoid the
synchronize between each graph, reduce the bubble time.
#### Key ideas:
- The "model forward" of the step 0 (first step) and remaining steps are
captured together as a "Callable", rather than capturing each model
individually.
- "update_attn_params" is moved outside the entire graph, meaning that
all "attn_metadata" required by all steps are constructed before
"replay", and the "attn_params" of all steps are updated at once.
- Remove synchronization between the main model graph and draft model
graph.
#### Key params/functions:
- params: draft_attn_metadatas, attn_metadata_multi_steps,
slot_mapping_group
- functions: _run_merged_draft, attn_update_stack_num_spec_norm,
update_attn_params, _propose, dummy_run
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
11b6af5280
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import copy
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import Any, ContextManager, Optional
|
||||
from typing import Any, Callable, ContextManager, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -84,6 +85,8 @@ def split_inputs_tp_to_sp(hidden_states, out):
|
||||
|
||||
class EagleProposer(VllmEagleProposer):
|
||||
|
||||
_runnable: Union[ACLGraphWrapper, Callable]
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
@@ -136,14 +139,29 @@ class EagleProposer(VllmEagleProposer):
|
||||
self.tp_group_context = nullcontext()
|
||||
|
||||
self.use_cuda_graph = (self.runner._use_aclgraph()
|
||||
and not self.speculative_config.enforce_eager
|
||||
and not self.use_async_scheduling)
|
||||
and not self.speculative_config.enforce_eager)
|
||||
if self.method == "mtp":
|
||||
self.use_cuda_graph = self.use_cuda_graph and not self.use_async_scheduling
|
||||
|
||||
# TODO: Remove it when the bug of fx-graph is solved
|
||||
self.maybe_eager_context: ContextManager[Any] = nullcontext()
|
||||
if not self.use_cuda_graph and enable_sp(vllm_config):
|
||||
self.maybe_eager_context = _maybe_eager_context(vllm_config)
|
||||
|
||||
self.last_token_indices = torch.zeros(
|
||||
self.vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
slot_mapping_lens = self.runner.max_num_tokens + \
|
||||
2 * self.pcp_size * self.runner.max_num_reqs
|
||||
self.slot_mapping_group = [
|
||||
torch.zeros(
|
||||
slot_mapping_lens, dtype=torch.int32, device=device,
|
||||
pin_memory=self.runner.pin_memory)
|
||||
for _ in range(self.num_speculative_tokens)]
|
||||
|
||||
self._runnable = self._run_merged_draft
|
||||
|
||||
def load_model(self, model: nn.Module) -> None:
|
||||
target_attn_layer_names = set(
|
||||
get_layers_from_vllm_config(self.vllm_config,
|
||||
@@ -166,7 +184,17 @@ class EagleProposer(VllmEagleProposer):
|
||||
draft_indexer_layer_names = indexer_layers - target_indexer_layer_names
|
||||
draft_attn_layer_names = draft_attn_layer_names - draft_indexer_layer_names
|
||||
assert len(draft_attn_layer_names) == 1
|
||||
self.attn_layer_names = list(draft_attn_layer_names)
|
||||
self.attn_layer_names = list(sorted(draft_attn_layer_names))
|
||||
self.piece_all_attn_layer_name = []
|
||||
for _ in range(self.num_speculative_tokens):
|
||||
self.piece_all_attn_layer_name.append([
|
||||
name for name in self.attn_layer_names])
|
||||
self.attn_layer_names = list(sorted(draft_attn_layer_names))
|
||||
|
||||
self.piece_all_attn_layer_name = []
|
||||
for _ in range(self.num_speculative_tokens):
|
||||
self.piece_all_attn_layer_name.append([
|
||||
name for name in self.attn_layer_names])
|
||||
|
||||
if supports_multimodal(model):
|
||||
# handle multimodality
|
||||
@@ -236,9 +264,14 @@ class EagleProposer(VllmEagleProposer):
|
||||
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
|
||||
) and self.use_cuda_graph:
|
||||
self.update_stream = torch.npu.Stream()
|
||||
self.model = ACLGraphWrapper(self.model,
|
||||
self.vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL)
|
||||
if self.method == "mtp":
|
||||
self.model = ACLGraphWrapper(self.model,
|
||||
self.vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL)
|
||||
else:
|
||||
self._runnable = ACLGraphWrapper(self._run_merged_draft,
|
||||
self.vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
# get raw model out of the aclgraph wrapper.
|
||||
@@ -246,6 +279,11 @@ class EagleProposer(VllmEagleProposer):
|
||||
return self.model.unwrap()
|
||||
return self.model
|
||||
|
||||
def shallow_copy_metadata(self, attn_metadata):
|
||||
# Currently, new objects will be assigned to the lists in attn_metadata
|
||||
# when update. So we can use the shallow copy.
|
||||
return copy.copy(attn_metadata)
|
||||
|
||||
@torch.inference_mode()
|
||||
def dummy_run(self,
|
||||
num_tokens: int,
|
||||
@@ -260,7 +298,7 @@ class EagleProposer(VllmEagleProposer):
|
||||
# update global cos, sin
|
||||
update_cos_sin(self._get_positions(num_tokens))
|
||||
|
||||
attn_metadata = None
|
||||
multi_steps_attn_metadata = []
|
||||
if not self.use_cuda_graph:
|
||||
aclgraph_runtime_mode = CUDAGraphMode.NONE
|
||||
if aclgraph_runtime_mode == CUDAGraphMode.FULL and len(
|
||||
@@ -286,6 +324,7 @@ class EagleProposer(VllmEagleProposer):
|
||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
||||
block_table_tensor=self.runner.input_batch.block_table[0].
|
||||
get_device_tensor()[:num_reqs],
|
||||
# This is used to hold a position.
|
||||
slot_mapping=self.runner.input_batch.block_table[0].
|
||||
slot_mapping.gpu,
|
||||
positions=self.runner.positions.gpu,
|
||||
@@ -295,46 +334,49 @@ class EagleProposer(VllmEagleProposer):
|
||||
)
|
||||
|
||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
attn_metadata_eagle = builder.build_for_graph_capture(
|
||||
common_attn_metadata, AscendAttentionState.ChunkedPrefill)
|
||||
attn_metadata = {}
|
||||
for layer_name in self.attn_layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_eagle
|
||||
# update the tensor's address for each step.
|
||||
for draft_step in range(self.num_speculative_tokens):
|
||||
common_attn_metadata = self.shallow_copy_metadata(
|
||||
common_attn_metadata)
|
||||
# Set the real slot_mapping.
|
||||
common_attn_metadata.slot_mapping = self.slot_mapping_group[draft_step]
|
||||
attn_metadata_eagle = builder.build_for_graph_capture(
|
||||
common_attn_metadata, AscendAttentionState.ChunkedPrefill)
|
||||
per_layer_attn_metadata = dict()
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata_eagle
|
||||
multi_steps_attn_metadata.append(per_layer_attn_metadata)
|
||||
|
||||
model_input_ids = self.input_ids[:num_tokens]
|
||||
model_positions = self._get_positions(num_tokens)
|
||||
model_previous_hidden_states = self.hidden_states[:num_tokens]
|
||||
for i in range(self.num_speculative_tokens):
|
||||
if i > 0 and in_graph_capturing and aclgraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
aclgraph_runtime_mode = CUDAGraphMode.NONE
|
||||
with set_ascend_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
num_actual_tokens=0,
|
||||
in_profile_run=is_profile,
|
||||
batch_descriptor=batch_descriptor,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
is_draft_model=True):
|
||||
|
||||
model_previous_hidden_states, model_positions = self.maybe_pad_and_reduce(
|
||||
model_previous_hidden_states, model_positions)
|
||||
batch_size = num_tokens // (self.num_speculative_tokens + 1)
|
||||
with set_ascend_forward_context(
|
||||
multi_steps_attn_metadata[0] if multi_steps_attn_metadata else None,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
num_actual_tokens=0,
|
||||
in_profile_run=is_profile,
|
||||
batch_descriptor=batch_descriptor,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
is_draft_model=True):
|
||||
|
||||
self.model(
|
||||
input_ids=model_input_ids,
|
||||
positions=model_positions,
|
||||
hidden_states=model_previous_hidden_states,
|
||||
)
|
||||
forward_context = get_forward_context()
|
||||
if (forward_context.cudagraph_runtime_mode
|
||||
== CUDAGraphMode.FULL
|
||||
and not forward_context.capturing):
|
||||
self._update_full_graph_params(forward_context, num_tokens)
|
||||
|
||||
model_previous_hidden_states, model_positions, _ = self.maybe_all_gather_and_unpad(
|
||||
model_previous_hidden_states, model_positions)
|
||||
|
||||
dummy_compute_logits(self.hidden_states)
|
||||
self._runnable(
|
||||
num_input_tokens=num_tokens,
|
||||
batch_size=batch_size,
|
||||
last_token_indices=self.last_token_indices[:batch_size],
|
||||
# The target_position's address is same as the model_positions's
|
||||
target_positions=model_positions,
|
||||
inputs_embeds=None,
|
||||
multi_steps_attn_metadata=multi_steps_attn_metadata,
|
||||
)
|
||||
forward_context = get_forward_context()
|
||||
if (forward_context.cudagraph_runtime_mode
|
||||
== CUDAGraphMode.FULL
|
||||
and not forward_context.capturing):
|
||||
self._update_full_graph_params(forward_context, num_tokens,
|
||||
multi_steps_attn_metadata)
|
||||
|
||||
def _propose(
|
||||
self,
|
||||
@@ -408,17 +450,59 @@ class EagleProposer(VllmEagleProposer):
|
||||
inputs_embeds = None
|
||||
input_ids = self.input_ids[:num_input_tokens]
|
||||
|
||||
# Update slot_mapping for different speculative.
|
||||
# NOTE: Currently, we only remake the slot_mapping, because it's the
|
||||
# only tensor which will be used in current FIA.
|
||||
# Strictly speaking, `query_start_loc`, `seq_lens` should also have
|
||||
# their memory allocated separately for each step just like `slot_mapping`.
|
||||
slot_mapping_lens = num_input_tokens if num_input_tokens < \
|
||||
common_attn_metadata.slot_mapping.shape[0] else \
|
||||
common_attn_metadata.slot_mapping.shape[0]
|
||||
self.slot_mapping_group[0][:slot_mapping_lens].copy_(
|
||||
common_attn_metadata.slot_mapping[:slot_mapping_lens])
|
||||
self.slot_mapping_group[0][slot_mapping_lens:].fill_(-1)
|
||||
common_attn_metadata.slot_mapping = self.slot_mapping_group[0][:slot_mapping_lens]
|
||||
|
||||
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
|
||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
attn_metadata = builder.build(0, common_attn_metadata,
|
||||
self.runner.get_model())
|
||||
|
||||
# update global cos, sin
|
||||
update_cos_sin(self._get_positions(num_input_tokens))
|
||||
per_layer_attn_metadata = {}
|
||||
|
||||
if self.uses_mrope:
|
||||
used_update_positions = target_positions[:, last_token_indices]
|
||||
else:
|
||||
used_update_positions = target_positions[last_token_indices]
|
||||
per_layer_attn_metadata = dict()
|
||||
# The first step of speculative.
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
multi_steps_attn_metadata = [per_layer_attn_metadata]
|
||||
|
||||
# Copy the old attn_metadata and update
|
||||
for draft_step in range(1, self.num_speculative_tokens):
|
||||
common_attn_metadata, attn_metadata = \
|
||||
self.attn_update_stack_num_spec_norm(
|
||||
draft_step,
|
||||
attn_metadata,
|
||||
common_attn_metadata,
|
||||
batch_size,
|
||||
num_input_tokens,
|
||||
used_update_positions,
|
||||
aclgraph_runtime_mode)
|
||||
per_layer_attn_metadata = dict()
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
multi_steps_attn_metadata.append(per_layer_attn_metadata)
|
||||
|
||||
last_token_indices_len = last_token_indices.shape[0]
|
||||
self.last_token_indices[:last_token_indices_len].copy_(
|
||||
last_token_indices)
|
||||
|
||||
with set_ascend_forward_context(
|
||||
per_layer_attn_metadata,
|
||||
multi_steps_attn_metadata[0],
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_actual_tokens=num_tokens,
|
||||
@@ -426,34 +510,52 @@ class EagleProposer(VllmEagleProposer):
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
is_draft_model=True):
|
||||
|
||||
# The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative tokens' proposings.
|
||||
# `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the inputs of speculative model.
|
||||
model_input_ids = self.input_ids[:num_input_tokens]
|
||||
model_positions = self._get_positions(num_input_tokens)
|
||||
model_hidden_states = self.hidden_states[:num_input_tokens]
|
||||
|
||||
model_hidden_states, model_positions = self.maybe_pad_and_reduce(
|
||||
model_hidden_states, model_positions)
|
||||
|
||||
ret_hidden_states = self.model(
|
||||
input_ids=model_input_ids,
|
||||
positions=model_positions,
|
||||
hidden_states=model_hidden_states,
|
||||
inputs_embeds = inputs_embeds
|
||||
)
|
||||
if self.method == "mtp":
|
||||
last_hidden_states = ret_hidden_states
|
||||
hidden_states = last_hidden_states
|
||||
else:
|
||||
last_hidden_states, hidden_states = ret_hidden_states
|
||||
draft_token_ids = self._runnable(
|
||||
num_input_tokens=num_input_tokens,
|
||||
batch_size=batch_size,
|
||||
last_token_indices=self.last_token_indices[:last_token_indices_len],
|
||||
target_positions=target_positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
multi_steps_attn_metadata=multi_steps_attn_metadata)
|
||||
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
self._update_full_graph_params(forward_context,
|
||||
num_input_tokens)
|
||||
num_input_tokens,
|
||||
multi_steps_attn_metadata)
|
||||
return draft_token_ids
|
||||
|
||||
last_hidden_states, model_positions, hidden_states = self.maybe_all_gather_and_unpad(
|
||||
last_hidden_states, model_positions, hidden_states)
|
||||
def _run_merged_draft(self,
|
||||
num_input_tokens,
|
||||
batch_size,
|
||||
last_token_indices,
|
||||
target_positions,
|
||||
inputs_embeds,
|
||||
multi_steps_attn_metadata,
|
||||
) -> torch.Tensor:
|
||||
# The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative tokens' proposings.
|
||||
# `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the inputs of speculative model.
|
||||
model_input_ids = self.input_ids[:num_input_tokens]
|
||||
model_positions = self._get_positions(num_input_tokens)
|
||||
model_hidden_states = self.hidden_states[:num_input_tokens]
|
||||
|
||||
model_hidden_states, model_positions = self.maybe_pad_and_reduce(
|
||||
model_hidden_states, model_positions)
|
||||
|
||||
ret_hidden_states = self.model(
|
||||
input_ids=model_input_ids,
|
||||
positions=model_positions,
|
||||
hidden_states=model_hidden_states,
|
||||
inputs_embeds = inputs_embeds,
|
||||
)
|
||||
if self.method == "mtp":
|
||||
last_hidden_states = ret_hidden_states
|
||||
hidden_states = last_hidden_states
|
||||
else:
|
||||
last_hidden_states, hidden_states = ret_hidden_states
|
||||
|
||||
last_hidden_states, model_positions, hidden_states = self.maybe_all_gather_and_unpad(
|
||||
last_hidden_states, model_positions, hidden_states)
|
||||
|
||||
sample_hidden_states = last_hidden_states[last_token_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
@@ -477,53 +579,17 @@ class EagleProposer(VllmEagleProposer):
|
||||
hidden_states = hidden_states[last_token_indices]
|
||||
last_token_indices = self.arange[:batch_size]
|
||||
|
||||
if self.use_cuda_graph and \
|
||||
batch_size <= self.runner.cudagraph_batch_sizes[-1]:
|
||||
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
|
||||
else:
|
||||
input_batch_size = batch_size
|
||||
input_batch_size = num_input_tokens
|
||||
|
||||
if self.use_cuda_graph:
|
||||
aclgraph_runtime_mode, batch_descriptor = \
|
||||
self.runner.cudagraph_dispatcher.dispatch(num_tokens=input_batch_size, uniform_decode=True, has_lora=has_lora)
|
||||
else:
|
||||
aclgraph_runtime_mode = CUDAGraphMode.NONE
|
||||
batch_descriptor = None
|
||||
forward_context = get_forward_context()
|
||||
forward_context.num_tokens = input_batch_size
|
||||
forward_context.num_accept_tokens = batch_size
|
||||
|
||||
if (
|
||||
aclgraph_runtime_mode == CUDAGraphMode.FULL
|
||||
and (pad_size := input_batch_size - batch_size) > 0
|
||||
):
|
||||
common_attn_metadata.num_reqs = input_batch_size
|
||||
common_attn_metadata.block_table_tensor = self._pad_tensor(
|
||||
common_attn_metadata.block_table_tensor, pad_size)
|
||||
common_attn_metadata.seq_lens = self._pad_tensor(
|
||||
common_attn_metadata.seq_lens, pad_size)
|
||||
common_attn_metadata.seq_lens_cpu = self._pad_tensor(
|
||||
common_attn_metadata.seq_lens_cpu, pad_size)
|
||||
common_attn_metadata.num_computed_tokens_cpu = self._pad_tensor(
|
||||
common_attn_metadata.num_computed_tokens_cpu, pad_size)
|
||||
common_attn_metadata.query_start_loc = self.arange[
|
||||
:input_batch_size + 1]
|
||||
common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
|
||||
self.token_arange_np[:input_batch_size + 1]).clone()
|
||||
else:
|
||||
common_attn_metadata.query_start_loc = self.arange[:batch_size + 1]
|
||||
common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
|
||||
self.token_arange_np[:batch_size + 1]).clone()
|
||||
|
||||
common_attn_metadata.num_actual_tokens = batch_size
|
||||
common_attn_metadata.max_query_len = 1
|
||||
common_attn_metadata.decode_token_per_req = 1
|
||||
common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
|
||||
common_attn_metadata.graph_pad_size = -1
|
||||
common_attn_metadata.num_input_tokens = input_batch_size
|
||||
|
||||
for now_speculative in range(self.num_speculative_tokens - 1):
|
||||
for draft_step in range(self.num_speculative_tokens - 1):
|
||||
# Update the inputs.
|
||||
# cast to int32 is crucial when eagle model is compiled.
|
||||
# tensor.argmax() returns int64 by default.
|
||||
input_ids = draft_token_ids_tensor[now_speculative]
|
||||
input_ids = draft_token_ids_tensor[draft_step]
|
||||
positions += 1
|
||||
|
||||
# NOTE(woosuk): We should handle the case where the draft model
|
||||
@@ -545,67 +611,6 @@ class EagleProposer(VllmEagleProposer):
|
||||
clamped_positions = torch.where(exceeds_max_model_len, 0,
|
||||
positions)
|
||||
|
||||
# For data integrity when async scheduling, we shouldn't use in place
|
||||
# operations in case they are modified in next step's `prepare_input`
|
||||
# of main model.
|
||||
# Increment the sequence lengths.
|
||||
common_attn_metadata.seq_lens[:batch_size] += 1
|
||||
# For the requests that exceed the max model length, we set the
|
||||
# sequence length to 1 to minimize their overheads in attention.
|
||||
common_attn_metadata.seq_lens[:batch_size].masked_fill_(
|
||||
exceeds_max_model_len, 1)
|
||||
|
||||
common_attn_metadata.seq_lens_cpu[:batch_size] = (
|
||||
common_attn_metadata.seq_lens_cpu[:batch_size] + 1)
|
||||
exceeds_mask = common_attn_metadata.seq_lens_cpu[:batch_size] >= \
|
||||
self.max_model_len
|
||||
common_attn_metadata.seq_lens_cpu[:batch_size].masked_fill_(
|
||||
exceeds_mask, 1)
|
||||
common_attn_metadata.num_computed_tokens_cpu[:batch_size] += 1
|
||||
if self.uses_mrope:
|
||||
common_attn_metadata.positions[:batch_size].copy_(clamped_positions[0])
|
||||
else:
|
||||
common_attn_metadata.positions[:batch_size].copy_(clamped_positions)
|
||||
if self.attn_metadata_builder is None:
|
||||
attn_metadata_builder = self._get_attention_metadata_builder()
|
||||
else:
|
||||
attn_metadata_builder = self.attn_metadata_builder
|
||||
block_size = attn_metadata_builder.kv_cache_spec.block_size
|
||||
|
||||
# Compute the slot mapping.
|
||||
if self.uses_mrope:
|
||||
block_numbers = clamped_positions[0] // block_size
|
||||
else:
|
||||
block_numbers = (clamped_positions // block_size)
|
||||
block_ids = attn_metadata.block_tables.gather(
|
||||
dim=1, index=block_numbers.view(-1, 1))
|
||||
block_ids = block_ids.view(-1)
|
||||
if self.uses_mrope:
|
||||
slot_mapping = (block_ids * block_size +
|
||||
clamped_positions[0] % block_size)
|
||||
else:
|
||||
slot_mapping = (block_ids * block_size +
|
||||
clamped_positions % block_size)
|
||||
|
||||
# Mask out the slot mappings that exceed the max model length.
|
||||
# Otherwise, the KV cache will be inadvertently updated with the
|
||||
# padding tokens.
|
||||
slot_mapping.masked_fill_(exceeds_max_model_len,
|
||||
PADDING_SLOT_ID)
|
||||
|
||||
common_attn_metadata.slot_mapping[:slot_mapping.shape[0]].copy_(
|
||||
slot_mapping.to(torch.int32))
|
||||
common_attn_metadata.slot_mapping[slot_mapping.shape[0]:].fill_(
|
||||
PADDING_SLOT_ID)
|
||||
|
||||
# Rebuild attention metadata
|
||||
attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
draft_index=now_speculative + 1,
|
||||
)
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.input_ids[:batch_size] = input_ids
|
||||
self._set_positions(batch_size, clamped_positions)
|
||||
@@ -624,55 +629,175 @@ class EagleProposer(VllmEagleProposer):
|
||||
update_cos_sin(self._get_positions(input_batch_size))
|
||||
|
||||
# Run the model.
|
||||
with set_ascend_forward_context(
|
||||
per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=input_batch_size,
|
||||
num_actual_tokens=batch_size,
|
||||
batch_descriptor=batch_descriptor,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
is_draft_model=True):
|
||||
|
||||
# The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative tokens' proposings.
|
||||
# `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the inputs of speculative model.
|
||||
model_input_ids = self.input_ids[:input_batch_size]
|
||||
model_positions = self._get_positions(input_batch_size)
|
||||
model_hidden_states = self.hidden_states[:input_batch_size]
|
||||
# The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative tokens' proposings.
|
||||
# `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the inputs of speculative model.
|
||||
model_input_ids = self.input_ids[:input_batch_size]
|
||||
model_positions = self._get_positions(input_batch_size)
|
||||
model_hidden_states = self.hidden_states[:input_batch_size]
|
||||
|
||||
model_hidden_states, model_positions = self.maybe_pad_and_reduce(
|
||||
model_hidden_states, model_positions)
|
||||
model_hidden_states, model_positions = self.maybe_pad_and_reduce(
|
||||
model_hidden_states, model_positions)
|
||||
|
||||
ret_hidden_states = self.model(
|
||||
input_ids=model_input_ids,
|
||||
positions=model_positions,
|
||||
hidden_states=model_hidden_states,
|
||||
inputs_embeds = inputs_embeds
|
||||
)
|
||||
if self.method == "mtp":
|
||||
last_hidden_states = ret_hidden_states
|
||||
hidden_states = last_hidden_states
|
||||
else:
|
||||
last_hidden_states, hidden_states = ret_hidden_states
|
||||
forward_context.attn_metadata = multi_steps_attn_metadata[draft_step + 1] \
|
||||
if multi_steps_attn_metadata else None
|
||||
ret_hidden_states = self.model(
|
||||
input_ids=model_input_ids,
|
||||
positions=model_positions,
|
||||
hidden_states=model_hidden_states,
|
||||
inputs_embeds = inputs_embeds,
|
||||
)
|
||||
if self.method == "mtp":
|
||||
last_hidden_states = ret_hidden_states
|
||||
hidden_states = last_hidden_states
|
||||
else:
|
||||
last_hidden_states, hidden_states = ret_hidden_states
|
||||
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
self._update_full_graph_params(forward_context,
|
||||
input_batch_size)
|
||||
|
||||
last_hidden_states, model_positions, hidden_states = self.maybe_all_gather_and_unpad(
|
||||
last_hidden_states, model_positions, hidden_states)
|
||||
last_hidden_states, model_positions, hidden_states = self.maybe_all_gather_and_unpad(
|
||||
last_hidden_states, model_positions, hidden_states)
|
||||
|
||||
hidden_states = hidden_states[:batch_size]
|
||||
logits = self.model.compute_logits(last_hidden_states[:batch_size])
|
||||
|
||||
# TODO(wenlong): get more than one token for tree attention
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
draft_token_ids_tensor[now_speculative + 1] = draft_token_ids
|
||||
draft_token_ids_tensor[draft_step + 1] = draft_token_ids
|
||||
|
||||
# [batch_size, num_speculative_tokens]
|
||||
draft_token_ids = draft_token_ids_tensor.swapaxes(0, 1)
|
||||
return draft_token_ids
|
||||
|
||||
def attn_update_stack_num_spec_norm(self,
|
||||
# `draft_step` must start from `1`, no `0`
|
||||
draft_step,
|
||||
old_attn_metadata,
|
||||
old_common_metadata,
|
||||
batch_size,
|
||||
input_batch_size,
|
||||
used_update_positions,
|
||||
aclgraph_runtime_mode):
|
||||
|
||||
assert(draft_step > 0)
|
||||
common_attn_metadata = self.shallow_copy_metadata(old_common_metadata)
|
||||
|
||||
if draft_step == 1:
|
||||
if (
|
||||
aclgraph_runtime_mode == CUDAGraphMode.FULL
|
||||
and (pad_size := input_batch_size - batch_size) > 0
|
||||
):
|
||||
common_attn_metadata.num_reqs = input_batch_size
|
||||
common_attn_metadata.block_table_tensor = self._pad_tensor(
|
||||
common_attn_metadata.block_table_tensor, pad_size)
|
||||
common_attn_metadata.seq_lens = self._pad_tensor(
|
||||
common_attn_metadata.seq_lens, pad_size)
|
||||
common_attn_metadata.seq_lens_cpu = self._pad_tensor(
|
||||
common_attn_metadata.seq_lens_cpu, pad_size)
|
||||
common_attn_metadata.num_computed_tokens_cpu = self._pad_tensor(
|
||||
common_attn_metadata.num_computed_tokens_cpu, pad_size)
|
||||
common_attn_metadata.query_start_loc = self.arange[
|
||||
:input_batch_size + 1]
|
||||
common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
|
||||
self.token_arange_np[:input_batch_size + 1]).clone()
|
||||
else:
|
||||
common_attn_metadata.query_start_loc = self.arange[:batch_size + 1]
|
||||
common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
|
||||
self.token_arange_np[:batch_size + 1]).clone()
|
||||
|
||||
common_attn_metadata.num_actual_tokens = batch_size
|
||||
common_attn_metadata.max_query_len = 1
|
||||
common_attn_metadata.decode_token_per_req = 1
|
||||
common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
|
||||
common_attn_metadata.graph_pad_size = -1
|
||||
common_attn_metadata.num_input_tokens = input_batch_size
|
||||
|
||||
# The loop part
|
||||
|
||||
used_update_positions += 1
|
||||
|
||||
# NOTE(woosuk): We should handle the case where the draft model
|
||||
# generates tokens beyond the max model length. Since it is complex
|
||||
# to remove such requests from the batch, we keep them in the batch
|
||||
# but adjust the position ids and slot mappings to avoid the
|
||||
# out-of-range access during the model execution. The draft tokens
|
||||
# generated with this adjustment should be ignored.
|
||||
if self.uses_mrope:
|
||||
exceeds_max_model_len = used_update_positions[
|
||||
0] >= self.vllm_config.model_config.max_model_len
|
||||
# Mask out the position ids that exceed the max model length.
|
||||
# Otherwise, we may get out-of-range error in RoPE.
|
||||
clamped_positions = torch.where(
|
||||
exceeds_max_model_len.unsqueeze(0),
|
||||
torch.zeros_like(used_update_positions), used_update_positions)
|
||||
else:
|
||||
exceeds_max_model_len = used_update_positions >= \
|
||||
self.vllm_config.model_config.max_model_len
|
||||
clamped_positions = torch.where(exceeds_max_model_len, 0,
|
||||
used_update_positions)
|
||||
|
||||
# For data integrity when async scheduling, we shouldn't use in place
|
||||
# operations in case they are modified in next step's `prepare_input`
|
||||
# of main model.
|
||||
# Increment the sequence lengths.
|
||||
common_attn_metadata.seq_lens[:batch_size] += 1
|
||||
# For the requests that exceed the max model length, we set the
|
||||
# sequence length to 1 to minimize their overheads in attention.
|
||||
common_attn_metadata.seq_lens[:batch_size].masked_fill_(
|
||||
exceeds_max_model_len, 1)
|
||||
|
||||
common_attn_metadata.seq_lens_cpu[:batch_size] = (
|
||||
common_attn_metadata.seq_lens_cpu[:batch_size] + 1)
|
||||
exceeds_mask = common_attn_metadata.seq_lens_cpu[:batch_size] >= \
|
||||
self.max_model_len
|
||||
common_attn_metadata.seq_lens_cpu[:batch_size].masked_fill_(
|
||||
exceeds_mask, 1)
|
||||
common_attn_metadata.num_computed_tokens_cpu[:batch_size] += 1
|
||||
if self.uses_mrope:
|
||||
common_attn_metadata.positions[:batch_size].copy_(clamped_positions[0])
|
||||
else:
|
||||
common_attn_metadata.positions[:batch_size].copy_(clamped_positions)
|
||||
|
||||
if self.attn_metadata_builder is None:
|
||||
attn_metadata_builder = self._get_attention_metadata_builder()
|
||||
else:
|
||||
attn_metadata_builder = self.attn_metadata_builder
|
||||
block_size = attn_metadata_builder.kv_cache_spec.block_size
|
||||
|
||||
# Compute the slot mapping.
|
||||
if self.uses_mrope:
|
||||
block_numbers = clamped_positions[0] // block_size
|
||||
else:
|
||||
block_numbers = (clamped_positions // block_size)
|
||||
block_ids = old_attn_metadata.block_tables.gather(
|
||||
dim=1, index=block_numbers.view(-1, 1))
|
||||
block_ids = block_ids.view(-1)
|
||||
if self.uses_mrope:
|
||||
slot_mapping = (block_ids * block_size +
|
||||
clamped_positions[0] % block_size)
|
||||
else:
|
||||
slot_mapping = (block_ids * block_size +
|
||||
clamped_positions % block_size)
|
||||
|
||||
# Mask out the slot mappings that exceed the max model length.
|
||||
# Otherwise, the KV cache will be inadvertently updated with the
|
||||
# padding tokens.
|
||||
slot_mapping.masked_fill_(exceeds_max_model_len,
|
||||
PADDING_SLOT_ID)
|
||||
self.slot_mapping_group[draft_step][:slot_mapping.shape[0]].copy_(
|
||||
slot_mapping.to(torch.int32))
|
||||
self.slot_mapping_group[draft_step][slot_mapping.shape[0]:].fill_(
|
||||
PADDING_SLOT_ID)
|
||||
# Set the address of the attn_metadata.slot_mapping to the self.slot_mapping_group[idx]
|
||||
common_attn_metadata.slot_mapping = self.slot_mapping_group[draft_step][
|
||||
:slot_mapping.shape[0]]
|
||||
|
||||
# Rebuild attention metadata
|
||||
attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
draft_index=draft_step,
|
||||
)
|
||||
|
||||
return common_attn_metadata, attn_metadata
|
||||
|
||||
def prepare_next_token_ids_padded(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
@@ -1011,7 +1136,7 @@ class EagleProposer(VllmEagleProposer):
|
||||
return num_tokens, input_ids, target_hidden_states, max_query_len, seq_lens, cu_num_tokens
|
||||
|
||||
# update full-graph params for one spec token
|
||||
def _update_full_graph_params(self, forward_context, num_tokens):
|
||||
def _update_full_graph_params(self, forward_context, num_tokens, draft_attn_metadatas=None):
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
update_mla_attn_dcp_pcp_params(self.update_stream,
|
||||
@@ -1026,7 +1151,7 @@ class EagleProposer(VllmEagleProposer):
|
||||
num_tokens)
|
||||
else:
|
||||
update_attn_params(self.update_stream, forward_context,
|
||||
num_tokens, self.vllm_config)
|
||||
num_tokens, self.vllm_config, draft_attn_metadatas)
|
||||
|
||||
# padding tensor into desired size
|
||||
def _pad_tensor(self, tensor, pad_size):
|
||||
|
||||
Reference in New Issue
Block a user