[Refactor][EAGLE] 4/N extract common methods from eagle and mtp (#5870)
### What this PR does / why we need it?
This PR aims to extract common methods from eagle_proposer and
mtp_proposer. This is a small step towards merging eagle and mtp.
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
by ci
- vLLM version: v0.13.0
- vLLM main:
bde38c11df
---------
Signed-off-by: Zetong Li <slippersss@126.com>
This commit is contained in:
@@ -89,7 +89,7 @@ class MtpProposer(EagleProposer):
|
||||
attn_metadata_mtp = builder.build_for_graph_capture(
|
||||
common_attn_metadata, attn_state)
|
||||
attn_metadata = {}
|
||||
for layer_name in self.attn_layer_name:
|
||||
for layer_name in self.attn_layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_mtp
|
||||
else:
|
||||
attn_metadata = None
|
||||
@@ -112,12 +112,8 @@ class MtpProposer(EagleProposer):
|
||||
batch_descriptor=batch_descriptor,
|
||||
is_draft_model=True,
|
||||
in_profile_run=is_profile):
|
||||
if self.enable_shared_expert_dp:
|
||||
positions = positions.unsqueeze(-1)
|
||||
positions = torch.ops.vllm.maybe_pad_and_reduce(positions)
|
||||
positions = positions.squeeze(-1)
|
||||
previous_hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
|
||||
previous_hidden_states)
|
||||
previous_hidden_states, positions = self.maybe_pad_and_reduce(
|
||||
previous_hidden_states, positions)
|
||||
self.model(input_ids=input_ids,
|
||||
positions=positions,
|
||||
hidden_states=previous_hidden_states)
|
||||
@@ -126,11 +122,8 @@ class MtpProposer(EagleProposer):
|
||||
not forward_context.capturing and not self.use_sparse:
|
||||
self._update_full_graph_params(forward_context, num_tokens)
|
||||
|
||||
if self.enable_shared_expert_dp:
|
||||
positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
positions, True)
|
||||
previous_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
previous_hidden_states, True)
|
||||
previous_hidden_states, positions, _ = self.maybe_all_gather_and_unpad(
|
||||
previous_hidden_states, positions)
|
||||
dummy_compute_logits(previous_hidden_states)
|
||||
if with_prefill:
|
||||
break
|
||||
@@ -249,11 +242,11 @@ class MtpProposer(EagleProposer):
|
||||
assert self.runner is not None
|
||||
|
||||
# Note(qcs): We may need to refactor these check logics.
|
||||
if self.runner.use_aclgraph and num_scheduled_tokens <= self.runner.cudagraph_batch_sizes[
|
||||
if self.use_cuda_graph and num_scheduled_tokens <= self.runner.cudagraph_batch_sizes[
|
||||
-1]:
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
||||
num_scheduled_tokens)
|
||||
elif self.use_aclgraph and num_tokens <= self.runner.cudagraph_batch_sizes[
|
||||
elif self.use_cuda_graph and num_tokens <= self.runner.cudagraph_batch_sizes[
|
||||
-1]:
|
||||
# Acl graph mode, add padding to the batch size
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||
@@ -304,7 +297,7 @@ class MtpProposer(EagleProposer):
|
||||
attn_metadata_mtp = builder.build(0, common_attn_metadata,
|
||||
self.runner.get_model())
|
||||
attn_metadata = {}
|
||||
for layer_name in self.attn_layer_name:
|
||||
for layer_name in self.attn_layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_mtp
|
||||
|
||||
for step in range(self.num_speculative_tokens):
|
||||
@@ -324,26 +317,8 @@ class MtpProposer(EagleProposer):
|
||||
positions = self.positions[:num_input_tokens]
|
||||
hidden_states = self.hidden_states[:num_input_tokens]
|
||||
|
||||
if self.enable_shared_expert_dp:
|
||||
# positions [N] -> [N, 1] for padding
|
||||
positions = positions.unsqueeze(-1)
|
||||
positions = torch.ops.vllm.maybe_pad_and_reduce(
|
||||
positions)
|
||||
positions = positions.squeeze(-1)
|
||||
hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
|
||||
hidden_states)
|
||||
|
||||
for layer_name in self.attn_layer_name:
|
||||
decode_metadata = getattr(attn_metadata[layer_name],
|
||||
"decode", None)
|
||||
if self.use_async_scheduling and decode_metadata is not None:
|
||||
actual_size = len(
|
||||
decode_metadata.actual_seq_lengths_q)
|
||||
|
||||
decode_metadata.seq_lens_list = \
|
||||
decode_metadata.seq_lens_list[:actual_size]
|
||||
decode_metadata.block_table = \
|
||||
decode_metadata.block_table[:actual_size]
|
||||
hidden_states, positions = self.maybe_pad_and_reduce(
|
||||
hidden_states, positions)
|
||||
|
||||
hidden_states = self.model(input_ids=input_ids,
|
||||
positions=positions,
|
||||
@@ -353,11 +328,8 @@ class MtpProposer(EagleProposer):
|
||||
self._update_full_graph_params(forward_context,
|
||||
num_input_tokens)
|
||||
|
||||
if self.enable_shared_expert_dp:
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
hidden_states.contiguous(), True)
|
||||
positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
positions.contiguous(), True)
|
||||
hidden_states, positions, _ = self.maybe_all_gather_and_unpad(
|
||||
hidden_states, positions)
|
||||
|
||||
num_indices = last_token_indices.shape[0]
|
||||
if lmhead_tp_enable():
|
||||
@@ -398,7 +370,7 @@ class MtpProposer(EagleProposer):
|
||||
if step == self.num_speculative_tokens - 1 or with_prefill:
|
||||
break
|
||||
|
||||
attn_metadata_i = attn_metadata[self.attn_layer_name[0]]
|
||||
attn_metadata_i = attn_metadata[self.attn_layer_names[0]]
|
||||
|
||||
if step == 0:
|
||||
positions = target_positions[last_token_indices]
|
||||
|
||||
Reference in New Issue
Block a user