diff --git a/requirements-dev.txt b/requirements-dev.txt index d3db952d..44bfc3c5 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -19,4 +19,5 @@ librosa soundfile pytest_mock msserviceprofiler>=1.2.2 -mindstudio-probe>=8.3.0 \ No newline at end of file +mindstudio-probe>=8.3.0 +arctic-inference==0.1.1 \ No newline at end of file diff --git a/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py b/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py index 30b75150..aec67bc3 100644 --- a/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py +++ b/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py @@ -146,3 +146,88 @@ def test_eagle_correctness( # Heuristic: expect at least 66% of the prompts to match exactly # Upon failure, inspect the outputs to check for inaccuracy. assert matches > int(0.66 * len(ref_outputs)) + + +def test_suffix_correctness( + test_prompts: list[list[dict[str, Any]]], + sampling_config: SamplingParams, + model_name: str, +): + ''' + Compare the outputs of a original LLM and a speculative LLM + should be the same when using ngram speculative decoding. + ''' + ref_llm = LLM(model=model_name, max_model_len=1024, enforce_eager=False) + ref_outputs = ref_llm.chat(test_prompts, sampling_config) + del ref_llm + with VllmRunner(model_name, + speculative_config={ + "method": "suffix", + "num_speculative_tokens": 8, + }, + max_model_len=1024, + enforce_eager=False) as runner: + spec_outputs = runner.model.chat(test_prompts, sampling_config) + matches = 0 + misses = 0 + for ref_output, spec_output in zip(ref_outputs, spec_outputs): + if ref_output.outputs[0].text == spec_output.outputs[0].text: + matches += 1 + else: + misses += 1 + print(f"ref_output: {ref_output.outputs[0].text}") + print(f"spec_output: {spec_output.outputs[0].text}") + + # Heuristic: expect at least 70% of the prompts to match exactly + # Upon failure, inspect the outputs to check for inaccuracy. + assert matches > int(0.66 * len(ref_outputs)) + + +def test_suffix_acceptance( + test_prompts: list[list[dict[str, Any]]], + sampling_config: SamplingParams, + model_name: str, +): + ''' + Check that suffix decoding caching takes effect and improves acceptance + lengths and acceptance rates over multiple runs of the same prompts. + ''' + num_draft = [] + num_accept = [] + with VllmRunner(model_name, + speculative_config={ + "method": "suffix", + "suffix_decoding_max_spec_factor": 2.0, + "suffix_decoding_max_cached_requests": 1000, + "num_speculative_tokens": 10, + }, + max_model_len=1024, + disable_log_stats=False, + enforce_eager=False) as runner: + for i in range(10): + runner.model.chat(test_prompts[i], sampling_config) + metrics = runner.model.get_metrics() + for metric in metrics: + print(metric) + if metric.name == "vllm:spec_decode_num_draft_tokens": + num_draft.append(metric.value) + if metric.name == "vllm:spec_decode_num_accepted_tokens": + num_accept.append(metric.value) + # Calculate the acceptance rates for the first and last runs. + first_accept_tokens = num_accept[0] + first_draft_tokens = num_draft[0] + first_accept_rate = first_accept_tokens / first_draft_tokens + + # Take the diff since the stats are cumulative. + last_accept_tokens = num_accept[-1] - num_accept[-2] + last_draft_tokens = num_draft[-1] - num_draft[-2] + last_accept_rate = last_accept_tokens / last_draft_tokens + + # Expect the acceptance length to improve. + assert first_accept_tokens < last_accept_tokens + + # Expect the acceptance rate to improve. + assert first_accept_rate < last_accept_rate + + # Heuristic: expect at least 80% acceptance rate at the end. + assert last_accept_rate > 0.60 diff --git a/vllm_ascend/patch/platform/patch_config.py b/vllm_ascend/patch/platform/patch_config.py index 0e8642d1..b798fda3 100644 --- a/vllm_ascend/patch/platform/patch_config.py +++ b/vllm_ascend/patch/platform/patch_config.py @@ -28,6 +28,8 @@ def __post_init__(self): self.quantization = self.target_model_config.quantization elif self.method in ("ngram", "[ngram]"): self.model = "ngram" + elif self.method == "suffix": + self.model = "suffix" else: raise ValueError("num_speculative_tokens was provided but without " "speculative model.") @@ -70,6 +72,10 @@ def __post_init__(self): # draft related config as None here. self.draft_model_config = self.target_model_config self.draft_parallel_config = self.target_parallel_config + elif self.method == "suffix": + self.draft_model_config = self.target_model_config + self.draft_parallel_config = self.target_parallel_config + self._validate_suffix_decoding() else: self.prompt_lookup_max = 0 self.prompt_lookup_min = 0 diff --git a/vllm_ascend/spec_decode/__init__.py b/vllm_ascend/spec_decode/__init__.py index 6abe8777..a8d44875 100644 --- a/vllm_ascend/spec_decode/__init__.py +++ b/vllm_ascend/spec_decode/__init__.py @@ -19,6 +19,7 @@ from vllm_ascend.spec_decode.eagle_proposer import EagleProposer from vllm_ascend.spec_decode.mtp_proposer import MtpProposer from vllm_ascend.spec_decode.ngram_proposer import NgramProposer +from vllm_ascend.spec_decode.suffix_proposer import SuffixDecodingProposer from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer @@ -35,6 +36,8 @@ def get_spec_decode_method(method, if is_torchair_graph: return TorchairMtpProposer(vllm_config, device, runner) return MtpProposer(vllm_config, device, runner) + elif method == 'suffix': + return SuffixDecodingProposer(vllm_config, device, runner) else: raise ValueError("Unknown speculative decoding method: " f"{method}") diff --git a/vllm_ascend/spec_decode/interface.py b/vllm_ascend/spec_decode/interface.py index 5fdb4945..098f171f 100644 --- a/vllm_ascend/spec_decode/interface.py +++ b/vllm_ascend/spec_decode/interface.py @@ -14,6 +14,7 @@ class SpecDcodeType(enum.Enum): EAGLE = 1 EAGLE3 = 2 MTP = 4 + SUFFIX = 5 class Proposer: @@ -51,4 +52,4 @@ class Proposer: attn_metadata=None, aux_hidden_states: torch.Tensor = None): """Called by execute_model in model_runner""" - raise NotImplementedError + raise NotImplementedError \ No newline at end of file diff --git a/vllm_ascend/spec_decode/suffix_proposer.py b/vllm_ascend/spec_decode/suffix_proposer.py new file mode 100644 index 00000000..e6070449 --- /dev/null +++ b/vllm_ascend/spec_decode/suffix_proposer.py @@ -0,0 +1,45 @@ +import torch +from vllm.config import CUDAGraphMode +from vllm.v1.spec_decode.suffix_decoding import \ + SuffixDecodingProposer as VllmSuffixDecodingProposer + +from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType + + +class SuffixDecodingProposer(VllmSuffixDecodingProposer, Proposer): + + def __init__(self, vllm_config, device, runner): + super().__init__(vllm_config) + self.name = SpecDcodeType.SUFFIX + self.device = device + self.runner = runner + + def load_model(self, *args, **kwargs): + # No model to load. + pass + + @torch.inference_mode() + def dummy_run(self, + num_tokens, + with_prefill=None, + skip_attn=None, + num_reqs=None, + num_tokens_across_dp=None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor=None, + dummy_compute_logits=lambda hidden_states: None): + pass + + def generate_token_ids(self, + valid_sampled_token_ids, + sampling_metadata=None, + scheduler_output=None, + spec_decode_metadata=None, + positions=None, + num_scheduled_tokens=None, + hidden_states=None, + attn_metadata=None, + aux_hidden_states=None) -> list[list[int]]: + draft_token_ids = self.propose(self.runner.input_batch, + valid_sampled_token_ids) + return draft_token_ids diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index c85b0ee3..f5c3bb35 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -96,6 +96,7 @@ from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin @@ -630,7 +631,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): # Set up speculative decoding. self.spec_attn_mask = None self.drafter: Optional[Union[NgramProposer, EagleProposer, MtpProposer, - TorchairMtpProposer]] = None + TorchairMtpProposer, + SuffixDecodingProposer]] = None self.actual_seq_lengths_q: list[int] = [] self.decode_token_per_req = 1 if self.speculative_config: