[Feat][Bugfix][main] Adapted SP to eagle3 (#5562)
### What this PR does / why we need it?
Adapted sp to eagle3.
There may still be some problems, e.g., accuracy in some scenes,
`sp`+`dp`...
We will fix them later.
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
We tested it mainly in a new `e2e`.
```shell
pytest -s tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py::test_llama_qwen_eagle_acceptance
```
```text
.
=============================== warnings summary ===============================
<frozen importlib._bootstrap>:241
<frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute
<frozen importlib._bootstrap>:241
<frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============= 3 passed, 1 skipped, 2 warnings in 142.05s (0:02:22) =============
```
It passed.
- vLLM version: v0.13.0
- vLLM main:
7157596103
Signed-off-by: drslark <slarksblood@qq.com>
This commit is contained in:
@@ -225,7 +225,6 @@ class EagleProposer(VllmEagleProposer):
|
||||
decode_token_per_req=self.runner.decode_token_per_req,
|
||||
max_seq_len=0,
|
||||
)
|
||||
dummy_compute_logits(self.hidden_states)
|
||||
|
||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
attn_metadata_eagle = builder.build_for_graph_capture(
|
||||
@@ -233,6 +232,10 @@ class EagleProposer(VllmEagleProposer):
|
||||
attn_metadata = {}
|
||||
for layer_name in self.attn_layer_name:
|
||||
attn_metadata[layer_name] = attn_metadata_eagle
|
||||
|
||||
model_input_ids = self.input_ids[:num_tokens]
|
||||
model_positions = self.positions[:num_tokens]
|
||||
model_previous_hidden_states = self.hidden_states[:num_tokens]
|
||||
for i in range(self.num_speculative_tokens):
|
||||
if i > 0 and in_graph_capturing and aclgraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
aclgraph_runtime_mode = CUDAGraphMode.NONE
|
||||
@@ -245,12 +248,17 @@ class EagleProposer(VllmEagleProposer):
|
||||
batch_descriptor=batch_descriptor,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
is_draft_model=True):
|
||||
forward_context = get_forward_context()
|
||||
|
||||
if self.enable_shared_expert_dp:
|
||||
model_previous_hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
|
||||
model_previous_hidden_states)
|
||||
|
||||
self.model(
|
||||
input_ids=self.input_ids[:num_tokens],
|
||||
positions=self.positions[:num_tokens],
|
||||
hidden_states=self.hidden_states[:num_tokens],
|
||||
input_ids=model_input_ids,
|
||||
positions=model_positions,
|
||||
hidden_states=model_previous_hidden_states,
|
||||
)
|
||||
forward_context = get_forward_context()
|
||||
if (forward_context.cudagraph_runtime_mode
|
||||
== CUDAGraphMode.FULL
|
||||
and not forward_context.capturing):
|
||||
@@ -261,6 +269,12 @@ class EagleProposer(VllmEagleProposer):
|
||||
self.vllm_config,
|
||||
)
|
||||
|
||||
if self.enable_shared_expert_dp:
|
||||
model_previous_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
model_previous_hidden_states, True)
|
||||
|
||||
dummy_compute_logits(self.hidden_states)
|
||||
|
||||
def _propose(
|
||||
self,
|
||||
# [num_tokens]
|
||||
@@ -338,10 +352,24 @@ class EagleProposer(VllmEagleProposer):
|
||||
batch_descriptor=batch_descriptor,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
is_draft_model=True):
|
||||
|
||||
# The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative tokens' proposings.
|
||||
# `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the inputs of speculative model.
|
||||
model_input_ids = self.input_ids[:num_input_tokens]
|
||||
model_positions = self.positions[:num_input_tokens]
|
||||
model_hidden_states = self.hidden_states[:num_input_tokens]
|
||||
|
||||
if self.enable_shared_expert_dp:
|
||||
# split hidden states along sequence dimension
|
||||
# positions should not be split?
|
||||
model_hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
|
||||
model_hidden_states)
|
||||
# in acl-graph, `model_hidden_states` should be copy back to `self.hidden_states`?
|
||||
|
||||
last_hidden_states, hidden_states = self.model(
|
||||
input_ids=self.input_ids[:num_input_tokens],
|
||||
positions=self.positions[:num_input_tokens],
|
||||
hidden_states=self.hidden_states[:num_input_tokens],
|
||||
input_ids=model_input_ids,
|
||||
positions=model_positions,
|
||||
hidden_states=model_hidden_states,
|
||||
)
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
@@ -352,6 +380,14 @@ class EagleProposer(VllmEagleProposer):
|
||||
num_input_tokens,
|
||||
self.vllm_config,
|
||||
)
|
||||
|
||||
if self.enable_shared_expert_dp:
|
||||
# merge hidden states along sequence dimension
|
||||
last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
last_hidden_states.contiguous(), True)
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
hidden_states.contiguous(), True)
|
||||
|
||||
sample_hidden_states = last_hidden_states[last_token_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
@@ -470,10 +506,23 @@ class EagleProposer(VllmEagleProposer):
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
is_draft_model=True):
|
||||
|
||||
# The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative tokens' proposings.
|
||||
# `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the inputs of speculative model.
|
||||
model_input_ids = self.input_ids[:input_batch_size]
|
||||
model_positions = self.positions[:input_batch_size]
|
||||
model_hidden_states = self.hidden_states[:input_batch_size]
|
||||
|
||||
if self.enable_shared_expert_dp:
|
||||
# split hidden states along sequence dimension
|
||||
# positions should not be split?
|
||||
model_hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
|
||||
model_hidden_states)
|
||||
# in acl-graph, `model_hidden_states` should be copy back to `self.hidden_states`?
|
||||
|
||||
last_hidden_states, hidden_states = self.model(
|
||||
input_ids=self.input_ids[:input_batch_size],
|
||||
positions=self.positions[:input_batch_size],
|
||||
hidden_states=self.hidden_states[:input_batch_size],
|
||||
input_ids=model_input_ids,
|
||||
positions=model_positions,
|
||||
hidden_states=model_hidden_states,
|
||||
)
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
@@ -483,6 +532,14 @@ class EagleProposer(VllmEagleProposer):
|
||||
input_batch_size,
|
||||
self.vllm_config,
|
||||
)
|
||||
|
||||
if self.enable_shared_expert_dp:
|
||||
# merge hidden states along sequence dimension
|
||||
last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
last_hidden_states.contiguous(), True)
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
hidden_states.contiguous(), True)
|
||||
|
||||
hidden_states = hidden_states[:batch_size]
|
||||
logits = self.model.compute_logits(last_hidden_states[:batch_size])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user