update
This commit is contained in:
0
vllm/v1/spec_decode/__init__.py
Normal file
0
vllm/v1/spec_decode/__init__.py
Normal file
75
vllm/v1/spec_decode/draft_model.py
Normal file
75
vllm/v1/spec_decode/draft_model.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing_extensions import override
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.v1.spec_decode.eagle import SpecDecodeBaseProposer
|
||||
from vllm.v1.spec_decode.utils import create_vllm_config_for_draft_model
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DraftModelProposer(SpecDecodeBaseProposer):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
runner=None,
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
device=device,
|
||||
pass_hidden_states_to_model=False,
|
||||
runner=runner,
|
||||
)
|
||||
self._raise_if_vocab_size_mismatch()
|
||||
self._raise_if_draft_tp_mismatch()
|
||||
|
||||
def _raise_if_vocab_size_mismatch(self):
|
||||
self.speculative_config.verify_equal_vocab_size_if_draft_model()
|
||||
|
||||
def _raise_if_draft_tp_mismatch(self):
|
||||
# Note(Tomas Ruiz) If we run the target model with TP > 1 and
|
||||
# the draft model with TP = 1, then the different TP ranks collide.
|
||||
# Specifically when all ranks compile the draft model on rank 0
|
||||
# (because TP=1), then the torch compile cache is overwritten and corrupted.
|
||||
# We need a mechanism like this: https://github.com/vllm-project/vllm/pull/5414
|
||||
# To prevent this error, we assert that both TP sizes must be the same.
|
||||
spec_cfg = self.speculative_config
|
||||
tgt_tp = spec_cfg.target_parallel_config.tensor_parallel_size
|
||||
draft_tp = spec_cfg.draft_parallel_config.tensor_parallel_size
|
||||
if draft_tp != tgt_tp:
|
||||
raise ValueError(
|
||||
f"Currently, 'draft_tensor_parallel_size' and 'tensor_parallel_size' "
|
||||
f"must be the same. Got {draft_tp} and {tgt_tp}. "
|
||||
"Please pass 'draft_tensor_parallel_size' in the speculative_config."
|
||||
)
|
||||
|
||||
@override
|
||||
def _get_model(self) -> nn.Module:
|
||||
# Draft models may be quantized or on different parallelism,
|
||||
# so we load them with a modified vllm config
|
||||
from vllm.compilation.backends import set_model_tag
|
||||
|
||||
temp_vllm_config = create_vllm_config_for_draft_model(self.vllm_config)
|
||||
with set_model_tag("draft_model"):
|
||||
model = get_model(
|
||||
vllm_config=temp_vllm_config,
|
||||
prefix="draft_model",
|
||||
)
|
||||
return model
|
||||
|
||||
@override
|
||||
def _maybe_share_embeddings(self, target_language_model: nn.Module) -> None:
|
||||
# Draft models don't share embeddings with the target model
|
||||
pass
|
||||
|
||||
@override
|
||||
def _maybe_share_lm_head(self, target_language_model: nn.Module) -> None:
|
||||
# Draft models don't share lm_head with the target model
|
||||
pass
|
||||
1765
vllm/v1/spec_decode/eagle.py
Normal file
1765
vllm/v1/spec_decode/eagle.py
Normal file
File diff suppressed because it is too large
Load Diff
78
vllm/v1/spec_decode/medusa.py
Normal file
78
vllm/v1/spec_decode/medusa.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models.interfaces import is_mixture_of_experts
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
# Initialize logger
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MedusaProposer:
|
||||
"""
|
||||
Medusa proposer class for generating token sequences
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
# Save config parameters
|
||||
self.vllm_config = vllm_config
|
||||
assert vllm_config.speculative_config is not None, (
|
||||
"Speculative config must be set"
|
||||
)
|
||||
self.spec_config = vllm_config.speculative_config
|
||||
self.device = device
|
||||
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
self.hidden_size = self.spec_config.draft_model_config.get_hidden_size()
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
|
||||
def propose(
|
||||
self,
|
||||
target_hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
slot_mappings: dict[str, torch.Tensor]
|
||||
| list[dict[str, torch.Tensor]]
|
||||
| None = None, # unused
|
||||
) -> torch.Tensor:
|
||||
# Generate blocks and compute logits
|
||||
blocks = self.model(target_hidden_states)
|
||||
logits = self.model.compute_logits(blocks)
|
||||
|
||||
# Compute argmax for each Medusa head and stack into a single tensor
|
||||
# Shape: [batch_size, num_heads]
|
||||
draft_tokens = torch.stack([logit.argmax(dim=-1) for logit in logits], dim=1)
|
||||
|
||||
return draft_tokens
|
||||
|
||||
def load_model(self, target_model: nn.Module) -> None:
|
||||
from vllm.compilation.backends import set_model_tag
|
||||
|
||||
with set_model_tag("medusa_head"):
|
||||
self.model = get_model(
|
||||
vllm_config=self.vllm_config,
|
||||
model_config=self.spec_config.draft_model_config,
|
||||
)
|
||||
assert not (
|
||||
is_mixture_of_experts(self.model)
|
||||
and self.vllm_config.parallel_config.enable_eplb
|
||||
), "EPLB for Medusa is not supported"
|
||||
|
||||
@torch.inference_mode()
|
||||
def dummy_run(self, num_tokens: int) -> None:
|
||||
hidden_states = torch.zeros(
|
||||
(self.max_num_tokens, self.hidden_size),
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
with set_forward_context(None, self.vllm_config, num_tokens=num_tokens):
|
||||
self.model(hidden_states)
|
||||
66
vllm/v1/spec_decode/metadata.py
Normal file
66
vllm/v1/spec_decode/metadata.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpecDecodeMetadata:
|
||||
# [num_tokens]
|
||||
draft_token_ids: torch.Tensor
|
||||
# [batch_size]
|
||||
num_draft_tokens: list[int]
|
||||
# [batch_size]
|
||||
cu_num_draft_tokens: torch.Tensor
|
||||
# [batch_size]
|
||||
cu_num_sampled_tokens: torch.Tensor
|
||||
# [num_tokens]
|
||||
target_logits_indices: torch.Tensor
|
||||
# [batch_size]
|
||||
bonus_logits_indices: torch.Tensor
|
||||
# [num_tokens + batch_size]
|
||||
logits_indices: torch.Tensor
|
||||
|
||||
def __post_init__(self):
|
||||
self.max_spec_len = max(self.num_draft_tokens)
|
||||
|
||||
@classmethod
|
||||
def make_dummy(
|
||||
cls,
|
||||
draft_token_ids: list[list[int]],
|
||||
device: torch.device,
|
||||
) -> "SpecDecodeMetadata":
|
||||
batch_size = len(draft_token_ids)
|
||||
num_draft_tokens = [len(ids) for ids in draft_token_ids]
|
||||
num_sampled_tokens = [len(ids) + 1 for ids in draft_token_ids]
|
||||
flattened_draft_token_ids = sum(draft_token_ids, [])
|
||||
num_tokens = len(flattened_draft_token_ids)
|
||||
|
||||
draft_token_ids_tensor = torch.tensor(
|
||||
flattened_draft_token_ids, dtype=torch.int32, device=device
|
||||
)
|
||||
cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
|
||||
cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to(device)
|
||||
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32)
|
||||
cu_num_sampled_tokens_tensor = torch.from_numpy(cu_num_sampled_tokens).to(
|
||||
device
|
||||
)
|
||||
|
||||
target_logits_indices = torch.zeros(
|
||||
num_tokens, dtype=torch.int32, device=device
|
||||
)
|
||||
bonus_logits_indices = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||
logits_indices = torch.zeros(
|
||||
num_tokens + batch_size, dtype=torch.int32, device=device
|
||||
)
|
||||
return cls(
|
||||
draft_token_ids=draft_token_ids_tensor,
|
||||
num_draft_tokens=num_draft_tokens,
|
||||
cu_num_draft_tokens=cu_num_draft_tokens_tensor,
|
||||
cu_num_sampled_tokens=cu_num_sampled_tokens_tensor,
|
||||
target_logits_indices=target_logits_indices,
|
||||
bonus_logits_indices=bonus_logits_indices,
|
||||
logits_indices=logits_indices,
|
||||
)
|
||||
225
vllm/v1/spec_decode/metrics.py
Normal file
225
vllm/v1/spec_decode/metrics.py
Normal file
@@ -0,0 +1,225 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
import prometheus_client
|
||||
|
||||
from vllm.config import SpeculativeConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpecDecodingStats:
|
||||
"""Per-step iteration decoding stats from scheduler.
|
||||
|
||||
Each scheduler step, statistics on spec decoding performance are
|
||||
aggregated across requests by the scheduler and returned to the
|
||||
frontend in EngineCoreOutputs->SchedulerStats.
|
||||
"""
|
||||
|
||||
num_spec_tokens: int
|
||||
num_drafts: int = 0
|
||||
num_draft_tokens: int = 0
|
||||
num_accepted_tokens: int = 0
|
||||
num_accepted_tokens_per_pos: list[int] = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def new(cls, num_spec_tokens: int) -> "SpecDecodingStats":
|
||||
return cls(
|
||||
num_spec_tokens=num_spec_tokens,
|
||||
num_accepted_tokens_per_pos=[0] * num_spec_tokens,
|
||||
)
|
||||
|
||||
def observe_draft(self, num_draft_tokens: int, num_accepted_tokens: int):
|
||||
self.num_drafts += 1
|
||||
self.num_draft_tokens += num_draft_tokens
|
||||
self.num_accepted_tokens += num_accepted_tokens
|
||||
assert num_accepted_tokens <= self.num_spec_tokens
|
||||
for i in range(num_accepted_tokens):
|
||||
self.num_accepted_tokens_per_pos[i] += 1
|
||||
|
||||
|
||||
class SpecDecodingLogging:
|
||||
"""Aggregate and log spec decoding metrics.
|
||||
|
||||
LoggingStatLogger aggregates per-iteration metrics over a set
|
||||
time interval using observe() and then logs them using log()
|
||||
before resetting to zero.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.num_drafts: list[int] = []
|
||||
self.num_draft_tokens: list[int] = []
|
||||
self.num_accepted_tokens: list[int] = []
|
||||
self.accepted_tokens_per_pos_lists: list[list[int]] = []
|
||||
self.last_log_time = time.monotonic()
|
||||
|
||||
def observe(self, spec_decoding_stats: SpecDecodingStats):
|
||||
self.num_drafts.append(spec_decoding_stats.num_drafts)
|
||||
self.num_draft_tokens.append(spec_decoding_stats.num_draft_tokens)
|
||||
self.num_accepted_tokens.append(spec_decoding_stats.num_accepted_tokens)
|
||||
self.accepted_tokens_per_pos_lists.append(
|
||||
spec_decoding_stats.num_accepted_tokens_per_pos
|
||||
)
|
||||
|
||||
def log(self, log_fn=logger.info):
|
||||
if not self.num_drafts:
|
||||
return
|
||||
num_drafts = np.sum(self.num_drafts)
|
||||
num_draft_tokens = np.sum(self.num_draft_tokens)
|
||||
num_accepted_tokens = np.sum(self.num_accepted_tokens)
|
||||
draft_throughput = 0
|
||||
accepted_throughput = 0
|
||||
|
||||
elapsed_time = time.monotonic() - self.last_log_time
|
||||
if elapsed_time > 0:
|
||||
draft_throughput = num_draft_tokens / elapsed_time
|
||||
accepted_throughput = num_accepted_tokens / elapsed_time
|
||||
|
||||
draft_acceptance_rate = (
|
||||
num_accepted_tokens / num_draft_tokens * 100
|
||||
if num_draft_tokens > 0
|
||||
else float("nan")
|
||||
)
|
||||
|
||||
# Conventionally, mean acceptance length includes the bonus token
|
||||
mean_acceptance_length = 1 + (num_accepted_tokens / num_drafts)
|
||||
|
||||
pos_matrix = np.array(self.accepted_tokens_per_pos_lists)
|
||||
acceptance_rates = np.sum(pos_matrix, axis=0) / num_drafts
|
||||
rates_str = ", ".join(f"{p:.3f}" for p in acceptance_rates)
|
||||
|
||||
log_fn(
|
||||
"SpecDecoding metrics: "
|
||||
"Mean acceptance length: %.2f, "
|
||||
"Accepted throughput: %.2f tokens/s, "
|
||||
"Drafted throughput: %.2f tokens/s, "
|
||||
"Accepted: %d tokens, "
|
||||
"Drafted: %d tokens, "
|
||||
"Per-position acceptance rate: %s, "
|
||||
"Avg Draft acceptance rate: %.1f%%",
|
||||
mean_acceptance_length,
|
||||
accepted_throughput,
|
||||
draft_throughput,
|
||||
num_accepted_tokens,
|
||||
num_draft_tokens,
|
||||
rates_str,
|
||||
draft_acceptance_rate,
|
||||
)
|
||||
self.reset()
|
||||
|
||||
|
||||
class SpecDecodingProm:
|
||||
"""Record spec decoding metrics in Prometheus.
|
||||
|
||||
The acceptance rate can be calculated using a PromQL query:
|
||||
|
||||
rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
|
||||
rate(vllm:spec_decode_num_draft_tokens_total[$interval])
|
||||
|
||||
The mean acceptance length (conventionally including bonus tokens)
|
||||
can be calculated using:
|
||||
|
||||
1 + (
|
||||
rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
|
||||
rate(vllm:spec_decode_num_drafts[$interval]))
|
||||
|
||||
A per-position acceptance rate vector can be computed using
|
||||
|
||||
vllm:spec_decode_num_accepted_tokens_per_pos[$interval] /
|
||||
vllm:spec_decode_num_drafts[$interval]
|
||||
"""
|
||||
|
||||
_counter_cls = prometheus_client.Counter
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
speculative_config: SpeculativeConfig | None,
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[object]],
|
||||
):
|
||||
self.spec_decoding_enabled = speculative_config is not None
|
||||
if not self.spec_decoding_enabled:
|
||||
return
|
||||
|
||||
counter_drafts = self._counter_cls(
|
||||
name="vllm:spec_decode_num_drafts",
|
||||
documentation="Number of spec decoding drafts.",
|
||||
labelnames=labelnames,
|
||||
)
|
||||
self.counter_spec_decode_num_drafts = make_per_engine(
|
||||
counter_drafts, per_engine_labelvalues
|
||||
)
|
||||
|
||||
counter_draft_tokens = self._counter_cls(
|
||||
name="vllm:spec_decode_num_draft_tokens",
|
||||
documentation="Number of draft tokens.",
|
||||
labelnames=labelnames,
|
||||
)
|
||||
self.counter_spec_decode_num_draft_tokens = make_per_engine(
|
||||
counter_draft_tokens, per_engine_labelvalues
|
||||
)
|
||||
|
||||
counter_accepted_tokens = self._counter_cls(
|
||||
name="vllm:spec_decode_num_accepted_tokens",
|
||||
documentation="Number of accepted tokens.",
|
||||
labelnames=labelnames,
|
||||
)
|
||||
self.counter_spec_decode_num_accepted_tokens = make_per_engine(
|
||||
counter_accepted_tokens, per_engine_labelvalues
|
||||
)
|
||||
|
||||
assert speculative_config is not None
|
||||
num_spec_tokens = (
|
||||
speculative_config.num_speculative_tokens
|
||||
if self.spec_decoding_enabled
|
||||
else 0
|
||||
)
|
||||
pos_labelnames = labelnames + ["position"]
|
||||
base_counter = self._counter_cls(
|
||||
name="vllm:spec_decode_num_accepted_tokens_per_pos",
|
||||
documentation="Accepted tokens per draft position.",
|
||||
labelnames=pos_labelnames,
|
||||
)
|
||||
self.counter_spec_decode_num_accepted_tokens_per_pos: dict[
|
||||
int, list[prometheus_client.Counter]
|
||||
] = {
|
||||
idx: [base_counter.labels(*lv, str(pos)) for pos in range(num_spec_tokens)]
|
||||
for idx, lv in per_engine_labelvalues.items()
|
||||
}
|
||||
|
||||
def observe(self, spec_decoding_stats: SpecDecodingStats, engine_idx: int = 0):
|
||||
if not self.spec_decoding_enabled:
|
||||
return
|
||||
self.counter_spec_decode_num_drafts[engine_idx].inc(
|
||||
spec_decoding_stats.num_drafts
|
||||
)
|
||||
self.counter_spec_decode_num_draft_tokens[engine_idx].inc(
|
||||
spec_decoding_stats.num_draft_tokens
|
||||
)
|
||||
self.counter_spec_decode_num_accepted_tokens[engine_idx].inc(
|
||||
spec_decoding_stats.num_accepted_tokens
|
||||
)
|
||||
for pos, counter in enumerate(
|
||||
self.counter_spec_decode_num_accepted_tokens_per_pos[engine_idx]
|
||||
):
|
||||
counter.inc(spec_decoding_stats.num_accepted_tokens_per_pos[pos])
|
||||
|
||||
|
||||
def make_per_engine(
|
||||
counter: prometheus_client.Counter,
|
||||
per_engine_labelvalues: dict[int, list[object]],
|
||||
):
|
||||
"""Create a counter for each label value."""
|
||||
return {
|
||||
idx: counter.labels(*labelvalues)
|
||||
for idx, labelvalues in per_engine_labelvalues.items()
|
||||
}
|
||||
285
vllm/v1/spec_decode/ngram_proposer.py
Normal file
285
vllm/v1/spec_decode/ngram_proposer.py
Normal file
@@ -0,0 +1,285 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from numba import get_num_threads, jit, njit, prange, set_num_threads
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
|
||||
class NgramProposer:
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
assert vllm_config.speculative_config is not None
|
||||
assert vllm_config.speculative_config.prompt_lookup_min is not None
|
||||
assert vllm_config.speculative_config.prompt_lookup_max is not None
|
||||
|
||||
# Minimum length of the n-gram to match.
|
||||
self.min_n = vllm_config.speculative_config.prompt_lookup_min
|
||||
# Maximum length of the n-gram to match.
|
||||
self.max_n = vllm_config.speculative_config.prompt_lookup_max
|
||||
# Number of tokens follow the match. If there are less than k
|
||||
# tokens follow the match, we will return the maximum amount of
|
||||
# tokens until the end.
|
||||
self.k = vllm_config.speculative_config.num_speculative_tokens
|
||||
# Maximum length of the model.
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
|
||||
# Pre-allocate buffers for numba batch propose.
|
||||
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
|
||||
self.valid_ngram_draft = np.zeros((max_num_seqs, self.k), dtype=np.int32)
|
||||
self.valid_ngram_num_drafts = np.zeros((max_num_seqs), dtype=np.int32)
|
||||
|
||||
# Threshold of total number of tokens in the batch to enable
|
||||
# multi-threading in numba batch propose.
|
||||
self.num_tokens_threshold = 8192
|
||||
tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||
cpu_count = os.cpu_count()
|
||||
# Max number of threads for numba parallel processing.
|
||||
if cpu_count:
|
||||
# Divide by 2 to use physical cores
|
||||
# and not logical cores (hyper-threading).
|
||||
# Cap the number of threads to 8 to avoid using too many threads
|
||||
# since other components like frontend (incl tokenization)
|
||||
# and Structured Outputs also use multiple threads.
|
||||
# TODO(ekagra-ranjan): bump up the cap from 1 to 8
|
||||
# when TP parallelization for ngram is implemented.
|
||||
self.num_numba_thread_available = min(1, (cpu_count // 2))
|
||||
# Divide by tp_size to ensure each tensor parallel rank
|
||||
# has some threads since all ranks will run this.
|
||||
self.num_numba_thread_available //= tp_size
|
||||
else:
|
||||
self.num_numba_thread_available = 1
|
||||
|
||||
# Trigger Numba JIT compilation for N-gram proposer.
|
||||
# This usually takes less than 1 second.
|
||||
self.propose(
|
||||
[[]] * 1024,
|
||||
np.zeros(1024, dtype=np.int32),
|
||||
np.zeros((1024, self.max_model_len), dtype=np.int32),
|
||||
)
|
||||
|
||||
def batch_propose(
|
||||
self,
|
||||
num_requests: int,
|
||||
valid_ngram_requests: list,
|
||||
num_tokens_no_spec: np.ndarray,
|
||||
token_ids_cpu: np.ndarray,
|
||||
) -> list[list[int]]:
|
||||
"""Batch version of ngram proposer using numba for acceleration.
|
||||
|
||||
Args:
|
||||
valid_ngram_requests:
|
||||
Set of indices of requests that need ngram proposals.
|
||||
num_tokens_no_spec:
|
||||
Numpy array of shape (batch_size,) representing the number
|
||||
of tokens without speculative tokens for each request.
|
||||
token_ids_cpu:
|
||||
Numpy array of shape (batch_size, max_model_len)
|
||||
representing the token IDs for each request.
|
||||
|
||||
Returns:
|
||||
list[list[int]]:
|
||||
A list where each element is a list of proposed
|
||||
token IDs for the corresponding request.
|
||||
"""
|
||||
draft_token_ids: list[list[int]] = []
|
||||
|
||||
# Only run batch propose if there are requests needing ngram proposals.
|
||||
# avoid calling numba function with empty list which causes error
|
||||
# ValueError: cannot compute fingerprint of empty list
|
||||
if num_ngram_requests := len(valid_ngram_requests):
|
||||
original_num_numba_threads = get_num_threads()
|
||||
# Ensure we use at least one thread.
|
||||
# If total tokens is small, using multiple threads
|
||||
# may slow down due to overhead.
|
||||
total_tokens = np.sum(num_tokens_no_spec)
|
||||
if total_tokens >= self.num_tokens_threshold:
|
||||
final_num_threads = max(
|
||||
1, min(self.num_numba_thread_available, num_ngram_requests)
|
||||
)
|
||||
set_num_threads(final_num_threads)
|
||||
else:
|
||||
set_num_threads(1)
|
||||
|
||||
batch_propose_numba(
|
||||
valid_ngram_requests,
|
||||
num_tokens_no_spec,
|
||||
token_ids_cpu,
|
||||
self.min_n,
|
||||
self.max_n,
|
||||
self.max_model_len,
|
||||
self.k,
|
||||
self.valid_ngram_draft,
|
||||
self.valid_ngram_num_drafts,
|
||||
)
|
||||
|
||||
# Restore original number of threads.
|
||||
set_num_threads(original_num_numba_threads)
|
||||
|
||||
for i in range(num_requests):
|
||||
if i in valid_ngram_requests and self.valid_ngram_num_drafts[i] > 0:
|
||||
draft_token_ids.append(
|
||||
self.valid_ngram_draft[i, : self.valid_ngram_num_drafts[i]].tolist()
|
||||
)
|
||||
else:
|
||||
draft_token_ids.append([])
|
||||
|
||||
return draft_token_ids
|
||||
|
||||
def propose(
|
||||
self,
|
||||
sampled_token_ids: list[list[int]],
|
||||
num_tokens_no_spec: np.ndarray,
|
||||
token_ids_cpu: np.ndarray,
|
||||
slot_mappings: dict[str, torch.Tensor]
|
||||
| list[dict[str, torch.Tensor]]
|
||||
| None = None, # unused
|
||||
) -> list[list[int]]:
|
||||
# find which requests need ngram proposals
|
||||
valid_ngram_requests = []
|
||||
for i, sampled_ids in enumerate(sampled_token_ids):
|
||||
num_sampled_ids = len(sampled_ids)
|
||||
if not num_sampled_ids:
|
||||
# Skip speculative decoding.
|
||||
continue
|
||||
|
||||
num_tokens = num_tokens_no_spec[i]
|
||||
if num_tokens >= self.max_model_len:
|
||||
# Skip requests that have already reached the max model length.
|
||||
continue
|
||||
|
||||
valid_ngram_requests.append(i)
|
||||
|
||||
draft_token_ids = self.batch_propose(
|
||||
len(sampled_token_ids),
|
||||
valid_ngram_requests,
|
||||
num_tokens_no_spec,
|
||||
token_ids_cpu,
|
||||
)
|
||||
|
||||
return draft_token_ids
|
||||
|
||||
def load_model(self, *args, **kwargs):
|
||||
# No model to load.
|
||||
pass
|
||||
|
||||
|
||||
@njit(parallel=True)
|
||||
def batch_propose_numba(
|
||||
valid_ngram_requests: list,
|
||||
num_tokens_no_spec: np.ndarray,
|
||||
token_ids_cpu: np.ndarray,
|
||||
min_n: int,
|
||||
max_n: int,
|
||||
max_model_len: int,
|
||||
k: int,
|
||||
valid_ngram_draft: np.ndarray,
|
||||
valid_ngram_num_drafts: np.ndarray,
|
||||
):
|
||||
for i in prange(len(valid_ngram_requests)):
|
||||
idx = valid_ngram_requests[i]
|
||||
num_tokens = num_tokens_no_spec[idx]
|
||||
context_token_ids = token_ids_cpu[idx, :num_tokens]
|
||||
drafter_output = _find_longest_matched_ngram_and_propose_tokens(
|
||||
origin_tokens=context_token_ids,
|
||||
min_ngram=min_n,
|
||||
max_ngram=max_n,
|
||||
max_model_len=max_model_len,
|
||||
k=k,
|
||||
)
|
||||
|
||||
valid_ngram_num_drafts[idx] = drafter_output.shape[0]
|
||||
if len(drafter_output):
|
||||
valid_ngram_draft[idx, : drafter_output.shape[0]] = drafter_output
|
||||
|
||||
|
||||
@jit(nopython=True)
|
||||
def _find_longest_matched_ngram_and_propose_tokens(
|
||||
origin_tokens: np.ndarray,
|
||||
min_ngram: int,
|
||||
max_ngram: int,
|
||||
max_model_len: int,
|
||||
k: int,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Find the longest n-gram which matches the suffix of the given tokens
|
||||
whose length is within [min_ngram, max_ngram] (inclusive).
|
||||
|
||||
If found, we will extract k right after the matched ngram.
|
||||
"""
|
||||
# Do not generate draft tokens is context is shorter than minimum n-gram
|
||||
total_token = origin_tokens.shape[0]
|
||||
if total_token < min_ngram:
|
||||
return np.empty((0,), dtype=origin_tokens.dtype)
|
||||
|
||||
# Do not generate draft tokens beyond the max model length.
|
||||
k = min(k, max_model_len - total_token)
|
||||
if k <= 0:
|
||||
return np.empty((0,), dtype=origin_tokens.dtype)
|
||||
|
||||
# Flip tokens, and the goal become to find longest ngram
|
||||
# on the rightmost position which matches the prefix with
|
||||
# length [min_n, max_n] (inclusive).
|
||||
tokens = origin_tokens[::-1]
|
||||
|
||||
# Longest prefix (not including itself) which is a suffix of
|
||||
# the current position.
|
||||
# lps[i] = max{v, where tokens[0:v] == tokens[i+1-v:i+1]}
|
||||
#
|
||||
# As ngram is capped by max_ngram to save memory, we only need to
|
||||
# store lps for the first max_ngram prefix.
|
||||
lps = np.zeros(max_ngram, dtype=np.int32)
|
||||
|
||||
longest_ngram = 0
|
||||
position = 0
|
||||
|
||||
# lps[0] always equal to 0, we start with index 1
|
||||
prev_lps = 0
|
||||
i = 1
|
||||
while i < total_token:
|
||||
# tokens[:prev_lps] is the longest prefix as a suffix of tokens[:i]
|
||||
if tokens[prev_lps] == tokens[i]:
|
||||
# Token match: tokens[:prev_lps+1] is the longest prefix as
|
||||
# a suffix of tokens[:i+1]
|
||||
prev_lps += 1
|
||||
# Check if we found a longer valid ngram.
|
||||
#
|
||||
# Update position when longest_ngram matched prev_lps,
|
||||
# as we want to get the target n-gram of the earliest position
|
||||
# in the original tokens (i.e.
|
||||
# latest position in the reversed tokens)
|
||||
if prev_lps >= longest_ngram:
|
||||
longest_ngram = prev_lps
|
||||
position = i
|
||||
if i < max_ngram:
|
||||
# Store LPS for the first max_ngram prefix
|
||||
lps[i] = prev_lps
|
||||
if prev_lps == max_ngram:
|
||||
# When prev_lps reached max_ngram, update prev_lps
|
||||
# to lps[max_ngram-1] to avoid matching ngram
|
||||
# longer than max_ngram
|
||||
prev_lps = lps[max_ngram - 1]
|
||||
i += 1
|
||||
elif prev_lps != 0:
|
||||
# Token mismatch: try the second-longest prefix
|
||||
# among all suffix of tokens[:i],
|
||||
# which is the longest prefix of tokens[:prev_lps]
|
||||
prev_lps = lps[prev_lps - 1]
|
||||
else:
|
||||
# Token mismatch, and no more prefix (except empty string)
|
||||
# as a suffix of tokens[:i]
|
||||
i += 1
|
||||
|
||||
if longest_ngram < min_ngram:
|
||||
# No valid ngram is found
|
||||
return np.empty((0,), dtype=origin_tokens.dtype)
|
||||
|
||||
# Flip the position back, so in origin_tokens,
|
||||
# origin_tokens[total_token-1-position:total_token-1-position+longest_ngram]
|
||||
# is the matched ngram, so we should start drafting tokens from
|
||||
# total_token-1-position+longest_ngram
|
||||
start_position = total_token - 1 - position + longest_ngram
|
||||
k = min(k, total_token - start_position)
|
||||
return origin_tokens[start_position : start_position + k]
|
||||
101
vllm/v1/spec_decode/suffix_decoding.py
Normal file
101
vllm/v1/spec_decode/suffix_decoding.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
|
||||
class SuffixDecodingProposer:
|
||||
"""
|
||||
Speculative decoding proposer for Suffix Decoding (https://arxiv.org/pdf/2411.04975).
|
||||
This class imports and uses the official implementation from Arctic Inference
|
||||
(https://github.com/snowflakedb/ArcticInference).
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
config = vllm_config.speculative_config
|
||||
assert config is not None, "Speculative config must be set"
|
||||
self.num_speculative_tokens = config.num_speculative_tokens
|
||||
self.max_tree_depth = config.suffix_decoding_max_tree_depth
|
||||
self.max_spec_factor = config.suffix_decoding_max_spec_factor
|
||||
self.min_token_prob = config.suffix_decoding_min_token_prob
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
|
||||
# Lazy import to avoid error when Suffix Decoding is not used.
|
||||
from arctic_inference.suffix_decoding import SuffixDecodingCache
|
||||
|
||||
# Initialize and empty cache. This object will take care of caching request
|
||||
# outputs, evicting old requests, and manages the per-prompt suffix trees.
|
||||
self.suffix_cache = SuffixDecodingCache(
|
||||
max_tree_depth=config.suffix_decoding_max_tree_depth,
|
||||
max_cached_requests=config.suffix_decoding_max_cached_requests,
|
||||
)
|
||||
|
||||
def propose(
|
||||
self,
|
||||
input_batch: InputBatch,
|
||||
sampled_token_ids: list[list[int]],
|
||||
slot_mappings: dict[str, torch.Tensor]
|
||||
| list[dict[str, torch.Tensor]]
|
||||
| None = None, # unused
|
||||
) -> list[list[int]]:
|
||||
"""
|
||||
Propose speculative tokens for each request in the input batch. Suffix Decoding
|
||||
will speculate a dynamic number of tokens for each request every decoding step,
|
||||
so each entry in the returned list may have different lengths.
|
||||
"""
|
||||
draft_token_ids: list[list[int]] = []
|
||||
for i, sampled_ids in enumerate(sampled_token_ids):
|
||||
if not sampled_ids:
|
||||
# Skip speculative decoding for partial prefills.
|
||||
draft_token_ids.append([])
|
||||
continue
|
||||
|
||||
req_id = input_batch.req_ids[i]
|
||||
num_tokens = input_batch.num_tokens_no_spec[i]
|
||||
if num_tokens >= self.max_model_len:
|
||||
# Skip requests that have already reached the max model length.
|
||||
draft_token_ids.append([])
|
||||
continue
|
||||
|
||||
index = input_batch.req_id_to_index[req_id]
|
||||
if req_id not in self.suffix_cache.active_requests:
|
||||
if req_id in self.suffix_cache.cached_requests:
|
||||
# Reset the suffix cache for this request.
|
||||
self.suffix_cache.evict_cached_response(req_id)
|
||||
num_prompt_tokens = input_batch.num_prompt_tokens[index]
|
||||
prompt_token_ids = input_batch.token_ids_cpu[index, :num_prompt_tokens]
|
||||
# Start a new request, this will build the suffix tree for that prompt.
|
||||
self.suffix_cache.start_request(req_id, prompt_token_ids)
|
||||
|
||||
# Append the newly sampled ids to the suffix cache for this request.
|
||||
self.suffix_cache.add_active_response(req_id, sampled_ids)
|
||||
|
||||
# Suffix decoding only uses the most recent tokens up to max_tree_depth, so
|
||||
# we extract the pattern from the end of the input.
|
||||
start = max(0, num_tokens - self.max_tree_depth)
|
||||
pattern = input_batch.token_ids_cpu[i, start:num_tokens]
|
||||
draft = self.suffix_cache.speculate(
|
||||
req_id,
|
||||
pattern,
|
||||
max_spec_tokens=min(
|
||||
self.num_speculative_tokens, self.max_model_len - num_tokens - 1
|
||||
),
|
||||
max_spec_factor=self.max_spec_factor,
|
||||
min_token_prob=self.min_token_prob,
|
||||
)
|
||||
|
||||
draft_token_ids.append(draft.token_ids)
|
||||
|
||||
# Stop requests that were not seen in the input batch.
|
||||
for req_id in (
|
||||
self.suffix_cache.active_requests - input_batch.req_id_to_index.keys()
|
||||
):
|
||||
self.suffix_cache.stop_request(req_id)
|
||||
|
||||
return draft_token_ids
|
||||
|
||||
def load_model(self, *args, **kwargs):
|
||||
# No model to load.
|
||||
pass
|
||||
357
vllm/v1/spec_decode/utils.py
Normal file
357
vllm/v1/spec_decode/utils.py
Normal file
@@ -0,0 +1,357 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig, replace
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
|
||||
PADDING_SLOT_ID = -1
|
||||
|
||||
|
||||
@triton.jit
|
||||
def eagle_prepare_inputs_padded_kernel(
|
||||
cu_num_draft_tokens_ptr, # [num_reqs]
|
||||
valid_sampled_tokens_count_ptr, # [num_reqs]
|
||||
query_start_loc_gpu_ptr, # [num_reqs + 1]
|
||||
token_indices_to_sample_ptr, # [num_reqs] (output)
|
||||
num_rejected_tokens_gpu_ptr, # [num_reqs] (output)
|
||||
num_reqs, # tl.int32
|
||||
):
|
||||
"""
|
||||
Fused kernel for Eagle prepare_input_padded. This kernel computes the
|
||||
token index to sample for each request, taking into account the number
|
||||
of draft tokens and the number of valid sampled tokens (which is one more than
|
||||
the number of accepted tokens).
|
||||
"""
|
||||
req_idx = tl.program_id(axis=0)
|
||||
if req_idx >= num_reqs:
|
||||
return
|
||||
|
||||
# Calculate num_draft_tokens from cu_num_draft_tokens, which is an inclusive
|
||||
# cumulative sum (first entry is the first value, not zero).
|
||||
cu_draft_curr = tl.load(cu_num_draft_tokens_ptr + req_idx)
|
||||
|
||||
num_draft_tokens = 0
|
||||
if req_idx == 0:
|
||||
num_draft_tokens = cu_draft_curr
|
||||
else:
|
||||
cu_draft_prev = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
|
||||
num_draft_tokens = cu_draft_curr - cu_draft_prev
|
||||
|
||||
valid_count = tl.load(valid_sampled_tokens_count_ptr + req_idx)
|
||||
num_rejected_tokens = num_draft_tokens + 1 - valid_count
|
||||
num_rejected_tokens = tl.where(num_draft_tokens > 0, num_rejected_tokens, 0)
|
||||
|
||||
# query_start_loc[req_idx + 1] is the start position of the next request,
|
||||
# which is one past the last token of this request.
|
||||
q_last_tok_idx = tl.load(query_start_loc_gpu_ptr + req_idx + 1) - 1
|
||||
|
||||
index_to_sample = q_last_tok_idx - num_rejected_tokens
|
||||
tl.store(token_indices_to_sample_ptr + req_idx, index_to_sample)
|
||||
tl.store(num_rejected_tokens_gpu_ptr + req_idx, num_rejected_tokens)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def eagle_prepare_next_token_padded_kernel(
|
||||
sampled_token_ids_ptr, # [num_reqs, num_sampled_tokens_per_req]
|
||||
discard_request_mask_ptr, # [num_reqs]
|
||||
backup_next_token_ids_ptr, # [num_reqs]
|
||||
next_token_ids_ptr, # [num_reqs] (output)
|
||||
valid_sampled_tokens_count_ptr, # [num_reqs] (output)
|
||||
vocab_size, # tl.int32
|
||||
num_sampled_tokens_per_req, # tl.int32 (num_spec_tokens + 1)
|
||||
num_reqs, # tl.int32
|
||||
stride_sampled_token_ids, # tl.int32 (stride for dim 0)
|
||||
BLOCK_SIZE_TOKENS: tl.constexpr, # Power-of-2 >= num_sampled_tokens_per_req
|
||||
):
|
||||
"""
|
||||
Fused kernel for Eagle prepare_next_token_ids_padded. This kernel computes the
|
||||
number of valid (1 + accepted) tokens for each request, and the corresponding
|
||||
"next" token id to sample from during speculative decoding. This is the
|
||||
"last accepted token" from the sampled tokens, or the backup token if no
|
||||
tokens were accepted or if the request is marked as discarded.
|
||||
"""
|
||||
req_idx = tl.program_id(axis=0)
|
||||
if req_idx >= num_reqs:
|
||||
return
|
||||
|
||||
# Check if this request is discarded.
|
||||
is_discarded = tl.load(discard_request_mask_ptr + req_idx)
|
||||
|
||||
if is_discarded:
|
||||
backup_token = tl.load(backup_next_token_ids_ptr + req_idx)
|
||||
valid_count = tl.full((), 0, dtype=tl.uint32)
|
||||
tl.store(next_token_ids_ptr + req_idx, backup_token)
|
||||
tl.store(valid_sampled_tokens_count_ptr + req_idx, valid_count)
|
||||
else:
|
||||
# Count the number of valid tokens among the sampled tokens.
|
||||
token_offs = tl.arange(0, BLOCK_SIZE_TOKENS)
|
||||
token_mask = token_offs < num_sampled_tokens_per_req
|
||||
|
||||
row_ptr = sampled_token_ids_ptr + req_idx * stride_sampled_token_ids
|
||||
token_ids = tl.load(row_ptr + token_offs, mask=token_mask, other=-1)
|
||||
|
||||
# Rejected tokens are -1, valid tokens are in [0, vocab_size)
|
||||
is_valid_mask = (token_ids != -1) & (token_ids < vocab_size) & token_mask
|
||||
valid_count = tl.sum(is_valid_mask)
|
||||
|
||||
if valid_count > 0:
|
||||
# Guaranteed to be well-defined since
|
||||
# valid_count > 0 implies is_valid_mask is not empty
|
||||
last_valid_index = tl.max(tl.where(is_valid_mask, token_offs, -1))
|
||||
|
||||
# Select the token at that index, using a sum trick since
|
||||
# we don't want to load again to access token_ids[last_valid_index].
|
||||
last_valid_token = tl.sum(
|
||||
tl.where(token_offs == last_valid_index, token_ids, 0)
|
||||
)
|
||||
tl.store(next_token_ids_ptr + req_idx, last_valid_token)
|
||||
else:
|
||||
# No valid tokens found, use backup token
|
||||
backup_token = tl.load(backup_next_token_ids_ptr + req_idx)
|
||||
tl.store(next_token_ids_ptr + req_idx, backup_token)
|
||||
|
||||
tl.store(valid_sampled_tokens_count_ptr + req_idx, valid_count)
|
||||
|
||||
|
||||
def compute_new_slot_mapping(
|
||||
cad: CommonAttentionMetadata,
|
||||
new_positions: torch.Tensor,
|
||||
is_rejected_token_mask: torch.Tensor,
|
||||
block_size: int,
|
||||
num_new_tokens: int,
|
||||
max_model_len: int,
|
||||
):
|
||||
batch_size, n_blocks_per_req = cad.block_table_tensor.shape
|
||||
req_indices = torch.arange(batch_size, device=cad.query_start_loc.device)
|
||||
req_indices = torch.repeat_interleave(
|
||||
req_indices,
|
||||
cad.naive_query_lens() + num_new_tokens,
|
||||
output_size=len(new_positions),
|
||||
)
|
||||
# Clamp the positions to prevent an out-of-bounds error when indexing
|
||||
# into block_table_tensor.
|
||||
clamped_positions = torch.clamp(new_positions, max=max_model_len - 1)
|
||||
block_table_indices = (
|
||||
req_indices * n_blocks_per_req + clamped_positions // block_size
|
||||
)
|
||||
block_nums = cad.block_table_tensor.view(-1)[block_table_indices]
|
||||
block_offsets = clamped_positions % block_size
|
||||
new_slot_mapping = block_nums * block_size + block_offsets
|
||||
# Mask out the position ids that exceed the max model length.
|
||||
exceeds_max_model_len = new_positions >= max_model_len
|
||||
new_slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
|
||||
# Mask out rejected tokens to prevent saves to the KV cache.
|
||||
new_slot_mapping.masked_fill_(is_rejected_token_mask, PADDING_SLOT_ID)
|
||||
return new_slot_mapping
|
||||
|
||||
|
||||
def create_vllm_config_for_draft_model(
|
||||
target_model_vllm_config: VllmConfig,
|
||||
) -> VllmConfig:
|
||||
"""The vllm_config is configured for the target model, e.g.
|
||||
its quant_config and parallel_config. But the draft model is potentially
|
||||
quantized differently, and has potentially different tensor_parallel_size.
|
||||
This function creates a new vllm_config configured for the drafter.
|
||||
The vllm_config is useful when loading the draft model with get_model().
|
||||
"""
|
||||
old = target_model_vllm_config
|
||||
assert old.speculative_config is not None, "speculative_config is not set"
|
||||
old_spec_config = old.speculative_config
|
||||
new_parallel_config = replace(
|
||||
old_spec_config.draft_parallel_config, rank=old.parallel_config.rank
|
||||
)
|
||||
new: VllmConfig = replace(
|
||||
old,
|
||||
quant_config=None,
|
||||
parallel_config=new_parallel_config,
|
||||
model_config=old_spec_config.draft_model_config,
|
||||
)
|
||||
return new
|
||||
|
||||
|
||||
def extend_all_queries_by_N(
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
N: int,
|
||||
arange: torch.Tensor,
|
||||
new_slot_mapping: torch.Tensor,
|
||||
) -> CommonAttentionMetadata:
|
||||
"""
|
||||
Creates a new CommonAttentionMetadata with all query lengths increased by N.
|
||||
Also all seq lens are increased by N.
|
||||
This is useful e.g. in speculative decoding with parallel drafting, where we
|
||||
extend each sequence by N tokens and predict all tokens in one pass.
|
||||
The slot mapping is computed externally, as it requires more information.
|
||||
"""
|
||||
cad = common_attn_metadata
|
||||
# query start loc must be increased by [+0, +N, +2N, ..., +batch_size * N]
|
||||
new_query_start_loc = cad.query_start_loc + N * arange[: len(cad.query_start_loc)]
|
||||
new_query_start_loc_cpu = cad.query_start_loc_cpu + N * torch.arange(
|
||||
len(cad.query_start_loc_cpu), dtype=torch.int32
|
||||
)
|
||||
new_cad = cad.replace(
|
||||
query_start_loc=new_query_start_loc,
|
||||
query_start_loc_cpu=new_query_start_loc_cpu,
|
||||
seq_lens=cad.seq_lens + N,
|
||||
# each request is extended by N tokens -> batch_size * N tokens are added
|
||||
num_actual_tokens=cad.num_actual_tokens + cad.batch_size() * N,
|
||||
# All query lens increase by N, so max query len increases by N
|
||||
max_query_len=cad.max_query_len + N,
|
||||
max_seq_len=cad.max_seq_len + N,
|
||||
slot_mapping=new_slot_mapping,
|
||||
)
|
||||
return new_cad
|
||||
|
||||
|
||||
# Unified copy/expand kernel
|
||||
@triton.jit
|
||||
def copy_and_expand_eagle_inputs_kernel(
|
||||
# (Padded) Inputs from the target model
|
||||
target_token_ids_ptr, # [total_tokens_in_batch]
|
||||
target_positions_ptr, # [total_tokens_in_batch]
|
||||
next_token_ids_ptr, # [num_reqs]
|
||||
# Outputs to the drafting buffers
|
||||
out_input_ids_ptr, # [total_draft_tokens_in_batch] (output)
|
||||
out_positions_ptr, # [total_draft_tokens_in_batch] (output)
|
||||
out_is_rejected_token_mask_ptr, # [total_draft_tokens_in_batch] (output)
|
||||
out_is_masked_token_mask_ptr, # [total_draft_tokens_in_batch] (output)
|
||||
out_new_token_indices_ptr, # [num_padding_slots_per_request * num_reqs] (output)
|
||||
out_hidden_state_mapping_ptr, # [total_tokens_in_batch]
|
||||
# Input metadata
|
||||
query_start_loc_ptr, # [num_reqs + 1], last value is the total num input tokens
|
||||
query_end_loc_ptr, # [num_reqs]
|
||||
padding_token_id, # tl.int32
|
||||
parallel_drafting_token_id, # tl.int32
|
||||
# Sizing info
|
||||
total_input_tokens, # tl.int32
|
||||
num_padding_slots_per_request, # tl.int32
|
||||
shift_input_ids, # tl.bool
|
||||
BLOCK_SIZE_TOKENS: tl.constexpr, # Blocks along token dim to handle prefills
|
||||
):
|
||||
"""
|
||||
Copy and expand inputs from the target model to the drafting buffers for Eagle
|
||||
speculative decoding. This kernel handles padding slots and parallel drafting
|
||||
tokens, if enabled.
|
||||
"""
|
||||
request_idx = tl.program_id(axis=0)
|
||||
token_batch_idx = tl.program_id(axis=1)
|
||||
|
||||
# Load query locations
|
||||
query_start_loc = tl.load(query_start_loc_ptr + request_idx)
|
||||
next_query_start_loc = tl.load(query_start_loc_ptr + request_idx + 1)
|
||||
query_end_loc = tl.load(query_end_loc_ptr + request_idx)
|
||||
|
||||
# Calculate number of valid tokens to copy and input offset
|
||||
# With shift_input_ids=True, we skip the first token
|
||||
# Output layout: each request gets (input_len + num_padding_slots_per_request) slots
|
||||
# But with shift, we lose one token per request
|
||||
if shift_input_ids:
|
||||
num_valid_tokens = query_end_loc - query_start_loc
|
||||
input_offset = 1
|
||||
output_start = query_start_loc + request_idx * (
|
||||
num_padding_slots_per_request - 1
|
||||
)
|
||||
else:
|
||||
num_valid_tokens = query_end_loc - query_start_loc + 1
|
||||
input_offset = 0
|
||||
output_start = query_start_loc + request_idx * num_padding_slots_per_request
|
||||
|
||||
# Number of rejected tokens from previous speculation
|
||||
num_rejected = next_query_start_loc - query_end_loc - 1
|
||||
|
||||
# Total output tokens for this request
|
||||
total_output_tokens = (
|
||||
num_valid_tokens + num_padding_slots_per_request + num_rejected
|
||||
)
|
||||
|
||||
# Process tokens in this block
|
||||
j = token_batch_idx * BLOCK_SIZE_TOKENS + tl.arange(0, BLOCK_SIZE_TOKENS)
|
||||
|
||||
# Compute masks for different output regions:
|
||||
# [0, num_valid_tokens): valid tokens copied from input
|
||||
# [num_valid_tokens]: bonus token from next_token_ids
|
||||
# (num_valid_tokens, num_valid_tokens + num_padding_slots_per_request):
|
||||
# parallel drafting slots
|
||||
# [num_valid_tokens + num_padding_slots_per_request, total_output_tokens):
|
||||
# rejected slots
|
||||
in_bounds = j < total_output_tokens
|
||||
is_valid_region = j < num_valid_tokens
|
||||
is_bonus_region = j == num_valid_tokens
|
||||
is_parallel_draft_region = (j > num_valid_tokens) & (
|
||||
j < num_valid_tokens + num_padding_slots_per_request
|
||||
)
|
||||
is_rejected_region = j >= num_valid_tokens + num_padding_slots_per_request
|
||||
|
||||
# Compute output indices
|
||||
out_idx = output_start + j
|
||||
|
||||
# For valid tokens, compute input index
|
||||
in_idx = query_start_loc + input_offset + j
|
||||
# Clamp to avoid out-of-bounds access (masked loads still need valid addresses)
|
||||
in_idx_clamped = tl.minimum(in_idx, total_input_tokens - 1)
|
||||
|
||||
# Load input tokens (masked to valid region)
|
||||
token_ids = tl.load(
|
||||
target_token_ids_ptr + in_idx_clamped, mask=is_valid_region & in_bounds, other=0
|
||||
)
|
||||
|
||||
# Load the starting position for this request (first position in the sequence)
|
||||
start_pos = tl.load(target_positions_ptr + query_start_loc)
|
||||
|
||||
# Load bonus token for this request
|
||||
bonus_token = tl.load(next_token_ids_ptr + request_idx)
|
||||
|
||||
# Build final token_ids based on region
|
||||
token_ids = tl.where(is_bonus_region, bonus_token, token_ids)
|
||||
token_ids = tl.where(
|
||||
is_parallel_draft_region, parallel_drafting_token_id, token_ids
|
||||
)
|
||||
token_ids = tl.where(is_rejected_region, padding_token_id, token_ids)
|
||||
|
||||
# Build final positions:
|
||||
# Positions are NOT shifted - they start from the first input position and increment
|
||||
# Output position j gets start_pos + j
|
||||
# (e.g., input positions [5,6,7] -> output [5,6,7,8,9,...])
|
||||
positions = start_pos + j
|
||||
# Rejected positions are don't-care, set to 0
|
||||
positions = tl.where(is_rejected_region, 0, positions)
|
||||
|
||||
# Compute output masks
|
||||
is_rejected_out = is_rejected_region & in_bounds
|
||||
is_masked_out = is_parallel_draft_region & in_bounds
|
||||
|
||||
# Compute indices of new tokens (bonus + parallel drafting) for sampling
|
||||
# New tokens are at positions
|
||||
# [num_valid_tokens, num_valid_tokens + num_padding_slots_per_request)
|
||||
is_new_token_region = (j >= num_valid_tokens) & (
|
||||
j < num_valid_tokens + num_padding_slots_per_request
|
||||
)
|
||||
new_token_local_idx = (
|
||||
j - num_valid_tokens
|
||||
) # 0 for bonus, 1, 2, ... for parallel drafting
|
||||
new_token_out_idx = (
|
||||
request_idx * num_padding_slots_per_request + new_token_local_idx
|
||||
)
|
||||
|
||||
# Compute hidden state mapping (source index -> destination index)
|
||||
# This maps each input position to its corresponding output position
|
||||
# Hidden states don't get shifted, so we map all input tokens (including rejected)
|
||||
if shift_input_ids:
|
||||
num_input_tokens_this_request = next_query_start_loc - query_start_loc
|
||||
is_input_region = j < num_input_tokens_this_request
|
||||
src_idx = query_start_loc + j
|
||||
tl.store(out_hidden_state_mapping_ptr + src_idx, out_idx, mask=is_input_region)
|
||||
|
||||
# Store outputs
|
||||
tl.store(out_input_ids_ptr + out_idx, token_ids, mask=in_bounds)
|
||||
tl.store(out_positions_ptr + out_idx, positions, mask=in_bounds)
|
||||
tl.store(out_is_rejected_token_mask_ptr + out_idx, is_rejected_out, mask=in_bounds)
|
||||
tl.store(out_is_masked_token_mask_ptr + out_idx, is_masked_out, mask=in_bounds)
|
||||
tl.store(
|
||||
out_new_token_indices_ptr + new_token_out_idx,
|
||||
out_idx,
|
||||
mask=is_new_token_region & in_bounds,
|
||||
)
|
||||
Reference in New Issue
Block a user