[feat][spec decode]Unified draft parallel (#6766)
### What this PR does / why we need it?
Implement a unified parallelized speculative decoding in VLLM
Ascend,which can simultaneously support parallel speculative inference
schemes such as Pard, P-Eagle, etc. refer to
https://github.com/vllm-project/vllm-ascend/pull/6565 and
https://github.com/vllm-project/vllm-ascend/pull/4078
### How was this patch tested?
run with parallel drafting script:
export target=/model/Llama-3.1-8B-Instruct
export draft=/model/PARD-Llama-3.2-1B
export CUDA_VISIBLE_DEVICES=6
export ASCEND_RT_VISIBLE_DEVICES=6
vllm serve $target \
--tensor-parallel-size 1 \
--max-model-len 4096 \
--no-enable-prefix-caching \
--port 8811 \
--speculative-config '{"model": "/model/PARD-Llama-3.2-1B", "method":
"draft_model", "num_speculative_tokens": 8, "parallel_drafting": true}'
base script:
export target=/model/Llama-3.1-8B-Instruct
export draft=/model/PARD-Llama-3.2-1B
export CUDA_VISIBLE_DEVICES=6
export ASCEND_RT_VISIBLE_DEVICES=6
vllm serve $target \
--tensor-parallel-size 1 \
--max-model-len 4096 \
--no-enable-prefix-caching \
--port 8811
benchmark script:
MAX_CONCURRENCY=1
NUM_PROMPTS=80
vllm bench serve --port 8811 \
--temperature 0 \
--model /model/Llama-3.1-8B-Instruct \
--backend openai-chat \
--endpoint /v1/chat/completions \
--dataset-name hf \
--dataset-path philschmid/mt-bench \
--num-prompts ${NUM_PROMPTS} \
--max-concurrency ${MAX_CONCURRENCY} \
--seed 1234
test results :
base(without spec decode): TTFT 79.46ms TPOT 26.99ms
output_tokens_throughput 36.75 tok/s
this pr(with parallel drafting): TTFT 72.24ms TPOT 13.45ms
output_tokens_throughput 72.98 tok/s
per-position acceptance(from position 0 to 7):
79.48%、56.93%、40%、27.90%、19.79%、14.25%、10.57%、7.61%.
----------------------------------------------------------------------
run on qwen3 model script :
export target=/model/Qwen3-1.7B
export draft=/model/PARD-Qwen3-0.6B
export CUDA_VISIBLE_DEVICES=1
export ASCEND_RT_VISIBLE_DEVICES=1
vllm serve $target \
--tensor-parallel-size 1 \
--max-model-len 4096 \
--no-enable-prefix-caching \
--port 8811 \
--speculative-config '{"model": "/model/PARD-Qwen3-0.6B", "method":
"draft_model", "num_speculative_tokens": 8, "parallel_drafting": true}'
cc @NickJudyHvv
- vLLM version: v0.15.0
- vLLM main:
9562912cea
---------
Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
Signed-off-by: kx <1670186653@qq.com>
Signed-off-by: HF-001 <1670186653@qq.com>
Co-authored-by: 01267596 <xiongkai123@cmbchina.com>
This commit is contained in:
@@ -284,6 +284,9 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
||||
if isinstance(self.kv_cache_spec, CrossAttentionSpec):
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
slot_mapping = common_attn_metadata.slot_mapping.to(torch.int32)
|
||||
elif self.speculative_config and self.speculative_config.parallel_drafting:
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
|
||||
attn_state = common_attn_metadata.attn_state
|
||||
|
||||
# Get attn_mask and swa_mask from singleton AttentionMaskBuilder
|
||||
|
||||
@@ -24,6 +24,7 @@ def prepare_inputs_padded_kernel(
|
||||
valid_sampled_tokens_count_ptr, # [num_reqs]
|
||||
query_start_loc_gpu_ptr, # [num_reqs + 1]
|
||||
token_indices_to_sample_ptr, # [num_reqs] (output)
|
||||
num_rejected_tokens_gpu_ptr,
|
||||
num_reqs, # tl.int32
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
@@ -61,3 +62,4 @@ def prepare_inputs_padded_kernel(
|
||||
|
||||
index_to_sample = q_last_tok_idx - num_rejected
|
||||
tl.store(token_indices_to_sample_ptr + offsets, index_to_sample, mask=mask)
|
||||
tl.store(num_rejected_tokens_gpu_ptr + offsets, num_rejected, mask=mask)
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
|
||||
#
|
||||
|
||||
from vllm_ascend.spec_decode.draft_proposer import AscendDraftModelProposer
|
||||
from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer
|
||||
from vllm_ascend.spec_decode.medusa_proposer import AscendMedusaProposer
|
||||
from vllm_ascend.spec_decode.ngram_proposer import AscendNgramProposer
|
||||
@@ -31,5 +33,7 @@ def get_spec_decode_method(method, vllm_config, device, runner):
|
||||
return AscendMedusaProposer(vllm_config, device)
|
||||
elif method in ("eagle", "eagle3", "mtp"):
|
||||
return AscendEagleProposer(vllm_config, device, runner)
|
||||
elif method == "draft_model":
|
||||
return AscendDraftModelProposer(vllm_config, device, runner)
|
||||
else:
|
||||
raise ValueError(f"Unknown speculative decoding method: {method}")
|
||||
|
||||
71
vllm_ascend/spec_decode/draft_proposer.py
Normal file
71
vllm_ascend/spec_decode/draft_proposer.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing_extensions import override
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.v1.spec_decode.utils import create_vllm_config_for_draft_model
|
||||
|
||||
from vllm_ascend.spec_decode.eagle_proposer import SpecDecodeBaseProposer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AscendDraftModelProposer(SpecDecodeBaseProposer):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
runner=None,
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
device=device,
|
||||
pass_hidden_states_to_model=False,
|
||||
runner=runner,
|
||||
)
|
||||
self._raise_if_vocab_size_mismatch()
|
||||
self._raise_if_draft_tp_mismatch()
|
||||
|
||||
def _raise_if_vocab_size_mismatch(self):
|
||||
self.speculative_config.verify_equal_vocab_size_if_draft_model()
|
||||
|
||||
def _raise_if_draft_tp_mismatch(self):
|
||||
# Note(Tomas Ruiz) If we run the target model with TP > 1 and
|
||||
# the draft model with TP = 1, then the different TP ranks collide.
|
||||
# Specifically when all ranks compile the draft model on rank 0
|
||||
# (because TP=1), then the torch compile cache is overwritten and corrupted.
|
||||
# We need a mechanism like this: https://github.com/vllm-project/vllm/pull/5414
|
||||
# To prevent this error, we assert that both TP sizes must be the same.
|
||||
spec_cfg = self.speculative_config
|
||||
tgt_tp = spec_cfg.target_parallel_config.tensor_parallel_size
|
||||
draft_tp = spec_cfg.draft_parallel_config.tensor_parallel_size
|
||||
if draft_tp != tgt_tp:
|
||||
raise ValueError(
|
||||
f"Currently, 'draft_tensor_parallel_size' and 'tensor_parallel_size' "
|
||||
f"must be the same. Got {draft_tp} and {tgt_tp}. "
|
||||
"Please pass 'draft_tensor_parallel_size' in the speculative_config."
|
||||
)
|
||||
|
||||
def _get_model(self) -> nn.Module:
|
||||
# Draft models may be quantized or on different parallelism,
|
||||
# so we load them with a modified vllm config
|
||||
from vllm.compilation.backends import set_model_tag
|
||||
|
||||
temp_vllm_config = create_vllm_config_for_draft_model(self.vllm_config)
|
||||
with set_model_tag("draft_model"):
|
||||
model = get_model(
|
||||
vllm_config=temp_vllm_config,
|
||||
prefix="draft_model",
|
||||
)
|
||||
return model
|
||||
|
||||
@override
|
||||
def _maybe_share_embeddings(self, target_language_model: nn.Module) -> None:
|
||||
# Draft models don't share embeddings with the target model
|
||||
pass
|
||||
|
||||
@override
|
||||
def _maybe_share_lm_head(self, target_language_model: nn.Module) -> None:
|
||||
# Draft models don't share lm_head with the target model
|
||||
pass
|
||||
@@ -30,8 +30,13 @@ from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, EagleProposer
|
||||
from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.utils import (
|
||||
PADDING_SLOT_ID,
|
||||
compute_new_slot_mapping,
|
||||
extend_all_queries_by_N,
|
||||
)
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, set_ascend_forward_context
|
||||
@@ -80,14 +85,14 @@ def split_inputs_tp_to_sp(hidden_states, out):
|
||||
return out[:padded_num_tokens_per_rank]
|
||||
|
||||
|
||||
class AscendEagleProposer(EagleProposer):
|
||||
class SpecDecodeBaseProposer(EagleProposer):
|
||||
_runnable: ACLGraphWrapper | Callable
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None):
|
||||
def __init__(self, vllm_config: VllmConfig, device: torch.device, pass_hidden_states_to_model: bool, runner=None):
|
||||
super().__init__(vllm_config, device, runner)
|
||||
|
||||
self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling
|
||||
|
||||
self.pass_hidden_states_to_model = pass_hidden_states_to_model
|
||||
self.decode_threshold = 1 + self.num_speculative_tokens
|
||||
self.query_start_loc = self.runner._make_buffer(self.runner.max_num_reqs + 2, dtype=torch.int32)
|
||||
self.arange_cpu = torch.arange(self.arange.shape[0], device="cpu", dtype=torch.int32)
|
||||
@@ -140,7 +145,7 @@ class AscendEagleProposer(EagleProposer):
|
||||
if not self.use_cuda_graph and enable_sp(vllm_config):
|
||||
self.maybe_eager_context = _maybe_eager_context(vllm_config)
|
||||
|
||||
self.last_token_indices = torch.zeros(
|
||||
self.token_indices_to_sample = torch.zeros(
|
||||
self.vllm_config.scheduler_config.max_num_batched_tokens, dtype=torch.int32, device=device
|
||||
)
|
||||
slot_mapping_lens = self.runner.max_num_tokens + 2 * self.pcp_size * self.runner.max_num_reqs
|
||||
@@ -150,15 +155,38 @@ class AscendEagleProposer(EagleProposer):
|
||||
]
|
||||
|
||||
self._runnable = self._run_merged_draft
|
||||
if self.uses_mrope:
|
||||
self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1), dtype=torch.int32, device=device)
|
||||
elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
|
||||
self.xdrope_positions = torch.zeros(
|
||||
(self.uses_xdrope_dim, self.max_num_tokens + 1),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
# RoPE need (max_num_tokens,)
|
||||
self.positions = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=device)
|
||||
|
||||
def _get_model(self) -> nn.Module:
|
||||
"""
|
||||
Default method to call get_model(). Can be overridden by subclasses which
|
||||
need to customize model loading.
|
||||
"""
|
||||
from vllm.compilation.backends import set_model_tag
|
||||
|
||||
with set_model_tag("eagle_head"):
|
||||
model = get_model(
|
||||
vllm_config=self.vllm_config,
|
||||
model_config=self.vllm_config.speculative_config.draft_model_config,
|
||||
)
|
||||
return model
|
||||
|
||||
def load_model(self, model: nn.Module) -> None:
|
||||
target_attn_layer_names = set(get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys())
|
||||
target_indexer_layer_names = set(get_layers_from_vllm_config(self.vllm_config, DeepseekV32IndexerCache).keys())
|
||||
|
||||
with self.maybe_eager_context:
|
||||
self.model = get_model(
|
||||
vllm_config=self.vllm_config, model_config=self.vllm_config.speculative_config.draft_model_config
|
||||
)
|
||||
self.model = self._get_model()
|
||||
|
||||
indexer_layers = get_layers_from_vllm_config(self.vllm_config, DeepseekV32IndexerCache).keys()
|
||||
draft_attn_layers_dict = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
|
||||
@@ -167,7 +195,7 @@ class AscendEagleProposer(EagleProposer):
|
||||
draft_attn_layer_names = draft_attn_layers - target_attn_layer_names
|
||||
draft_indexer_layer_names = indexer_layers - target_indexer_layer_names
|
||||
draft_attn_layer_names = draft_attn_layer_names - draft_indexer_layer_names
|
||||
assert len(draft_attn_layer_names) == 1
|
||||
|
||||
self.attn_layer_names = list(sorted(draft_attn_layer_names))
|
||||
|
||||
self.kernel_block_size = (
|
||||
@@ -202,6 +230,24 @@ class AscendEagleProposer(EagleProposer):
|
||||
target_language_model = model
|
||||
|
||||
# share embed_tokens with the target model if needed
|
||||
self._maybe_share_embeddings(target_language_model)
|
||||
self._maybe_share_lm_head(model)
|
||||
|
||||
if self.parallel_drafting and self.pass_hidden_states_to_model:
|
||||
assert self.parallel_drafting_hidden_state_tensor is not None
|
||||
self.parallel_drafting_hidden_state_tensor.copy_(
|
||||
self.model.combine_hidden_states(self.model.mask_hidden.view(3 * self.hidden_size))
|
||||
if self.eagle3_use_aux_hidden_state
|
||||
else self.model.mask_hidden.view(self.hidden_size)
|
||||
)
|
||||
|
||||
def _maybe_share_embeddings(self, target_language_model: nn.Module) -> None:
|
||||
"""
|
||||
Some draft models may not have their own embedding layers, and some may
|
||||
have a duplicate copy of the target model's embedding layers. In these cases,
|
||||
we share the target model's embedding layers with the draft model to save
|
||||
memory.
|
||||
"""
|
||||
if get_pp_group().world_size == 1:
|
||||
if hasattr(target_language_model.model, "embed_tokens"):
|
||||
target_embed_tokens = target_language_model.model.embed_tokens
|
||||
@@ -256,7 +302,9 @@ class AscendEagleProposer(EagleProposer):
|
||||
"Since PP > 1 or other reasons the model head loaded its own vocab embedding"
|
||||
" weights instead of sharing them with the target model."
|
||||
)
|
||||
# share lm_head with the target model if needed
|
||||
|
||||
# share lm_head with the target model if needed
|
||||
def _maybe_share_lm_head(self, model: nn.Module) -> None:
|
||||
# some model definition do not define lm_head explicitly
|
||||
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
|
||||
if self.method == "eagle" and hasattr(model, "lm_head"):
|
||||
@@ -389,7 +437,7 @@ class AscendEagleProposer(EagleProposer):
|
||||
self._runnable(
|
||||
num_input_tokens=num_tokens,
|
||||
batch_size=batch_size,
|
||||
last_token_indices=self.last_token_indices[:batch_size],
|
||||
token_indices_to_sample=self.token_indices_to_sample[: batch_size * self.extra_slots_per_request],
|
||||
# The target_position's address is same as the model_positions's
|
||||
target_positions=model_positions,
|
||||
inputs_embeds=None,
|
||||
@@ -411,7 +459,7 @@ class AscendEagleProposer(EagleProposer):
|
||||
target_hidden_states: torch.Tensor,
|
||||
# [batch_size]
|
||||
next_token_ids: torch.Tensor,
|
||||
last_token_indices: torch.Tensor | None,
|
||||
token_indices_to_sample: torch.Tensor | None,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
|
||||
@@ -421,31 +469,34 @@ class AscendEagleProposer(EagleProposer):
|
||||
num_decode_reqs=0,
|
||||
scheduler_output: SchedulerOutput = None,
|
||||
num_scheduled_tokens: int = 0,
|
||||
num_rejected_tokens_gpu: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
num_tokens = target_token_ids.shape[0]
|
||||
batch_size = next_token_ids.shape[0]
|
||||
batch_size = common_attn_metadata.batch_size()
|
||||
|
||||
if last_token_indices is None:
|
||||
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
|
||||
if token_indices_to_sample is None:
|
||||
token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1
|
||||
|
||||
if self.method == "eagle3":
|
||||
assert isinstance(self.get_model(), Eagle3LlamaForCausalLM)
|
||||
target_hidden_states = self.model.combine_hidden_states(target_hidden_states)
|
||||
assert target_hidden_states.shape[-1] == self.hidden_size
|
||||
|
||||
# Shift the input ids by one token.
|
||||
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
|
||||
self.input_ids[: num_tokens - 1] = target_token_ids[1:]
|
||||
# Replace the last token with the next token.
|
||||
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
||||
self.input_ids[last_token_indices] = next_token_ids
|
||||
num_tokens, token_indices_to_sample, common_attn_metadata = self.set_inputs_first_pass(
|
||||
target_token_ids=target_token_ids,
|
||||
next_token_ids=next_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
token_indices_to_sample=token_indices_to_sample,
|
||||
cad=common_attn_metadata,
|
||||
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
|
||||
)
|
||||
|
||||
assert self.runner is not None
|
||||
# update pcp related params
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
assert long_seq_metadata is not None
|
||||
common_attn_metadata.prefill_context_parallel_metadata = long_seq_metadata
|
||||
ori_last_token_indices = last_token_indices.clone()
|
||||
ori_token_indices_to_sample = token_indices_to_sample.clone()
|
||||
query_lens_d = self.runner.query_lens[:num_decode_reqs]
|
||||
if self.pcp_size > 1:
|
||||
# 1. preprocess decode/prefill input_ids & target_hidden_states
|
||||
@@ -484,9 +535,11 @@ class AscendEagleProposer(EagleProposer):
|
||||
target_hidden_states = torch.cat([target_hidden_states_d, target_hidden_states_p], dim=0)
|
||||
# 2. update sample_indices according to main model
|
||||
if num_decode_reqs:
|
||||
last_token_indices[:num_decode_reqs] = self.runner.logits_indices[last_token_indices[:num_decode_reqs]]
|
||||
token_indices_to_sample[:num_decode_reqs] = self.runner.logits_indices[
|
||||
token_indices_to_sample[:num_decode_reqs]
|
||||
]
|
||||
if num_prefill_reqs:
|
||||
last_token_indices[-num_prefill_reqs:] = self.runner.logits_indices[-num_prefill_reqs:]
|
||||
token_indices_to_sample[-num_prefill_reqs:] = self.runner.logits_indices[-num_prefill_reqs:]
|
||||
# 3. update attn_metadata params that may be influenced by pcp
|
||||
common_attn_metadata.num_actual_tokens = num_tokens
|
||||
common_attn_metadata.max_query_len = max(self.decode_threshold, max_query_len_p)
|
||||
@@ -530,10 +583,6 @@ class AscendEagleProposer(EagleProposer):
|
||||
aclgraph_runtime_mode = CUDAGraphMode.NONE
|
||||
batch_descriptor = None
|
||||
|
||||
# copy inputs to buffer for cudagraph
|
||||
self._set_positions(num_tokens, target_positions)
|
||||
self.hidden_states[:num_tokens] = target_hidden_states
|
||||
|
||||
if self.supports_mm_inputs:
|
||||
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
|
||||
inputs_embeds = self.model.embed_input_ids(
|
||||
@@ -559,15 +608,16 @@ class AscendEagleProposer(EagleProposer):
|
||||
attn_metadata = builder.build(0, common_attn_metadata, self.runner.get_model())
|
||||
|
||||
if self.uses_mrope:
|
||||
used_update_positions = target_positions[:, last_token_indices]
|
||||
used_update_positions = self.mrope_positions[:, token_indices_to_sample]
|
||||
else:
|
||||
used_update_positions = target_positions[last_token_indices]
|
||||
used_update_positions = self.positions[token_indices_to_sample]
|
||||
per_layer_attn_metadata = dict()
|
||||
# The first step of speculative.
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
multi_steps_attn_metadata = [per_layer_attn_metadata]
|
||||
|
||||
# Copy the old attn_metadata and update
|
||||
attn_metadata_i = per_layer_attn_metadata[self.attn_layer_names[0]]
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
if self.num_speculative_tokens > 1 and not attn_metadata_i.num_prefills:
|
||||
@@ -578,7 +628,7 @@ class AscendEagleProposer(EagleProposer):
|
||||
# to get corresponding slot_mapping in each step.
|
||||
num_reject_tokens = (
|
||||
torch.tensor(self.runner.pcp_manager.cu_num_tokens_pcp_full, dtype=torch.int32).to(self.device)
|
||||
- ori_last_token_indices
|
||||
- ori_token_indices_to_sample
|
||||
- 1
|
||||
)
|
||||
num_accept_tokens = query_lens_d.to(self.device) - num_reject_tokens
|
||||
@@ -616,6 +666,27 @@ class AscendEagleProposer(EagleProposer):
|
||||
common_attn_metadata.block_table_tensor = common_attn_metadata.block_table_tensor[:batch_size]
|
||||
|
||||
# Copy the old attn_metadata and update
|
||||
if not self.parallel_drafting:
|
||||
for draft_step in range(1, self.num_speculative_tokens):
|
||||
common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm(
|
||||
draft_step,
|
||||
attn_metadata,
|
||||
common_attn_metadata,
|
||||
batch_size,
|
||||
num_input_tokens,
|
||||
used_update_positions,
|
||||
aclgraph_runtime_mode,
|
||||
ori_seq_len,
|
||||
slot_indices,
|
||||
mtp_slot_mapping,
|
||||
)
|
||||
per_layer_attn_metadata = dict()
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
multi_steps_attn_metadata.append(per_layer_attn_metadata)
|
||||
else:
|
||||
# Copy the old attn_metadata and update
|
||||
if not self.parallel_drafting:
|
||||
for draft_step in range(1, self.num_speculative_tokens):
|
||||
common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm(
|
||||
draft_step,
|
||||
@@ -625,33 +696,14 @@ class AscendEagleProposer(EagleProposer):
|
||||
num_input_tokens,
|
||||
used_update_positions,
|
||||
aclgraph_runtime_mode,
|
||||
ori_seq_len,
|
||||
slot_indices,
|
||||
mtp_slot_mapping,
|
||||
)
|
||||
per_layer_attn_metadata = dict()
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
multi_steps_attn_metadata.append(per_layer_attn_metadata)
|
||||
else:
|
||||
# Copy the old attn_metadata and update
|
||||
for draft_step in range(1, self.num_speculative_tokens):
|
||||
common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm(
|
||||
draft_step,
|
||||
attn_metadata,
|
||||
common_attn_metadata,
|
||||
batch_size,
|
||||
num_input_tokens,
|
||||
used_update_positions,
|
||||
aclgraph_runtime_mode,
|
||||
)
|
||||
per_layer_attn_metadata = dict()
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
multi_steps_attn_metadata.append(per_layer_attn_metadata)
|
||||
|
||||
last_token_indices_len = last_token_indices.shape[0]
|
||||
self.last_token_indices[:last_token_indices_len].copy_(last_token_indices)
|
||||
token_indices_to_sample_len = token_indices_to_sample.shape[0]
|
||||
self.token_indices_to_sample[:token_indices_to_sample_len].copy_(token_indices_to_sample)
|
||||
|
||||
with set_ascend_forward_context(
|
||||
multi_steps_attn_metadata[0],
|
||||
@@ -672,7 +724,7 @@ class AscendEagleProposer(EagleProposer):
|
||||
draft_token_ids = self._runnable(
|
||||
num_input_tokens=num_input_tokens,
|
||||
batch_size=batch_size,
|
||||
last_token_indices=self.last_token_indices[:last_token_indices_len],
|
||||
token_indices_to_sample=self.token_indices_to_sample[:token_indices_to_sample_len],
|
||||
target_positions=target_positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
multi_steps_attn_metadata=multi_steps_attn_metadata,
|
||||
@@ -689,7 +741,7 @@ class AscendEagleProposer(EagleProposer):
|
||||
self,
|
||||
num_input_tokens,
|
||||
batch_size,
|
||||
last_token_indices,
|
||||
token_indices_to_sample,
|
||||
target_positions,
|
||||
inputs_embeds,
|
||||
multi_steps_attn_metadata,
|
||||
@@ -702,17 +754,22 @@ class AscendEagleProposer(EagleProposer):
|
||||
# `model_hidden_states` represent the speculative model inputs.
|
||||
model_input_ids = self.input_ids[:num_input_tokens]
|
||||
model_positions = self._get_positions(num_input_tokens)
|
||||
model_hidden_states = self.hidden_states[:num_input_tokens]
|
||||
|
||||
model_hidden_states, model_positions = self.maybe_pad_and_reduce(model_hidden_states, model_positions)
|
||||
model_kwargs = {
|
||||
"input_ids": model_input_ids,
|
||||
"positions": model_positions,
|
||||
"inputs_embeds": inputs_embeds,
|
||||
}
|
||||
|
||||
ret_hidden_states = self.model(
|
||||
input_ids=model_input_ids,
|
||||
positions=model_positions,
|
||||
hidden_states=model_hidden_states,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
if self.method == "mtp":
|
||||
if self.pass_hidden_states_to_model:
|
||||
model_hidden_states = self.hidden_states[:num_input_tokens]
|
||||
model_hidden_states, model_positions = self.maybe_pad_and_reduce(model_hidden_states, model_positions)
|
||||
model_kwargs["hidden_states"] = model_hidden_states
|
||||
if self.method == "mtp":
|
||||
model_kwargs["positions"] = model_positions
|
||||
|
||||
ret_hidden_states = self.model(**model_kwargs)
|
||||
if not self.model_returns_tuple():
|
||||
last_hidden_states = ret_hidden_states
|
||||
hidden_states = last_hidden_states
|
||||
else:
|
||||
@@ -722,6 +779,7 @@ class AscendEagleProposer(EagleProposer):
|
||||
last_hidden_states, model_positions, hidden_states
|
||||
)
|
||||
|
||||
num_indices = token_indices_to_sample.shape[0]
|
||||
if self.pcp_size > 1:
|
||||
# remove graph padding before all_gather
|
||||
hidden_states = hidden_states[:num_tokens]
|
||||
@@ -741,26 +799,27 @@ class AscendEagleProposer(EagleProposer):
|
||||
self.runner.pcp_manager.pcp_allgather_restore_idx.gpu[: last_hidden_states.shape[0]],
|
||||
)
|
||||
|
||||
num_indices = last_token_indices.shape[0]
|
||||
if lmhead_tp_enable() and not is_dummy:
|
||||
max_num_reqs_across_dp = (
|
||||
self.vllm_config.scheduler_config.max_num_seqs * self.runner.uniform_decode_query_len
|
||||
)
|
||||
last_token_indices = nn.functional.pad(last_token_indices, (0, max_num_reqs_across_dp - num_indices))
|
||||
token_indices_to_sample = nn.functional.pad(
|
||||
token_indices_to_sample, (0, max_num_reqs_across_dp - num_indices)
|
||||
)
|
||||
|
||||
sample_hidden_states = last_hidden_states[last_token_indices]
|
||||
sample_hidden_states = last_hidden_states[token_indices_to_sample]
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
|
||||
if lmhead_tp_enable() and num_indices < logits.shape[0] and not is_dummy:
|
||||
logits = logits[:num_indices]
|
||||
last_token_indices = last_token_indices[:num_indices]
|
||||
token_indices_to_sample = token_indices_to_sample[:num_indices]
|
||||
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
|
||||
# Early exit if there is only one draft token to be generated.
|
||||
if self.num_speculative_tokens == 1:
|
||||
if self.num_speculative_tokens == 1 or self.parallel_drafting:
|
||||
# [batch_size, 1]
|
||||
return draft_token_ids.view(-1, 1)
|
||||
return draft_token_ids.view(-1, self.num_speculative_tokens)
|
||||
|
||||
if self.pcp_size * self.dcp_size > 1 and is_prefill:
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
@@ -775,11 +834,11 @@ class AscendEagleProposer(EagleProposer):
|
||||
)
|
||||
draft_token_ids_tensor[0] = draft_token_ids
|
||||
if self.uses_mrope:
|
||||
positions = target_positions[:, last_token_indices]
|
||||
positions = self.mrope_positions[:, token_indices_to_sample]
|
||||
else:
|
||||
positions = target_positions[last_token_indices]
|
||||
hidden_states = hidden_states[last_token_indices]
|
||||
last_token_indices = self.arange[:batch_size]
|
||||
positions = self.positions[token_indices_to_sample]
|
||||
hidden_states = hidden_states[token_indices_to_sample]
|
||||
token_indices_to_sample = self.arange[:batch_size]
|
||||
|
||||
input_batch_size = num_input_tokens if (self.method == "mtp" or self.use_cuda_graph) else batch_size
|
||||
|
||||
@@ -843,13 +902,17 @@ class AscendEagleProposer(EagleProposer):
|
||||
forward_context.attn_metadata = (
|
||||
multi_steps_attn_metadata[draft_step + 1] if multi_steps_attn_metadata else None
|
||||
)
|
||||
ret_hidden_states = self.model(
|
||||
input_ids=model_input_ids,
|
||||
positions=model_positions,
|
||||
hidden_states=model_hidden_states,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
if self.method == "mtp":
|
||||
|
||||
model_kwargs = {
|
||||
"input_ids": model_input_ids,
|
||||
"positions": model_positions,
|
||||
"inputs_embeds": inputs_embeds,
|
||||
}
|
||||
if self.pass_hidden_states_to_model:
|
||||
model_kwargs["hidden_states"] = model_hidden_states
|
||||
|
||||
ret_hidden_states = self.model(**model_kwargs)
|
||||
if not self.model_returns_tuple():
|
||||
last_hidden_states = ret_hidden_states
|
||||
hidden_states = last_hidden_states
|
||||
else:
|
||||
@@ -859,22 +922,22 @@ class AscendEagleProposer(EagleProposer):
|
||||
last_hidden_states, model_positions, hidden_states
|
||||
)
|
||||
|
||||
num_indices = last_token_indices.shape[0]
|
||||
num_indices = token_indices_to_sample.shape[0]
|
||||
if lmhead_tp_enable() and not is_dummy:
|
||||
max_num_reqs_across_dp = (
|
||||
self.vllm_config.scheduler_config.max_num_seqs * self.runner.uniform_decode_query_len
|
||||
)
|
||||
last_token_indices = nn.functional.pad(
|
||||
last_token_indices,
|
||||
token_indices_to_sample = nn.functional.pad(
|
||||
token_indices_to_sample,
|
||||
(0, max_num_reqs_across_dp - num_indices),
|
||||
)
|
||||
|
||||
sample_hidden_states = last_hidden_states[last_token_indices]
|
||||
sample_hidden_states = last_hidden_states[token_indices_to_sample]
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
|
||||
if lmhead_tp_enable() and num_indices < logits.shape[0] and not is_dummy:
|
||||
logits = logits[:num_indices]
|
||||
last_token_indices = last_token_indices[:num_indices]
|
||||
token_indices_to_sample = token_indices_to_sample[:num_indices]
|
||||
|
||||
# TODO(wenlong): get more than one token for tree attention
|
||||
hidden_states = hidden_states[:batch_size]
|
||||
@@ -885,6 +948,122 @@ class AscendEagleProposer(EagleProposer):
|
||||
draft_token_ids = draft_token_ids_tensor.swapaxes(0, 1)
|
||||
return draft_token_ids
|
||||
|
||||
def set_inputs_first_pass(
|
||||
self,
|
||||
target_token_ids: torch.Tensor,
|
||||
next_token_ids: torch.Tensor,
|
||||
target_positions: torch.Tensor,
|
||||
target_hidden_states: torch.Tensor,
|
||||
token_indices_to_sample: torch.Tensor | None,
|
||||
cad: CommonAttentionMetadata,
|
||||
num_rejected_tokens_gpu: torch.Tensor | None,
|
||||
) -> tuple[int, torch.Tensor, CommonAttentionMetadata]:
|
||||
if not self.needs_extra_input_slots:
|
||||
# Default EAGLE pathway: no reshaping of input tensors needed.
|
||||
# Simply rotate the input ids and leave the positions unchanged,
|
||||
# Inserting the next token ids at the last slot in each request.
|
||||
if token_indices_to_sample is None:
|
||||
token_indices_to_sample = cad.query_start_loc[1:] - 1
|
||||
|
||||
num_tokens = target_token_ids.shape[0]
|
||||
# Shift the input ids by one token.
|
||||
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
|
||||
self.input_ids[: num_tokens - 1] = target_token_ids[1:]
|
||||
# Replace the last token with the next token.
|
||||
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
||||
self.input_ids[token_indices_to_sample] = next_token_ids
|
||||
|
||||
# copy inputs to buffer for cudagraph
|
||||
if self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim == 0:
|
||||
target_positions = target_positions[0]
|
||||
|
||||
self._set_positions(num_tokens, target_positions)
|
||||
self.hidden_states[:num_tokens] = target_hidden_states
|
||||
|
||||
return num_tokens, token_indices_to_sample, cad
|
||||
else:
|
||||
assert self.is_rejected_token_mask is not None
|
||||
assert self.is_masked_token_mask is not None
|
||||
# 1.
|
||||
# Call the CopyAndExpandEagleInputs AscendC operator to copy
|
||||
# input_ids and positions into the correct slots in the
|
||||
# preallocated buffers self.input_ids, self.positions.
|
||||
batch_size = cad.batch_size()
|
||||
total_num_input_tokens = target_token_ids.shape[0]
|
||||
total_num_output_tokens = total_num_input_tokens + (self.net_num_new_slots_per_request * batch_size)
|
||||
|
||||
query_start_loc = cad.query_start_loc
|
||||
query_end_loc = cad.query_start_loc[1:] - 1
|
||||
if num_rejected_tokens_gpu is not None:
|
||||
query_end_loc = query_end_loc - num_rejected_tokens_gpu
|
||||
|
||||
(
|
||||
out_input_ids,
|
||||
out_positions,
|
||||
out_is_rejected_token_mask,
|
||||
out_is_masked_token_mask,
|
||||
token_indices_to_sample,
|
||||
out_hidden_state_mapping,
|
||||
) = torch.ops._C_ascend.npu_copy_and_expand_eagle_inputs(
|
||||
target_token_ids,
|
||||
target_positions.to(torch.int32),
|
||||
next_token_ids,
|
||||
query_start_loc,
|
||||
query_end_loc,
|
||||
0, # padding_token_id
|
||||
self.parallel_drafting_token_id,
|
||||
self.extra_slots_per_request,
|
||||
self.pass_hidden_states_to_model,
|
||||
total_num_output_tokens,
|
||||
)
|
||||
|
||||
# Copy returned tensors into pre-allocated buffers
|
||||
self.input_ids[:total_num_output_tokens].copy_(out_input_ids)
|
||||
self.positions[:total_num_output_tokens].copy_(out_positions)
|
||||
self.is_rejected_token_mask[:total_num_output_tokens].copy_(out_is_rejected_token_mask)
|
||||
self.is_masked_token_mask[:total_num_output_tokens].copy_(out_is_masked_token_mask)
|
||||
if self.pass_hidden_states_to_model:
|
||||
assert self.parallel_drafting_hidden_state_tensor is not None
|
||||
self.hidden_states[out_hidden_state_mapping] = target_hidden_states
|
||||
# Use torch.where to avoid DtoH sync from boolean indexing
|
||||
mask = self.is_masked_token_mask[:total_num_output_tokens]
|
||||
torch.where(
|
||||
mask.unsqueeze(1),
|
||||
self.parallel_drafting_hidden_state_tensor,
|
||||
self.hidden_states[:total_num_output_tokens],
|
||||
out=self.hidden_states[:total_num_output_tokens],
|
||||
)
|
||||
|
||||
# 2.
|
||||
# Recompute the slot mapping based on the new positions and
|
||||
# rejection mask.
|
||||
builder = (
|
||||
self._get_attention_metadata_builder()
|
||||
if self.attn_metadata_builder is None
|
||||
else self.attn_metadata_builder
|
||||
)
|
||||
new_slot_mapping = compute_new_slot_mapping(
|
||||
cad=cad,
|
||||
new_positions=self.positions[:total_num_output_tokens],
|
||||
is_rejected_token_mask=self.is_rejected_token_mask[:total_num_output_tokens],
|
||||
block_size=builder.kv_cache_spec.block_size,
|
||||
num_new_tokens=self.net_num_new_slots_per_request,
|
||||
max_model_len=self.max_model_len,
|
||||
)
|
||||
|
||||
# 3. Update the common attention metadata with the new (meta)data
|
||||
new_cad = extend_all_queries_by_N(
|
||||
cad,
|
||||
N=self.net_num_new_slots_per_request,
|
||||
arange=self.arange,
|
||||
new_slot_mapping=new_slot_mapping,
|
||||
)
|
||||
|
||||
return total_num_output_tokens, token_indices_to_sample, new_cad
|
||||
|
||||
def model_returns_tuple(self) -> bool:
|
||||
return self.method not in ("mtp", "draft_model")
|
||||
|
||||
def attn_update_stack_num_spec_norm(
|
||||
self,
|
||||
# `draft_step` must start from `1`, no `0`
|
||||
@@ -1201,7 +1380,7 @@ class AscendEagleProposer(EagleProposer):
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
spec_decode_metadata: SpecDecodeMetadata,
|
||||
valid_sampled_tokens_count: torch.Tensor,
|
||||
) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
This function is used to prepare the inputs for speculative decoding
|
||||
It updates the common_attn_metadata for speculative decoding,
|
||||
@@ -1215,7 +1394,7 @@ class AscendEagleProposer(EagleProposer):
|
||||
device = valid_sampled_tokens_count.device
|
||||
|
||||
token_indices_to_sample = torch.empty((num_reqs,), dtype=torch.int32, device=device)
|
||||
|
||||
num_rejected_tokens_gpu = torch.empty((num_reqs,), dtype=torch.int32, device=device)
|
||||
num_blocks_needed = triton.cdiv(num_reqs, _PREPARE_INPUTS_BLOCK_SIZE)
|
||||
num_vector_core = get_vectorcore_num()
|
||||
grid_size = min(num_blocks_needed, num_vector_core)
|
||||
@@ -1226,6 +1405,7 @@ class AscendEagleProposer(EagleProposer):
|
||||
valid_sampled_tokens_count,
|
||||
common_attn_metadata.query_start_loc,
|
||||
token_indices_to_sample,
|
||||
num_rejected_tokens_gpu,
|
||||
num_reqs,
|
||||
BLOCK_SIZE=_PREPARE_INPUTS_BLOCK_SIZE,
|
||||
)
|
||||
@@ -1274,7 +1454,7 @@ class AscendEagleProposer(EagleProposer):
|
||||
max_seq_len=0,
|
||||
)
|
||||
|
||||
return spec_common_attn_metadata, token_indices, token_indices_to_sample
|
||||
return spec_common_attn_metadata, token_indices, token_indices_to_sample, num_rejected_tokens_gpu
|
||||
|
||||
def _split_pcp_input(self, req_scheduled_tokens, input_ids, target_hidden_states):
|
||||
"""
|
||||
@@ -1394,3 +1574,18 @@ class AscendEagleProposer(EagleProposer):
|
||||
if hidden_states is not None:
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(hidden_states.contiguous(), True)
|
||||
return last_hidden_states, positions, hidden_states
|
||||
|
||||
|
||||
class AscendEagleProposer(SpecDecodeBaseProposer):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
runner=None,
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config,
|
||||
device,
|
||||
pass_hidden_states_to_model=True,
|
||||
runner=runner,
|
||||
)
|
||||
|
||||
@@ -108,6 +108,7 @@ from vllm_ascend.patch.worker.patch_draft_quarot import patch_load_weights
|
||||
from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort
|
||||
from vllm_ascend.sample.sampler import AscendSampler
|
||||
from vllm_ascend.spec_decode import get_spec_decode_method
|
||||
from vllm_ascend.spec_decode.draft_proposer import AscendDraftModelProposer
|
||||
from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer
|
||||
from vllm_ascend.spec_decode.medusa_proposer import AscendMedusaProposer
|
||||
from vllm_ascend.spec_decode.ngram_proposer import AscendNgramProposer
|
||||
@@ -406,7 +407,12 @@ class NPUModelRunner(GPUModelRunner):
|
||||
def _set_up_drafter(self):
|
||||
# Set up speculative decoding.
|
||||
self.drafter: (
|
||||
AscendNgramProposer | AscendEagleProposer | AscendSuffixDecodingProposer | AscendMedusaProposer | None
|
||||
AscendNgramProposer
|
||||
| AscendEagleProposer
|
||||
| AscendDraftModelProposer
|
||||
| AscendSuffixDecodingProposer
|
||||
| AscendMedusaProposer
|
||||
| None
|
||||
) = None
|
||||
self.actual_seq_lengths_q: list[int] = []
|
||||
self.decode_token_per_req = 1
|
||||
@@ -971,7 +977,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
draft_token_ids = self.drafter.propose(
|
||||
valid_sampled_token_ids, sampling_metadata, spec_decode_metadata, sample_hidden_states
|
||||
)
|
||||
elif self.speculative_config.use_eagle():
|
||||
elif self.speculative_config.use_eagle() or self.speculative_config.uses_draft_model():
|
||||
common_attn_metadata = spec_decode_common_attn_metadata
|
||||
sampled_token_ids = valid_sampled_token_ids
|
||||
|
||||
@@ -1018,6 +1024,8 @@ class NPUModelRunner(GPUModelRunner):
|
||||
long_seq_metadata = None # type: ignore
|
||||
num_prefill_reqs = 0
|
||||
num_decode_reqs = 0
|
||||
|
||||
num_rejected_tokens_gpu = None
|
||||
if spec_decode_metadata is None:
|
||||
# update pcp related params
|
||||
if self.pcp_size > 1:
|
||||
@@ -1053,8 +1061,10 @@ class NPUModelRunner(GPUModelRunner):
|
||||
)
|
||||
else:
|
||||
assert self.drafter is not None
|
||||
common_attn_metadata, token_indices, token_indices_to_sample = self.drafter.prepare_inputs_padded(
|
||||
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
|
||||
common_attn_metadata, token_indices, token_indices_to_sample, num_rejected_tokens_gpu = (
|
||||
self.drafter.prepare_inputs_padded(
|
||||
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
|
||||
)
|
||||
)
|
||||
if self.pcp_size > 1:
|
||||
target_token_ids = input_ids_pcp_full[token_indices]
|
||||
@@ -1075,7 +1085,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
next_token_ids=next_token_ids,
|
||||
last_token_indices=token_indices_to_sample,
|
||||
token_indices_to_sample=token_indices_to_sample,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
sampling_metadata=sampling_metadata,
|
||||
req_scheduled_tokens=req_scheduled_tokens,
|
||||
@@ -1084,6 +1094,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
num_decode_reqs=num_decode_reqs,
|
||||
scheduler_output=scheduler_output,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown speculative decoding method: {self.speculative_config.method}")
|
||||
@@ -1516,16 +1527,16 @@ class NPUModelRunner(GPUModelRunner):
|
||||
|
||||
with record_function_or_nullcontext("draft_token"):
|
||||
if self.speculative_config:
|
||||
use_padded_batch_for_eagle = (
|
||||
use_padded_batch = (
|
||||
self.speculative_config
|
||||
and self.speculative_config.use_eagle()
|
||||
and (self.speculative_config.use_eagle() or self.speculative_config.uses_draft_model())
|
||||
and not self.speculative_config.disable_padded_drafter_batch
|
||||
)
|
||||
if use_padded_batch_for_eagle:
|
||||
if use_padded_batch:
|
||||
# EAGLE speculative decoding can use the GPU sampled tokens
|
||||
# as inputs, and does not need to wait for bookkeeping to finish.
|
||||
propose_draft_token_ids(sampler_output.sampled_token_ids)
|
||||
if self.speculative_config and not use_padded_batch_for_eagle:
|
||||
if self.speculative_config and not use_padded_batch:
|
||||
# ngram and other speculative decoding methods use the sampled
|
||||
# tokens on the CPU, so they are run after bookkeeping.
|
||||
propose_draft_token_ids(valid_sampled_token_ids)
|
||||
@@ -2165,7 +2176,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
if kv_cache_gid > 0:
|
||||
cm.block_table_tensor, cm.slot_mapping = _get_block_table_and_slot_mapping(kv_cache_gid)
|
||||
if self.speculative_config and spec_decode_common_attn_metadata is None:
|
||||
if isinstance(self.drafter, AscendEagleProposer):
|
||||
if isinstance(self.drafter, AscendEagleProposer | AscendDraftModelProposer):
|
||||
if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names:
|
||||
spec_decode_common_attn_metadata = cm
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user