[Feature] Integrate Suffix Spec Decoding (#4045)

### What this PR does / why we need it?
This PR integrate suffix decoding (https://arxiv.org/abs/2411.04975)
from vllm (https://github.com/vllm-project/vllm/pull/25784)

#
Suffix Decoding is a dynamic n-gram matching method that:

1. Uses suffix trees to generate speculative tokens quickly using branch
frequency counts.
2. Can keep a history of prior model responses, which tends to work very
well with repetitive agentic use cases.
3. Can be dynamically updated with newly generated tokens, and FIFO
eviction of older requests.
#
### Does this PR introduce _any_ user-facing change?
This feature should be implemented as opt-in and remain seamless for
users who do not require suffix speculative decoding.

For users who wish to enable it, they must first install
arctic-inference:
`pip install arctic-inference
`

After installation, the suffix speculative decoding feature can be
enabled using the following speculative config:
`--speculative_config '{"method": "suffix", "num_speculative_tokens":
5}'
`

### How was this patch tested?
This PR is currently being tested on vLLM
main:83f478bb19
 with PR https://github.com/vllm-project/vllm/pull/25784

In our previous testing, suffix decoding achieved a 13%-30% throughput
improvement over n-gram on the sonnet dataset, tested on vllm-ascend
v0.9.1 with concurrency ranging from 2 to 40.

- vLLM version: v0.11.2

---------

Signed-off-by: fluctlux <38945811+fluctlux@users.noreply.github.com>
This commit is contained in:
fluctlux
2025-12-01 18:41:42 +08:00
committed by GitHub
parent 3023e15e23
commit f1f6370ed9
7 changed files with 146 additions and 3 deletions

View File

@@ -19,4 +19,5 @@ librosa
soundfile
pytest_mock
msserviceprofiler>=1.2.2
mindstudio-probe>=8.3.0
mindstudio-probe>=8.3.0
arctic-inference==0.1.1

View File

@@ -146,3 +146,88 @@ def test_eagle_correctness(
# Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))
def test_suffix_correctness(
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
):
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using ngram speculative decoding.
'''
ref_llm = LLM(model=model_name, max_model_len=1024, enforce_eager=False)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
with VllmRunner(model_name,
speculative_config={
"method": "suffix",
"num_speculative_tokens": 8,
},
max_model_len=1024,
enforce_eager=False) as runner:
spec_outputs = runner.model.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
# Heuristic: expect at least 70% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))
def test_suffix_acceptance(
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
):
'''
Check that suffix decoding caching takes effect and improves acceptance
lengths and acceptance rates over multiple runs of the same prompts.
'''
num_draft = []
num_accept = []
with VllmRunner(model_name,
speculative_config={
"method": "suffix",
"suffix_decoding_max_spec_factor": 2.0,
"suffix_decoding_max_cached_requests": 1000,
"num_speculative_tokens": 10,
},
max_model_len=1024,
disable_log_stats=False,
enforce_eager=False) as runner:
for i in range(10):
runner.model.chat(test_prompts[i], sampling_config)
metrics = runner.model.get_metrics()
for metric in metrics:
print(metric)
if metric.name == "vllm:spec_decode_num_draft_tokens":
num_draft.append(metric.value)
if metric.name == "vllm:spec_decode_num_accepted_tokens":
num_accept.append(metric.value)
# Calculate the acceptance rates for the first and last runs.
first_accept_tokens = num_accept[0]
first_draft_tokens = num_draft[0]
first_accept_rate = first_accept_tokens / first_draft_tokens
# Take the diff since the stats are cumulative.
last_accept_tokens = num_accept[-1] - num_accept[-2]
last_draft_tokens = num_draft[-1] - num_draft[-2]
last_accept_rate = last_accept_tokens / last_draft_tokens
# Expect the acceptance length to improve.
assert first_accept_tokens < last_accept_tokens
# Expect the acceptance rate to improve.
assert first_accept_rate < last_accept_rate
# Heuristic: expect at least 80% acceptance rate at the end.
assert last_accept_rate > 0.60

View File

@@ -28,6 +28,8 @@ def __post_init__(self):
self.quantization = self.target_model_config.quantization
elif self.method in ("ngram", "[ngram]"):
self.model = "ngram"
elif self.method == "suffix":
self.model = "suffix"
else:
raise ValueError("num_speculative_tokens was provided but without "
"speculative model.")
@@ -70,6 +72,10 @@ def __post_init__(self):
# draft related config as None here.
self.draft_model_config = self.target_model_config
self.draft_parallel_config = self.target_parallel_config
elif self.method == "suffix":
self.draft_model_config = self.target_model_config
self.draft_parallel_config = self.target_parallel_config
self._validate_suffix_decoding()
else:
self.prompt_lookup_max = 0
self.prompt_lookup_min = 0

View File

@@ -19,6 +19,7 @@
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.spec_decode.ngram_proposer import NgramProposer
from vllm_ascend.spec_decode.suffix_proposer import SuffixDecodingProposer
from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer
@@ -35,6 +36,8 @@ def get_spec_decode_method(method,
if is_torchair_graph:
return TorchairMtpProposer(vllm_config, device, runner)
return MtpProposer(vllm_config, device, runner)
elif method == 'suffix':
return SuffixDecodingProposer(vllm_config, device, runner)
else:
raise ValueError("Unknown speculative decoding method: "
f"{method}")

View File

@@ -14,6 +14,7 @@ class SpecDcodeType(enum.Enum):
EAGLE = 1
EAGLE3 = 2
MTP = 4
SUFFIX = 5
class Proposer:
@@ -51,4 +52,4 @@ class Proposer:
attn_metadata=None,
aux_hidden_states: torch.Tensor = None):
"""Called by execute_model in model_runner"""
raise NotImplementedError
raise NotImplementedError

View File

@@ -0,0 +1,45 @@
import torch
from vllm.config import CUDAGraphMode
from vllm.v1.spec_decode.suffix_decoding import \
SuffixDecodingProposer as VllmSuffixDecodingProposer
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
class SuffixDecodingProposer(VllmSuffixDecodingProposer, Proposer):
def __init__(self, vllm_config, device, runner):
super().__init__(vllm_config)
self.name = SpecDcodeType.SUFFIX
self.device = device
self.runner = runner
def load_model(self, *args, **kwargs):
# No model to load.
pass
@torch.inference_mode()
def dummy_run(self,
num_tokens,
with_prefill=None,
skip_attn=None,
num_reqs=None,
num_tokens_across_dp=None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor=None,
dummy_compute_logits=lambda hidden_states: None):
pass
def generate_token_ids(self,
valid_sampled_token_ids,
sampling_metadata=None,
scheduler_output=None,
spec_decode_metadata=None,
positions=None,
num_scheduled_tokens=None,
hidden_states=None,
attn_metadata=None,
aux_hidden_states=None) -> list[list[int]]:
draft_token_ids = self.propose(self.runner.input_batch,
valid_sampled_token_ids)
return draft_token_ids

View File

@@ -96,6 +96,7 @@ from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
@@ -630,7 +631,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Set up speculative decoding.
self.spec_attn_mask = None
self.drafter: Optional[Union[NgramProposer, EagleProposer, MtpProposer,
TorchairMtpProposer]] = None
TorchairMtpProposer,
SuffixDecodingProposer]] = None
self.actual_seq_lengths_q: list[int] = []
self.decode_token_per_req = 1
if self.speculative_config: