[Refactor][EAGLE] 5/N Update attn_metadata by common_attn_metadata (#5869)
### What this PR does / why we need it?
4/N EAGLE refactor plan devided into many parts, this PR is the first
change, which modifies the attn_metadata update method by modifying
common_metadata and then rebuilding the code.
### Does this PR introduce _any_ user-facing change?
ut
### How was this patch tested?
no
- vLLM version: v0.13.0
- vLLM main:
bde38c11df
---------
Signed-off-by: lilinsiman <lilinsiman@gmail.com>
Signed-off-by: Zetong Li <slippersss@126.com>
Co-authored-by: Zetong Li <slippersss@126.com>
This commit is contained in:
@@ -471,22 +471,42 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
else:
|
else:
|
||||||
input_batch_size = batch_size
|
input_batch_size = batch_size
|
||||||
|
|
||||||
attn_metadata.num_actual_tokens = batch_size
|
|
||||||
attn_metadata.max_query_len = 1
|
|
||||||
attn_metadata.query_start_loc = self.arange_cpu[:batch_size + 1]
|
|
||||||
attn_metadata.num_decodes, attn_metadata.num_prefills, attn_metadata.num_decode_tokens, attn_metadata.num_prefill_tokens = 0, batch_size, 0, batch_size
|
|
||||||
attn_metadata.num_actual_tokens_pcp_padded = attn_metadata.num_decode_tokens + attn_metadata.num_prefill_tokens
|
|
||||||
|
|
||||||
attn_metadata.actual_seq_lengths_q = attn_metadata.query_start_loc[
|
|
||||||
1:].tolist()
|
|
||||||
attn_metadata.seq_lens_list = attn_metadata.seq_lens.tolist()
|
|
||||||
attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
|
|
||||||
if self.use_cuda_graph:
|
if self.use_cuda_graph:
|
||||||
aclgraph_runtime_mode, batch_descriptor = \
|
aclgraph_runtime_mode, batch_descriptor = \
|
||||||
self.runner.cudagraph_dispatcher.dispatch(num_tokens=input_batch_size, uniform_decode=True, has_lora=has_lora)
|
self.runner.cudagraph_dispatcher.dispatch(num_tokens=input_batch_size, uniform_decode=True, has_lora=has_lora)
|
||||||
else:
|
else:
|
||||||
aclgraph_runtime_mode = CUDAGraphMode.NONE
|
aclgraph_runtime_mode = CUDAGraphMode.NONE
|
||||||
batch_descriptor = None
|
batch_descriptor = None
|
||||||
|
|
||||||
|
if (
|
||||||
|
aclgraph_runtime_mode == CUDAGraphMode.FULL
|
||||||
|
and (pad_size := input_batch_size - batch_size) > 0
|
||||||
|
):
|
||||||
|
common_attn_metadata.num_reqs = input_batch_size
|
||||||
|
common_attn_metadata.block_table_tensor = self._pad_tensor(
|
||||||
|
common_attn_metadata.block_table_tensor, pad_size)
|
||||||
|
common_attn_metadata.seq_lens = self._pad_tensor(
|
||||||
|
common_attn_metadata.seq_lens, pad_size)
|
||||||
|
common_attn_metadata.seq_lens_cpu = self._pad_tensor(
|
||||||
|
common_attn_metadata.seq_lens_cpu, pad_size)
|
||||||
|
common_attn_metadata.num_computed_tokens_cpu = self._pad_tensor(
|
||||||
|
common_attn_metadata.num_computed_tokens_cpu, pad_size)
|
||||||
|
common_attn_metadata.query_start_loc = self.arange[
|
||||||
|
:input_batch_size + 1]
|
||||||
|
common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
|
||||||
|
self.token_arange_np[:input_batch_size + 1]).clone()
|
||||||
|
else:
|
||||||
|
common_attn_metadata.query_start_loc = self.arange[:batch_size + 1]
|
||||||
|
common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
|
||||||
|
self.token_arange_np[:batch_size + 1]).clone()
|
||||||
|
|
||||||
|
common_attn_metadata.num_actual_tokens = batch_size
|
||||||
|
common_attn_metadata.max_query_len = 1
|
||||||
|
common_attn_metadata.decode_token_per_req = 1
|
||||||
|
common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
|
||||||
|
common_attn_metadata.graph_pad_size = -1
|
||||||
|
common_attn_metadata.num_input_tokens = input_batch_size
|
||||||
|
|
||||||
for now_speculative in range(self.num_speculative_tokens - 1):
|
for now_speculative in range(self.num_speculative_tokens - 1):
|
||||||
# Update the inputs.
|
# Update the inputs.
|
||||||
# cast to int32 is crucial when eagle model is compiled.
|
# cast to int32 is crucial when eagle model is compiled.
|
||||||
@@ -513,18 +533,27 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
clamped_positions = torch.where(exceeds_max_model_len, 0,
|
clamped_positions = torch.where(exceeds_max_model_len, 0,
|
||||||
positions)
|
positions)
|
||||||
|
|
||||||
# TODO: Increment the sequence lengths.
|
# For data integrity when async scheduling, we shouldn't use in place
|
||||||
|
# operations in case they are modified in next step's `prepare_input`
|
||||||
attn_metadata.seq_lens = attn_metadata.seq_lens + 1
|
# of main model.
|
||||||
attn_metadata.seq_lens_list = [
|
# Increment the sequence lengths.
|
||||||
_ + 1 for _ in attn_metadata.seq_lens_list
|
common_attn_metadata.seq_lens[:batch_size] += 1
|
||||||
]
|
|
||||||
# TODO: Consider max model length.
|
|
||||||
# attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
|
|
||||||
# self.max_model_len)
|
|
||||||
# For the requests that exceed the max model length, we set the
|
# For the requests that exceed the max model length, we set the
|
||||||
# TODO: sequence length to 1 to minimize their overheads in attention.
|
# sequence length to 1 to minimize their overheads in attention.
|
||||||
|
common_attn_metadata.seq_lens[:batch_size].masked_fill_(
|
||||||
|
exceeds_max_model_len, 1)
|
||||||
|
|
||||||
|
common_attn_metadata.seq_lens_cpu[:batch_size] = (
|
||||||
|
common_attn_metadata.seq_lens_cpu[:batch_size] + 1)
|
||||||
|
exceeds_mask = common_attn_metadata.seq_lens_cpu[:batch_size] >= \
|
||||||
|
self.max_model_len
|
||||||
|
common_attn_metadata.seq_lens_cpu[:batch_size].masked_fill_(
|
||||||
|
exceeds_mask, 1)
|
||||||
|
common_attn_metadata.num_computed_tokens_cpu[:batch_size] += 1
|
||||||
|
if self.uses_mrope:
|
||||||
|
common_attn_metadata.positions[:batch_size].copy_(clamped_positions[0])
|
||||||
|
else:
|
||||||
|
common_attn_metadata.positions[:batch_size].copy_(clamped_positions)
|
||||||
if self.attn_metadata_builder is None:
|
if self.attn_metadata_builder is None:
|
||||||
attn_metadata_builder = self._get_attention_metadata_builder()
|
attn_metadata_builder = self._get_attention_metadata_builder()
|
||||||
else:
|
else:
|
||||||
@@ -540,22 +569,31 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
dim=1, index=block_numbers.view(-1, 1))
|
dim=1, index=block_numbers.view(-1, 1))
|
||||||
block_ids = block_ids.view(-1)
|
block_ids = block_ids.view(-1)
|
||||||
if self.uses_mrope:
|
if self.uses_mrope:
|
||||||
slot_mapping_tmp = (block_ids * block_size +
|
slot_mapping = (block_ids * block_size +
|
||||||
clamped_positions[0] % block_size)
|
clamped_positions[0] % block_size)
|
||||||
else:
|
else:
|
||||||
slot_mapping_tmp = (block_ids * block_size +
|
slot_mapping = (block_ids * block_size +
|
||||||
clamped_positions % block_size)
|
clamped_positions % block_size)
|
||||||
|
|
||||||
# Mask out the slot mappings that exceed the max model length.
|
# Mask out the slot mappings that exceed the max model length.
|
||||||
# Otherwise, the KV cache will be inadvertently updated with the
|
# Otherwise, the KV cache will be inadvertently updated with the
|
||||||
# padding tokens.
|
# padding tokens.
|
||||||
slot_mapping_tmp.masked_fill_(exceeds_max_model_len,
|
slot_mapping.masked_fill_(exceeds_max_model_len,
|
||||||
PADDING_SLOT_ID)
|
PADDING_SLOT_ID)
|
||||||
# NOTE: ASCEND slot_mapping must on cpu
|
|
||||||
attn_metadata.slot_mapping[:slot_mapping_tmp.shape[0]].copy_(
|
common_attn_metadata.slot_mapping[:slot_mapping.shape[0]].copy_(
|
||||||
slot_mapping_tmp.to(torch.int32))
|
slot_mapping.to(torch.int32))
|
||||||
attn_metadata.slot_mapping[slot_mapping_tmp.shape[0]:].fill_(
|
common_attn_metadata.slot_mapping[slot_mapping.shape[0]:].fill_(
|
||||||
PADDING_SLOT_ID)
|
PADDING_SLOT_ID)
|
||||||
|
|
||||||
|
# Rebuild attention metadata
|
||||||
|
attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore
|
||||||
|
common_attn_metadata=common_attn_metadata,
|
||||||
|
draft_index=now_speculative + 1,
|
||||||
|
)
|
||||||
|
for layer_name in self.attn_layer_names:
|
||||||
|
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||||
|
|
||||||
# copy inputs to buffer for cudagraph
|
# copy inputs to buffer for cudagraph
|
||||||
self.input_ids[:batch_size] = input_ids
|
self.input_ids[:batch_size] = input_ids
|
||||||
self._set_positions(batch_size, clamped_positions)
|
self._set_positions(batch_size, clamped_positions)
|
||||||
@@ -569,9 +607,6 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
else:
|
else:
|
||||||
input_ids = self.input_ids[:input_batch_size]
|
input_ids = self.input_ids[:input_batch_size]
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask()
|
|
||||||
|
|
||||||
attn_metadata.attn_mask = attn_mask
|
|
||||||
|
|
||||||
# update global cos, sin
|
# update global cos, sin
|
||||||
update_cos_sin(self._get_positions(input_batch_size))
|
update_cos_sin(self._get_positions(input_batch_size))
|
||||||
@@ -981,6 +1016,12 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
update_attn_params(self.update_stream, forward_context,
|
update_attn_params(self.update_stream, forward_context,
|
||||||
num_tokens, self.vllm_config)
|
num_tokens, self.vllm_config)
|
||||||
|
|
||||||
|
# padding tensor into desired size
|
||||||
|
def _pad_tensor(self, tensor, pad_size):
|
||||||
|
pad = [0] * (2 * tensor.dim() - 1) + [pad_size]
|
||||||
|
padded_tensor = F.pad(tensor, pad, mode="constant", value=0)
|
||||||
|
return padded_tensor
|
||||||
|
|
||||||
def maybe_pad_and_reduce(
|
def maybe_pad_and_reduce(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|||||||
Reference in New Issue
Block a user