From a8576ec610edf2eca9f98272d42187345277f737 Mon Sep 17 00:00:00 2001 From: lilinsiman Date: Tue, 20 Jan 2026 10:06:00 +0800 Subject: [PATCH] [Refactor][EAGLE] 5/N Update attn_metadata by common_attn_metadata (#5869) ### What this PR does / why we need it? 4/N EAGLE refactor plan devided into many parts, this PR is the first change, which modifies the attn_metadata update method by modifying common_metadata and then rebuilding the code. ### Does this PR introduce _any_ user-facing change? ut ### How was this patch tested? no - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/bde38c11df0ea066a740efe9b77fff5418be45df --------- Signed-off-by: lilinsiman Signed-off-by: Zetong Li Co-authored-by: Zetong Li --- vllm_ascend/spec_decode/eagle_proposer.py | 101 +++++++++++++++------- 1 file changed, 71 insertions(+), 30 deletions(-) diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 5322c5cd..b8b9094b 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -471,22 +471,42 @@ class EagleProposer(VllmEagleProposer): else: input_batch_size = batch_size - attn_metadata.num_actual_tokens = batch_size - attn_metadata.max_query_len = 1 - attn_metadata.query_start_loc = self.arange_cpu[:batch_size + 1] - attn_metadata.num_decodes, attn_metadata.num_prefills, attn_metadata.num_decode_tokens, attn_metadata.num_prefill_tokens = 0, batch_size, 0, batch_size - attn_metadata.num_actual_tokens_pcp_padded = attn_metadata.num_decode_tokens + attn_metadata.num_prefill_tokens - - attn_metadata.actual_seq_lengths_q = attn_metadata.query_start_loc[ - 1:].tolist() - attn_metadata.seq_lens_list = attn_metadata.seq_lens.tolist() - attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill if self.use_cuda_graph: aclgraph_runtime_mode, batch_descriptor = \ self.runner.cudagraph_dispatcher.dispatch(num_tokens=input_batch_size, uniform_decode=True, has_lora=has_lora) else: aclgraph_runtime_mode = CUDAGraphMode.NONE batch_descriptor = None + + if ( + aclgraph_runtime_mode == CUDAGraphMode.FULL + and (pad_size := input_batch_size - batch_size) > 0 + ): + common_attn_metadata.num_reqs = input_batch_size + common_attn_metadata.block_table_tensor = self._pad_tensor( + common_attn_metadata.block_table_tensor, pad_size) + common_attn_metadata.seq_lens = self._pad_tensor( + common_attn_metadata.seq_lens, pad_size) + common_attn_metadata.seq_lens_cpu = self._pad_tensor( + common_attn_metadata.seq_lens_cpu, pad_size) + common_attn_metadata.num_computed_tokens_cpu = self._pad_tensor( + common_attn_metadata.num_computed_tokens_cpu, pad_size) + common_attn_metadata.query_start_loc = self.arange[ + :input_batch_size + 1] + common_attn_metadata.query_start_loc_cpu = torch.from_numpy( + self.token_arange_np[:input_batch_size + 1]).clone() + else: + common_attn_metadata.query_start_loc = self.arange[:batch_size + 1] + common_attn_metadata.query_start_loc_cpu = torch.from_numpy( + self.token_arange_np[:batch_size + 1]).clone() + + common_attn_metadata.num_actual_tokens = batch_size + common_attn_metadata.max_query_len = 1 + common_attn_metadata.decode_token_per_req = 1 + common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill + common_attn_metadata.graph_pad_size = -1 + common_attn_metadata.num_input_tokens = input_batch_size + for now_speculative in range(self.num_speculative_tokens - 1): # Update the inputs. # cast to int32 is crucial when eagle model is compiled. @@ -513,18 +533,27 @@ class EagleProposer(VllmEagleProposer): clamped_positions = torch.where(exceeds_max_model_len, 0, positions) - # TODO: Increment the sequence lengths. - - attn_metadata.seq_lens = attn_metadata.seq_lens + 1 - attn_metadata.seq_lens_list = [ - _ + 1 for _ in attn_metadata.seq_lens_list - ] - # TODO: Consider max model length. - # attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, - # self.max_model_len) + # For data integrity when async scheduling, we shouldn't use in place + # operations in case they are modified in next step's `prepare_input` + # of main model. + # Increment the sequence lengths. + common_attn_metadata.seq_lens[:batch_size] += 1 # For the requests that exceed the max model length, we set the - # TODO: sequence length to 1 to minimize their overheads in attention. + # sequence length to 1 to minimize their overheads in attention. + common_attn_metadata.seq_lens[:batch_size].masked_fill_( + exceeds_max_model_len, 1) + common_attn_metadata.seq_lens_cpu[:batch_size] = ( + common_attn_metadata.seq_lens_cpu[:batch_size] + 1) + exceeds_mask = common_attn_metadata.seq_lens_cpu[:batch_size] >= \ + self.max_model_len + common_attn_metadata.seq_lens_cpu[:batch_size].masked_fill_( + exceeds_mask, 1) + common_attn_metadata.num_computed_tokens_cpu[:batch_size] += 1 + if self.uses_mrope: + common_attn_metadata.positions[:batch_size].copy_(clamped_positions[0]) + else: + common_attn_metadata.positions[:batch_size].copy_(clamped_positions) if self.attn_metadata_builder is None: attn_metadata_builder = self._get_attention_metadata_builder() else: @@ -540,22 +569,31 @@ class EagleProposer(VllmEagleProposer): dim=1, index=block_numbers.view(-1, 1)) block_ids = block_ids.view(-1) if self.uses_mrope: - slot_mapping_tmp = (block_ids * block_size + + slot_mapping = (block_ids * block_size + clamped_positions[0] % block_size) else: - slot_mapping_tmp = (block_ids * block_size + + slot_mapping = (block_ids * block_size + clamped_positions % block_size) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. - slot_mapping_tmp.masked_fill_(exceeds_max_model_len, + slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID) - # NOTE: ASCEND slot_mapping must on cpu - attn_metadata.slot_mapping[:slot_mapping_tmp.shape[0]].copy_( - slot_mapping_tmp.to(torch.int32)) - attn_metadata.slot_mapping[slot_mapping_tmp.shape[0]:].fill_( + + common_attn_metadata.slot_mapping[:slot_mapping.shape[0]].copy_( + slot_mapping.to(torch.int32)) + common_attn_metadata.slot_mapping[slot_mapping.shape[0]:].fill_( PADDING_SLOT_ID) + + # Rebuild attention metadata + attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore + common_attn_metadata=common_attn_metadata, + draft_index=now_speculative + 1, + ) + for layer_name in self.attn_layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata + # copy inputs to buffer for cudagraph self.input_ids[:batch_size] = input_ids self._set_positions(batch_size, clamped_positions) @@ -569,9 +607,6 @@ class EagleProposer(VllmEagleProposer): else: input_ids = self.input_ids[:input_batch_size] inputs_embeds = None - attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask() - - attn_metadata.attn_mask = attn_mask # update global cos, sin update_cos_sin(self._get_positions(input_batch_size)) @@ -981,6 +1016,12 @@ class EagleProposer(VllmEagleProposer): update_attn_params(self.update_stream, forward_context, num_tokens, self.vllm_config) + # padding tensor into desired size + def _pad_tensor(self, tensor, pad_size): + pad = [0] * (2 * tensor.dim() - 1) + [pad_size] + padded_tensor = F.pad(tensor, pad, mode="constant", value=0) + return padded_tensor + def maybe_pad_and_reduce( self, hidden_states: torch.Tensor,