[eagle3][pcp] fix acceptance rate for eagle3 and pcp enabled (#7549)
### What this PR does / why we need it?
fix the position 3 acceptance rate for eagle3 and pcp enabled
detail:
In the merged graph of eagle_proposer, the code logic was changed from
updating the code once before the forward pass of the draft model to
updating all three positions of common_attn_metadata in the merged graph
before performing the forward pass of the model. As a result, the update
of position 2 and position 3 affected the update of position 1.
For example, in the following field:
common_attn_metadata.block_table_tensor[:batch_size] =
common_attn_metadata.block_table_tensor[block_indices]
When updating the block_table_tensor at position 2, the modification of
this field occurred at the original address of common_attn_metadata. As
a result, the parameter at position 1 was also modified, but the forward
pass at position 1 had not been performed. Therefore, a copy of the
address of block_table_tensor needs to be made, and the modification
needs to be performed on the new address to ensure complete isolation
between positions.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
tests and ut
- vLLM version: v0.18.0
- vLLM main:
8b6325758c
---------
Signed-off-by: lilinsiman <lilinsiman@gmail.com>
This commit is contained in:
@@ -570,6 +570,12 @@ class SpecDecodeBaseProposer(EagleProposer):
|
||||
|
||||
# Copy the old attn_metadata and update
|
||||
attn_metadata_i = per_layer_attn_metadata[self.attn_layer_names[0]]
|
||||
|
||||
# Clone the data so that when calculating the data at position 2 and position 3
|
||||
# in the merged graph, it does not affect position 1
|
||||
# FIXME(lilinsiman)
|
||||
common_attn_metadata.block_table_tensor = common_attn_metadata.block_table_tensor.clone()
|
||||
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
if self.num_speculative_tokens > 1 and not attn_metadata_i.num_prefills:
|
||||
# For pcp/dcp, tokens are split across different cp ranks,
|
||||
@@ -1136,9 +1142,16 @@ class SpecDecodeBaseProposer(EagleProposer):
|
||||
common_attn_metadata.num_input_tokens = input_batch_size
|
||||
|
||||
# The loop part
|
||||
|
||||
used_update_positions += 1
|
||||
|
||||
# Clone the data so that when calculating the data at position 2 and position 3
|
||||
# in the merged graph, it does not affect position 1
|
||||
# FIXME(lilinsiman)
|
||||
common_attn_metadata.seq_lens = common_attn_metadata.seq_lens.clone()
|
||||
common_attn_metadata.seq_lens_cpu = common_attn_metadata.seq_lens_cpu.clone()
|
||||
common_attn_metadata.num_computed_tokens_cpu = common_attn_metadata.num_computed_tokens_cpu.clone()
|
||||
common_attn_metadata.positions = common_attn_metadata.positions.clone()
|
||||
|
||||
# NOTE(woosuk): We should handle the case where the draft model
|
||||
# generates tokens beyond the max model length. Since it is complex
|
||||
# to remove such requests from the batch, we keep them in the batch
|
||||
|
||||
Reference in New Issue
Block a user