[Feature] Integrate Suffix Spec Decoding (#4045)

### What this PR does / why we need it?
This PR integrate suffix decoding (https://arxiv.org/abs/2411.04975)
from vllm (https://github.com/vllm-project/vllm/pull/25784)

#
Suffix Decoding is a dynamic n-gram matching method that:

1. Uses suffix trees to generate speculative tokens quickly using branch
frequency counts.
2. Can keep a history of prior model responses, which tends to work very
well with repetitive agentic use cases.
3. Can be dynamically updated with newly generated tokens, and FIFO
eviction of older requests.
#
### Does this PR introduce _any_ user-facing change?
This feature should be implemented as opt-in and remain seamless for
users who do not require suffix speculative decoding.

For users who wish to enable it, they must first install
arctic-inference:
`pip install arctic-inference
`

After installation, the suffix speculative decoding feature can be
enabled using the following speculative config:
`--speculative_config '{"method": "suffix", "num_speculative_tokens":
5}'
`

### How was this patch tested?
This PR is currently being tested on vLLM
main:83f478bb19
 with PR https://github.com/vllm-project/vllm/pull/25784

In our previous testing, suffix decoding achieved a 13%-30% throughput
improvement over n-gram on the sonnet dataset, tested on vllm-ascend
v0.9.1 with concurrency ranging from 2 to 40.

- vLLM version: v0.11.2

---------

Signed-off-by: fluctlux <38945811+fluctlux@users.noreply.github.com>
This commit is contained in:
fluctlux
2025-12-01 18:41:42 +08:00
committed by GitHub
parent 3023e15e23
commit f1f6370ed9
7 changed files with 146 additions and 3 deletions

View File

@@ -146,3 +146,88 @@ def test_eagle_correctness(
# Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))
def test_suffix_correctness(
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
):
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using ngram speculative decoding.
'''
ref_llm = LLM(model=model_name, max_model_len=1024, enforce_eager=False)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
with VllmRunner(model_name,
speculative_config={
"method": "suffix",
"num_speculative_tokens": 8,
},
max_model_len=1024,
enforce_eager=False) as runner:
spec_outputs = runner.model.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
# Heuristic: expect at least 70% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))
def test_suffix_acceptance(
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
):
'''
Check that suffix decoding caching takes effect and improves acceptance
lengths and acceptance rates over multiple runs of the same prompts.
'''
num_draft = []
num_accept = []
with VllmRunner(model_name,
speculative_config={
"method": "suffix",
"suffix_decoding_max_spec_factor": 2.0,
"suffix_decoding_max_cached_requests": 1000,
"num_speculative_tokens": 10,
},
max_model_len=1024,
disable_log_stats=False,
enforce_eager=False) as runner:
for i in range(10):
runner.model.chat(test_prompts[i], sampling_config)
metrics = runner.model.get_metrics()
for metric in metrics:
print(metric)
if metric.name == "vllm:spec_decode_num_draft_tokens":
num_draft.append(metric.value)
if metric.name == "vllm:spec_decode_num_accepted_tokens":
num_accept.append(metric.value)
# Calculate the acceptance rates for the first and last runs.
first_accept_tokens = num_accept[0]
first_draft_tokens = num_draft[0]
first_accept_rate = first_accept_tokens / first_draft_tokens
# Take the diff since the stats are cumulative.
last_accept_tokens = num_accept[-1] - num_accept[-2]
last_draft_tokens = num_draft[-1] - num_draft[-2]
last_accept_rate = last_accept_tokens / last_draft_tokens
# Expect the acceptance length to improve.
assert first_accept_tokens < last_accept_tokens
# Expect the acceptance rate to improve.
assert first_accept_rate < last_accept_rate
# Heuristic: expect at least 80% acceptance rate at the end.
assert last_accept_rate > 0.60