[Feature] Support to use fullgraph with eagle (#5118)
### What this PR does / why we need it?
We support to use full graph with eagle.
Change list:
1. Distinguish between processing graph_params and draft_graph_params in
attention_v1.
2. Adapt the full-graph mode in eagle_proposer, include:
1). If use full graph, make Fullgraph Wrapper when load model.
2). Build a new meatadata, set running mode in FULL and mark attention
update in dummy_run when in Fullgraph mode.
3). Fixed and fill any attn_metadata, such as
attn_metadata.slot_mapping.
4). Add a descriptor.
5). Set running mode and triggered update metadata.
3. Trans is_mtp_model to is_draft_model, and add the update of
workspace.
NOTE:
When set async_scheduling=True, the draft model will enforce execution
in eager mode.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou <136800916+yiz-liu@users.noreply.github.com>
This commit is contained in:
@@ -8,6 +8,7 @@ from vllm.attention.layer import Attention
|
||||
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
|
||||
get_layers_from_vllm_config)
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
@@ -25,6 +26,8 @@ from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
||||
update_attn_params)
|
||||
from vllm_ascend.ops.rotary_embedding import update_cos_sin
|
||||
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
||||
|
||||
@@ -48,6 +51,8 @@ class EagleProposer(Proposer):
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.draft_model_config = self.speculative_config.draft_model_config
|
||||
self.method = self.speculative_config.method
|
||||
self.num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
|
||||
self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling
|
||||
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
# We need to get the hidden size from the draft model config because
|
||||
@@ -56,9 +61,17 @@ class EagleProposer(Proposer):
|
||||
self.hidden_size = vllm_config.speculative_config.draft_model_config.get_hidden_size(
|
||||
)
|
||||
|
||||
self.use_cuda_graph = (self.vllm_config.compilation_config.mode
|
||||
== CompilationMode.VLLM_COMPILE and
|
||||
not self.vllm_config.model_config.enforce_eager)
|
||||
# there is synchronization between mtp steps when enabling aclgraph,
|
||||
# disable aclgraph when use async scheduling to avoid the
|
||||
# synchronization overhead.
|
||||
# NOTE: we need to set aclgraph_runtime_mode to None in both dummy_run
|
||||
# and _propose.
|
||||
self.use_cuda_graph = (
|
||||
self.vllm_config.compilation_config.mode
|
||||
== CompilationMode.VLLM_COMPILE
|
||||
and not self.vllm_config.model_config.enforce_eager
|
||||
and not self.use_async_scheduling
|
||||
and not self.vllm_config.speculative_config.enforce_eager)
|
||||
|
||||
self.cudagraph_batch_sizes = list(
|
||||
sorted(
|
||||
@@ -74,8 +87,7 @@ class EagleProposer(Proposer):
|
||||
device=device,
|
||||
with_numpy=True,
|
||||
)
|
||||
self.decode_threshold = 1 + \
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
self.decode_threshold = 1 + self.num_speculative_tokens
|
||||
|
||||
# persistent buffers for cuda graph
|
||||
self.input_ids = torch.zeros(
|
||||
@@ -160,6 +172,19 @@ class EagleProposer(Proposer):
|
||||
else:
|
||||
self.model.lm_head = model.lm_head
|
||||
|
||||
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
|
||||
) and self.use_cuda_graph:
|
||||
self.update_stream = torch.npu.Stream()
|
||||
self.model = ACLGraphWrapper(self.model,
|
||||
self.vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
# get raw model out of the aclgraph wrapper.
|
||||
if isinstance(self.model, ACLGraphWrapper):
|
||||
return self.model.unwrap()
|
||||
return self.model
|
||||
|
||||
@torch.inference_mode()
|
||||
def dummy_run(self,
|
||||
num_tokens: int,
|
||||
@@ -174,16 +199,73 @@ class EagleProposer(Proposer):
|
||||
# update global cos, sin
|
||||
update_cos_sin(self.positions[:num_tokens])
|
||||
|
||||
with set_ascend_forward_context(None,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens):
|
||||
self.model(
|
||||
input_ids=self.input_ids[:num_tokens],
|
||||
positions=self.positions[:num_tokens],
|
||||
hidden_states=self.hidden_states[:num_tokens],
|
||||
attn_metadata = None
|
||||
if not self.use_cuda_graph:
|
||||
aclgraph_runtime_mode = CUDAGraphMode.NONE
|
||||
if aclgraph_runtime_mode == CUDAGraphMode.FULL and len(
|
||||
self.runner.attn_groups) > 0:
|
||||
num_computed_tokens_cpu = (
|
||||
self.runner.input_batch.
|
||||
num_computed_tokens_cpu_tensor[:num_reqs])
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=self.runner.query_start_loc.gpu[:num_reqs + 1],
|
||||
query_start_loc_cpu=self.runner.query_start_loc.cpu[:num_reqs +
|
||||
1],
|
||||
seq_lens_cpu=self.runner.seq_lens.cpu,
|
||||
seq_lens=self.runner.seq_lens.gpu[:num_reqs],
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=num_tokens,
|
||||
num_input_tokens=num_tokens,
|
||||
max_query_len=self.num_speculative_tokens + 1,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
||||
block_table_tensor=self.runner.input_batch.block_table[0].
|
||||
get_device_tensor()[:num_reqs],
|
||||
slot_mapping=self.runner.input_batch.block_table[0].
|
||||
slot_mapping.gpu,
|
||||
positions=self.runner.positions.gpu,
|
||||
attn_mask=self.runner.attn_mask,
|
||||
spec_attn_mask=self.runner.spec_attn_mask,
|
||||
attn_state=self.runner.attn_state,
|
||||
decode_token_per_req=self.runner.decode_token_per_req,
|
||||
max_seq_len=0,
|
||||
)
|
||||
dummy_compute_logits(self.hidden_states)
|
||||
|
||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
attn_metadata_eagle = builder.build_for_graph_capture(
|
||||
common_attn_metadata, AscendAttentionState.ChunkedPrefill)
|
||||
attn_metadata = {}
|
||||
for layer_name in [self.attn_layer_name]:
|
||||
attn_metadata[layer_name] = attn_metadata_eagle
|
||||
for i in range(self.num_speculative_tokens):
|
||||
if i > 0 and in_graph_capturing and aclgraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
aclgraph_runtime_mode = CUDAGraphMode.NONE
|
||||
with set_ascend_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
num_actual_tokens=0,
|
||||
in_profile_run=is_profile,
|
||||
batch_descriptor=batch_descriptor,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
is_draft_model=True):
|
||||
forward_context = get_forward_context()
|
||||
self.model(
|
||||
input_ids=self.input_ids[:num_tokens],
|
||||
positions=self.positions[:num_tokens],
|
||||
hidden_states=self.hidden_states[:num_tokens],
|
||||
)
|
||||
if (forward_context.cudagraph_runtime_mode
|
||||
== CUDAGraphMode.FULL
|
||||
and not forward_context.capturing):
|
||||
update_attn_params(
|
||||
self.update_stream,
|
||||
forward_context,
|
||||
num_tokens,
|
||||
self.vllm_config,
|
||||
)
|
||||
|
||||
def generate_token_ids(self,
|
||||
sampled_token_ids: torch.Tensor | list[list[int]],
|
||||
sampling_metadata: SamplingMetadata = None,
|
||||
@@ -343,7 +425,7 @@ class EagleProposer(Proposer):
|
||||
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
|
||||
|
||||
if self.name == SpecDcodeType.EAGLE3:
|
||||
assert isinstance(self.model, Eagle3LlamaForCausalLM)
|
||||
assert isinstance(self.get_model(), Eagle3LlamaForCausalLM)
|
||||
target_hidden_states = self.model.combine_hidden_states(
|
||||
target_hidden_states)
|
||||
assert target_hidden_states.shape[-1] == self.hidden_size
|
||||
@@ -361,6 +443,14 @@ class EagleProposer(Proposer):
|
||||
else:
|
||||
num_input_tokens = num_tokens
|
||||
|
||||
has_lora = len(self.runner.input_batch.lora_id_to_lora_request) > 0
|
||||
if self.use_cuda_graph:
|
||||
aclgraph_runtime_mode, batch_descriptor = \
|
||||
self.runner.cudagraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=True, has_lora=has_lora)
|
||||
else:
|
||||
aclgraph_runtime_mode = CUDAGraphMode.NONE
|
||||
batch_descriptor = None
|
||||
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.positions[:num_tokens] = target_positions
|
||||
self.hidden_states[:num_tokens] = target_hidden_states
|
||||
@@ -372,27 +462,40 @@ class EagleProposer(Proposer):
|
||||
# update global cos, sin
|
||||
update_cos_sin(self.positions[:num_input_tokens])
|
||||
|
||||
with set_ascend_forward_context(attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens):
|
||||
with set_ascend_forward_context(
|
||||
{self.attn_layer_name: attn_metadata},
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_actual_tokens=num_tokens,
|
||||
batch_descriptor=batch_descriptor,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
is_draft_model=True):
|
||||
last_hidden_states, hidden_states = self.model(
|
||||
input_ids=self.input_ids[:num_input_tokens],
|
||||
positions=self.positions[:num_input_tokens],
|
||||
hidden_states=self.hidden_states[:num_input_tokens],
|
||||
)
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
# TODO: support mla in future.
|
||||
update_attn_params(
|
||||
self.update_stream,
|
||||
forward_context,
|
||||
num_input_tokens,
|
||||
self.vllm_config,
|
||||
)
|
||||
sample_hidden_states = last_hidden_states[last_token_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
|
||||
# Early exit if there is only one draft token to be generated.
|
||||
if self.vllm_config.speculative_config.num_speculative_tokens == 1:
|
||||
if self.num_speculative_tokens == 1:
|
||||
# [batch_size, 1]
|
||||
return draft_token_ids.view(-1, 1)
|
||||
|
||||
# Generate the remaining draft tokens.
|
||||
draft_token_ids_tensor = torch.zeros(
|
||||
(self.vllm_config.speculative_config.num_speculative_tokens,
|
||||
*draft_token_ids.shape),
|
||||
(self.num_speculative_tokens, *draft_token_ids.shape),
|
||||
dtype=draft_token_ids.dtype,
|
||||
device=self.device)
|
||||
draft_token_ids_tensor[0] = draft_token_ids
|
||||
@@ -417,9 +520,13 @@ class EagleProposer(Proposer):
|
||||
1:].tolist()
|
||||
attn_metadata.seq_lens_list = attn_metadata.seq_lens.tolist()
|
||||
attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
|
||||
for now_speculative in range(
|
||||
self.vllm_config.speculative_config.num_speculative_tokens -
|
||||
1):
|
||||
if self.use_cuda_graph:
|
||||
aclgraph_runtime_mode, batch_descriptor = \
|
||||
self.runner.cudagraph_dispatcher.dispatch(num_tokens=input_batch_size, uniform_decode=True, has_lora=has_lora)
|
||||
else:
|
||||
aclgraph_runtime_mode = CUDAGraphMode.NONE
|
||||
batch_descriptor = None
|
||||
for now_speculative in range(self.num_speculative_tokens - 1):
|
||||
# Update the inputs.
|
||||
# cast to int32 is crucial when eagle model is compiled.
|
||||
# tensor.argmax() returns int64 by default.
|
||||
@@ -467,6 +574,8 @@ class EagleProposer(Proposer):
|
||||
# NOTE: ASCEND slot_mapping must on cpu
|
||||
attn_metadata.slot_mapping[:slot_mapping_tmp.shape[0]].copy_(
|
||||
slot_mapping_tmp.to(torch.int32))
|
||||
attn_metadata.slot_mapping[slot_mapping_tmp.shape[0]:].fill_(
|
||||
PADDING_SLOT_ID)
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.input_ids[:batch_size] = input_ids
|
||||
self.positions[:batch_size] = clamped_positions
|
||||
@@ -474,20 +583,33 @@ class EagleProposer(Proposer):
|
||||
attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||
|
||||
attn_metadata.attn_mask = attn_mask
|
||||
# Run the model.
|
||||
|
||||
# update global cos, sin
|
||||
update_cos_sin(self.positions[:input_batch_size])
|
||||
|
||||
with set_ascend_forward_context(attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=input_batch_size):
|
||||
# Run the model.
|
||||
with set_ascend_forward_context(
|
||||
{self.attn_layer_name: attn_metadata},
|
||||
self.vllm_config,
|
||||
num_tokens=input_batch_size,
|
||||
num_actual_tokens=batch_size,
|
||||
batch_descriptor=batch_descriptor,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
is_draft_model=True):
|
||||
|
||||
last_hidden_states, hidden_states = self.model(
|
||||
input_ids=self.input_ids[:input_batch_size],
|
||||
positions=self.positions[:input_batch_size],
|
||||
hidden_states=self.hidden_states[:input_batch_size],
|
||||
)
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
update_attn_params(
|
||||
self.update_stream,
|
||||
forward_context,
|
||||
input_batch_size,
|
||||
self.vllm_config,
|
||||
)
|
||||
hidden_states = hidden_states[:batch_size]
|
||||
logits = self.model.compute_logits(last_hidden_states[:batch_size])
|
||||
|
||||
@@ -719,7 +841,7 @@ class EagleProposer(Proposer):
|
||||
common_attn_metadata.slot_mapping[token_indices])
|
||||
common_attn_metadata.slot_mapping[token_indices.shape[0]:].fill_(-1)
|
||||
|
||||
# NOTE: Currently positions and seq_lens are not used in mla_v1 forward
|
||||
# NOTE: Currently positions and seq_lens are not used in attn forward
|
||||
# so we do not need to fixed them. But if they are used in the future,
|
||||
# we should fixed them.
|
||||
spec_common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
@@ -779,7 +901,7 @@ class EagleProposer(Proposer):
|
||||
total_num_tokens = query_start_loc_cpu[-1].item()
|
||||
token_indices = self.arange[:total_num_tokens]
|
||||
|
||||
# NOTE: Currently positions and seq_lens are not used in mla_v1 forward
|
||||
# NOTE: Currently positions and seq_lens are not used in attn forward
|
||||
# so we do not need to fixed them. But if they are used in the future,
|
||||
# we should fixed them.
|
||||
spec_common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
@@ -803,7 +925,8 @@ class EagleProposer(Proposer):
|
||||
seq_lens=common_attn_metadata.seq_lens,
|
||||
max_seq_len=0)
|
||||
|
||||
token_indices_to_sample = (common_attn_metadata.query_start_loc[1:] -
|
||||
1 - num_rejected_tokens_gpu)
|
||||
query_start_loc = common_attn_metadata.query_start_loc[
|
||||
1:1 + num_rejected_tokens_gpu.shape[0]]
|
||||
token_indices_to_sample = query_start_loc - 1 - num_rejected_tokens_gpu
|
||||
|
||||
return spec_common_attn_metadata, token_indices, token_indices_to_sample
|
||||
|
||||
Reference in New Issue
Block a user