[Refactor][EAGLE] 5/N Update attn_metadata by common_attn_metadata (#5869)

### What this PR does / why we need it?
4/N EAGLE refactor plan devided into many parts, this PR is the first
change, which modifies the attn_metadata update method by modifying
common_metadata and then rebuilding the code.

### Does this PR introduce _any_ user-facing change?
ut

### How was this patch tested?
no
- vLLM version: v0.13.0
- vLLM main:
bde38c11df

---------

Signed-off-by: lilinsiman <lilinsiman@gmail.com>
Signed-off-by: Zetong Li <slippersss@126.com>
Co-authored-by: Zetong Li <slippersss@126.com>
This commit is contained in:
lilinsiman
2026-01-20 10:06:00 +08:00
committed by GitHub
parent f58e110afe
commit a8576ec610

View File

@@ -471,22 +471,42 @@ class EagleProposer(VllmEagleProposer):
else:
input_batch_size = batch_size
attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange_cpu[:batch_size + 1]
attn_metadata.num_decodes, attn_metadata.num_prefills, attn_metadata.num_decode_tokens, attn_metadata.num_prefill_tokens = 0, batch_size, 0, batch_size
attn_metadata.num_actual_tokens_pcp_padded = attn_metadata.num_decode_tokens + attn_metadata.num_prefill_tokens
attn_metadata.actual_seq_lengths_q = attn_metadata.query_start_loc[
1:].tolist()
attn_metadata.seq_lens_list = attn_metadata.seq_lens.tolist()
attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
if self.use_cuda_graph:
aclgraph_runtime_mode, batch_descriptor = \
self.runner.cudagraph_dispatcher.dispatch(num_tokens=input_batch_size, uniform_decode=True, has_lora=has_lora)
else:
aclgraph_runtime_mode = CUDAGraphMode.NONE
batch_descriptor = None
if (
aclgraph_runtime_mode == CUDAGraphMode.FULL
and (pad_size := input_batch_size - batch_size) > 0
):
common_attn_metadata.num_reqs = input_batch_size
common_attn_metadata.block_table_tensor = self._pad_tensor(
common_attn_metadata.block_table_tensor, pad_size)
common_attn_metadata.seq_lens = self._pad_tensor(
common_attn_metadata.seq_lens, pad_size)
common_attn_metadata.seq_lens_cpu = self._pad_tensor(
common_attn_metadata.seq_lens_cpu, pad_size)
common_attn_metadata.num_computed_tokens_cpu = self._pad_tensor(
common_attn_metadata.num_computed_tokens_cpu, pad_size)
common_attn_metadata.query_start_loc = self.arange[
:input_batch_size + 1]
common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
self.token_arange_np[:input_batch_size + 1]).clone()
else:
common_attn_metadata.query_start_loc = self.arange[:batch_size + 1]
common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
self.token_arange_np[:batch_size + 1]).clone()
common_attn_metadata.num_actual_tokens = batch_size
common_attn_metadata.max_query_len = 1
common_attn_metadata.decode_token_per_req = 1
common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
common_attn_metadata.graph_pad_size = -1
common_attn_metadata.num_input_tokens = input_batch_size
for now_speculative in range(self.num_speculative_tokens - 1):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
@@ -513,18 +533,27 @@ class EagleProposer(VllmEagleProposer):
clamped_positions = torch.where(exceeds_max_model_len, 0,
positions)
# TODO: Increment the sequence lengths.
attn_metadata.seq_lens = attn_metadata.seq_lens + 1
attn_metadata.seq_lens_list = [
_ + 1 for _ in attn_metadata.seq_lens_list
]
# TODO: Consider max model length.
# attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
# self.max_model_len)
# For data integrity when async scheduling, we shouldn't use in place
# operations in case they are modified in next step's `prepare_input`
# of main model.
# Increment the sequence lengths.
common_attn_metadata.seq_lens[:batch_size] += 1
# For the requests that exceed the max model length, we set the
# TODO: sequence length to 1 to minimize their overheads in attention.
# sequence length to 1 to minimize their overheads in attention.
common_attn_metadata.seq_lens[:batch_size].masked_fill_(
exceeds_max_model_len, 1)
common_attn_metadata.seq_lens_cpu[:batch_size] = (
common_attn_metadata.seq_lens_cpu[:batch_size] + 1)
exceeds_mask = common_attn_metadata.seq_lens_cpu[:batch_size] >= \
self.max_model_len
common_attn_metadata.seq_lens_cpu[:batch_size].masked_fill_(
exceeds_mask, 1)
common_attn_metadata.num_computed_tokens_cpu[:batch_size] += 1
if self.uses_mrope:
common_attn_metadata.positions[:batch_size].copy_(clamped_positions[0])
else:
common_attn_metadata.positions[:batch_size].copy_(clamped_positions)
if self.attn_metadata_builder is None:
attn_metadata_builder = self._get_attention_metadata_builder()
else:
@@ -540,22 +569,31 @@ class EagleProposer(VllmEagleProposer):
dim=1, index=block_numbers.view(-1, 1))
block_ids = block_ids.view(-1)
if self.uses_mrope:
slot_mapping_tmp = (block_ids * block_size +
slot_mapping = (block_ids * block_size +
clamped_positions[0] % block_size)
else:
slot_mapping_tmp = (block_ids * block_size +
slot_mapping = (block_ids * block_size +
clamped_positions % block_size)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
slot_mapping_tmp.masked_fill_(exceeds_max_model_len,
slot_mapping.masked_fill_(exceeds_max_model_len,
PADDING_SLOT_ID)
# NOTE: ASCEND slot_mapping must on cpu
attn_metadata.slot_mapping[:slot_mapping_tmp.shape[0]].copy_(
slot_mapping_tmp.to(torch.int32))
attn_metadata.slot_mapping[slot_mapping_tmp.shape[0]:].fill_(
common_attn_metadata.slot_mapping[:slot_mapping.shape[0]].copy_(
slot_mapping.to(torch.int32))
common_attn_metadata.slot_mapping[slot_mapping.shape[0]:].fill_(
PADDING_SLOT_ID)
# Rebuild attention metadata
attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore
common_attn_metadata=common_attn_metadata,
draft_index=now_speculative + 1,
)
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
self._set_positions(batch_size, clamped_positions)
@@ -569,9 +607,6 @@ class EagleProposer(VllmEagleProposer):
else:
input_ids = self.input_ids[:input_batch_size]
inputs_embeds = None
attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask()
attn_metadata.attn_mask = attn_mask
# update global cos, sin
update_cos_sin(self._get_positions(input_batch_size))
@@ -981,6 +1016,12 @@ class EagleProposer(VllmEagleProposer):
update_attn_params(self.update_stream, forward_context,
num_tokens, self.vllm_config)
# padding tensor into desired size
def _pad_tensor(self, tensor, pad_size):
pad = [0] * (2 * tensor.dim() - 1) + [pad_size]
padded_tensor = F.pad(tensor, pad, mode="constant", value=0)
return padded_tensor
def maybe_pad_and_reduce(
self,
hidden_states: torch.Tensor,