From f1f6370ed966bd554f8874ddc430c09f25354c72 Mon Sep 17 00:00:00 2001 From: fluctlux <38945811+fluctlux@users.noreply.github.com> Date: Mon, 1 Dec 2025 18:41:42 +0800 Subject: [PATCH] [Feature] Integrate Suffix Spec Decoding (#4045) ### What this PR does / why we need it? This PR integrate suffix decoding (https://arxiv.org/abs/2411.04975) from vllm (https://github.com/vllm-project/vllm/pull/25784) # Suffix Decoding is a dynamic n-gram matching method that: 1. Uses suffix trees to generate speculative tokens quickly using branch frequency counts. 2. Can keep a history of prior model responses, which tends to work very well with repetitive agentic use cases. 3. Can be dynamically updated with newly generated tokens, and FIFO eviction of older requests. # ### Does this PR introduce _any_ user-facing change? This feature should be implemented as opt-in and remain seamless for users who do not require suffix speculative decoding. For users who wish to enable it, they must first install arctic-inference: `pip install arctic-inference ` After installation, the suffix speculative decoding feature can be enabled using the following speculative config: `--speculative_config '{"method": "suffix", "num_speculative_tokens": 5}' ` ### How was this patch tested? This PR is currently being tested on vLLM main:https://github.com/vllm-project/vllm/commit/83f478bb19489b41e9d208b47b4bb5a95ac171ac with PR https://github.com/vllm-project/vllm/pull/25784 In our previous testing, suffix decoding achieved a 13%-30% throughput improvement over n-gram on the sonnet dataset, tested on vllm-ascend v0.9.1 with concurrency ranging from 2 to 40. - vLLM version: v0.11.2 --------- Signed-off-by: fluctlux <38945811+fluctlux@users.noreply.github.com> --- requirements-dev.txt | 3 +- .../spec_decode_v1/test_v1_spec_decode.py | 85 +++++++++++++++++++ vllm_ascend/patch/platform/patch_config.py | 6 ++ vllm_ascend/spec_decode/__init__.py | 3 + vllm_ascend/spec_decode/interface.py | 3 +- vllm_ascend/spec_decode/suffix_proposer.py | 45 ++++++++++ vllm_ascend/worker/model_runner_v1.py | 4 +- 7 files changed, 146 insertions(+), 3 deletions(-) create mode 100644 vllm_ascend/spec_decode/suffix_proposer.py diff --git a/requirements-dev.txt b/requirements-dev.txt index d3db952d..44bfc3c5 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -19,4 +19,5 @@ librosa soundfile pytest_mock msserviceprofiler>=1.2.2 -mindstudio-probe>=8.3.0 \ No newline at end of file +mindstudio-probe>=8.3.0 +arctic-inference==0.1.1 \ No newline at end of file diff --git a/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py b/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py index 30b75150..aec67bc3 100644 --- a/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py +++ b/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py @@ -146,3 +146,88 @@ def test_eagle_correctness( # Heuristic: expect at least 66% of the prompts to match exactly # Upon failure, inspect the outputs to check for inaccuracy. assert matches > int(0.66 * len(ref_outputs)) + + +def test_suffix_correctness( + test_prompts: list[list[dict[str, Any]]], + sampling_config: SamplingParams, + model_name: str, +): + ''' + Compare the outputs of a original LLM and a speculative LLM + should be the same when using ngram speculative decoding. + ''' + ref_llm = LLM(model=model_name, max_model_len=1024, enforce_eager=False) + ref_outputs = ref_llm.chat(test_prompts, sampling_config) + del ref_llm + with VllmRunner(model_name, + speculative_config={ + "method": "suffix", + "num_speculative_tokens": 8, + }, + max_model_len=1024, + enforce_eager=False) as runner: + spec_outputs = runner.model.chat(test_prompts, sampling_config) + matches = 0 + misses = 0 + for ref_output, spec_output in zip(ref_outputs, spec_outputs): + if ref_output.outputs[0].text == spec_output.outputs[0].text: + matches += 1 + else: + misses += 1 + print(f"ref_output: {ref_output.outputs[0].text}") + print(f"spec_output: {spec_output.outputs[0].text}") + + # Heuristic: expect at least 70% of the prompts to match exactly + # Upon failure, inspect the outputs to check for inaccuracy. + assert matches > int(0.66 * len(ref_outputs)) + + +def test_suffix_acceptance( + test_prompts: list[list[dict[str, Any]]], + sampling_config: SamplingParams, + model_name: str, +): + ''' + Check that suffix decoding caching takes effect and improves acceptance + lengths and acceptance rates over multiple runs of the same prompts. + ''' + num_draft = [] + num_accept = [] + with VllmRunner(model_name, + speculative_config={ + "method": "suffix", + "suffix_decoding_max_spec_factor": 2.0, + "suffix_decoding_max_cached_requests": 1000, + "num_speculative_tokens": 10, + }, + max_model_len=1024, + disable_log_stats=False, + enforce_eager=False) as runner: + for i in range(10): + runner.model.chat(test_prompts[i], sampling_config) + metrics = runner.model.get_metrics() + for metric in metrics: + print(metric) + if metric.name == "vllm:spec_decode_num_draft_tokens": + num_draft.append(metric.value) + if metric.name == "vllm:spec_decode_num_accepted_tokens": + num_accept.append(metric.value) + # Calculate the acceptance rates for the first and last runs. + first_accept_tokens = num_accept[0] + first_draft_tokens = num_draft[0] + first_accept_rate = first_accept_tokens / first_draft_tokens + + # Take the diff since the stats are cumulative. + last_accept_tokens = num_accept[-1] - num_accept[-2] + last_draft_tokens = num_draft[-1] - num_draft[-2] + last_accept_rate = last_accept_tokens / last_draft_tokens + + # Expect the acceptance length to improve. + assert first_accept_tokens < last_accept_tokens + + # Expect the acceptance rate to improve. + assert first_accept_rate < last_accept_rate + + # Heuristic: expect at least 80% acceptance rate at the end. + assert last_accept_rate > 0.60 diff --git a/vllm_ascend/patch/platform/patch_config.py b/vllm_ascend/patch/platform/patch_config.py index 0e8642d1..b798fda3 100644 --- a/vllm_ascend/patch/platform/patch_config.py +++ b/vllm_ascend/patch/platform/patch_config.py @@ -28,6 +28,8 @@ def __post_init__(self): self.quantization = self.target_model_config.quantization elif self.method in ("ngram", "[ngram]"): self.model = "ngram" + elif self.method == "suffix": + self.model = "suffix" else: raise ValueError("num_speculative_tokens was provided but without " "speculative model.") @@ -70,6 +72,10 @@ def __post_init__(self): # draft related config as None here. self.draft_model_config = self.target_model_config self.draft_parallel_config = self.target_parallel_config + elif self.method == "suffix": + self.draft_model_config = self.target_model_config + self.draft_parallel_config = self.target_parallel_config + self._validate_suffix_decoding() else: self.prompt_lookup_max = 0 self.prompt_lookup_min = 0 diff --git a/vllm_ascend/spec_decode/__init__.py b/vllm_ascend/spec_decode/__init__.py index 6abe8777..a8d44875 100644 --- a/vllm_ascend/spec_decode/__init__.py +++ b/vllm_ascend/spec_decode/__init__.py @@ -19,6 +19,7 @@ from vllm_ascend.spec_decode.eagle_proposer import EagleProposer from vllm_ascend.spec_decode.mtp_proposer import MtpProposer from vllm_ascend.spec_decode.ngram_proposer import NgramProposer +from vllm_ascend.spec_decode.suffix_proposer import SuffixDecodingProposer from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer @@ -35,6 +36,8 @@ def get_spec_decode_method(method, if is_torchair_graph: return TorchairMtpProposer(vllm_config, device, runner) return MtpProposer(vllm_config, device, runner) + elif method == 'suffix': + return SuffixDecodingProposer(vllm_config, device, runner) else: raise ValueError("Unknown speculative decoding method: " f"{method}") diff --git a/vllm_ascend/spec_decode/interface.py b/vllm_ascend/spec_decode/interface.py index 5fdb4945..098f171f 100644 --- a/vllm_ascend/spec_decode/interface.py +++ b/vllm_ascend/spec_decode/interface.py @@ -14,6 +14,7 @@ class SpecDcodeType(enum.Enum): EAGLE = 1 EAGLE3 = 2 MTP = 4 + SUFFIX = 5 class Proposer: @@ -51,4 +52,4 @@ class Proposer: attn_metadata=None, aux_hidden_states: torch.Tensor = None): """Called by execute_model in model_runner""" - raise NotImplementedError + raise NotImplementedError \ No newline at end of file diff --git a/vllm_ascend/spec_decode/suffix_proposer.py b/vllm_ascend/spec_decode/suffix_proposer.py new file mode 100644 index 00000000..e6070449 --- /dev/null +++ b/vllm_ascend/spec_decode/suffix_proposer.py @@ -0,0 +1,45 @@ +import torch +from vllm.config import CUDAGraphMode +from vllm.v1.spec_decode.suffix_decoding import \ + SuffixDecodingProposer as VllmSuffixDecodingProposer + +from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType + + +class SuffixDecodingProposer(VllmSuffixDecodingProposer, Proposer): + + def __init__(self, vllm_config, device, runner): + super().__init__(vllm_config) + self.name = SpecDcodeType.SUFFIX + self.device = device + self.runner = runner + + def load_model(self, *args, **kwargs): + # No model to load. + pass + + @torch.inference_mode() + def dummy_run(self, + num_tokens, + with_prefill=None, + skip_attn=None, + num_reqs=None, + num_tokens_across_dp=None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor=None, + dummy_compute_logits=lambda hidden_states: None): + pass + + def generate_token_ids(self, + valid_sampled_token_ids, + sampling_metadata=None, + scheduler_output=None, + spec_decode_metadata=None, + positions=None, + num_scheduled_tokens=None, + hidden_states=None, + attn_metadata=None, + aux_hidden_states=None) -> list[list[int]]: + draft_token_ids = self.propose(self.runner.input_batch, + valid_sampled_token_ids) + return draft_token_ids diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index c85b0ee3..f5c3bb35 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -96,6 +96,7 @@ from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin @@ -630,7 +631,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): # Set up speculative decoding. self.spec_attn_mask = None self.drafter: Optional[Union[NgramProposer, EagleProposer, MtpProposer, - TorchairMtpProposer]] = None + TorchairMtpProposer, + SuffixDecodingProposer]] = None self.actual_seq_lengths_q: list[int] = [] self.decode_token_per_req = 1 if self.speculative_config: