diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index ba9f026..5d09094 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -161,8 +161,9 @@ jobs: if: steps.filter_spec_decode.outputs.speculative_tests_changed == 'true' || github.event_name == 'schedule' run: | if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then + VLLM_USE_MODELSCOPE=true pytest -sv tests/singlecard/spec_decode/e2e/test_v1_spec_decode.py pytest -sv tests/singlecard/spec_decode/e2e/test_mtp_correctness.py # it needs a clean process - pytest -sv tests/singlecard/spec_decode --ignore=tests/singlecard/spec_decode/e2e/test_mtp_correctness.py + pytest -sv tests/singlecard/spec_decode --ignore=tests/singlecard/spec_decode/e2e/test_mtp_correctness.py --ignore=tests/singlecard/spec_decode/e2e/test_v1_spec_decode.py fi - name: Run vllm-project/vllm test for V0 Engine diff --git a/requirements-dev.txt b/requirements-dev.txt index 4fb45d1..e113c20 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -9,3 +9,4 @@ ray types-jsonschema xgrammar zmq +numba diff --git a/tests/sample/__init__.py b/tests/sample/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/sample/test_rejection_sampler.py b/tests/sample/test_rejection_sampler.py new file mode 100644 index 0000000..a88776f --- /dev/null +++ b/tests/sample/test_rejection_sampler.py @@ -0,0 +1,610 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Optional + +import pytest +import torch +import torch.nn.functional as F +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata + +from vllm_ascend.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID, + AscendRejectionSampler) + +DEVICE = "npu" + + +@pytest.fixture +def rejection_sampler(): + return AscendRejectionSampler() + + +def create_logits_tensor(output_token_ids: list[list[int]], + vocab_size: int = 100) -> torch.Tensor: + """Helper function to create logits tensor that + will produce desired token ids on argmax""" + token_ids = [tokens[:-1] for tokens in output_token_ids] + num_total_tokens = sum(len(tokens) for tokens in token_ids) + logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE) + start_loc = 0 + for tokens in token_ids: + for j, token_id in enumerate(tokens): + logits[start_loc + j, token_id] = 100.0 + start_loc += len(tokens) + return logits + + +def create_sampling_metadata( + all_greedy: bool, + temperature: Optional[torch.Tensor] = None, + top_k: Optional[torch.Tensor] = None, + top_p: Optional[torch.Tensor] = None, + generators: Optional[dict[int, Any]] = None, +) -> SamplingMetadata: + """Create a v1 sampling metadata object with all_greedy set + to the given value. Either all greedy or all random sampling + is used. + """ + generators = generators or {} + if all_greedy: + temperature = None + else: + assert temperature is not None + + return SamplingMetadata( + temperature=temperature, + all_greedy=all_greedy, + all_random=not all_greedy, + top_p=top_p, + top_k=top_k, + min_p=torch.empty(1, ), + generators=generators, + max_num_logprobs=0, + no_penalties=False, + prompt_token_ids=None, + frequency_penalties=torch.tensor([]), + presence_penalties=torch.tensor([]), + repetition_penalties=torch.tensor([]), + output_token_ids=[], + min_tokens={}, + logit_bias=[None], + allowed_token_ids_mask=None, + bad_words_token_ids={}, + ) + + +########################### Tests for Greedy Sampling ################### +def test_perfect_match(rejection_sampler): + """Test when output tokens perfectly match speculated tokens""" + spec_tokens = [[1, 2, 3]] + output_tokens = [[1, 2, 3, 4]] # 4 is the bonus token + + metadata = create_sampling_metadata(all_greedy=True) + logits = create_logits_tensor(output_tokens) + bonus_token_tensor = torch.tensor([output_tokens[0][-1]], + device=logits.device) + spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, + device=logits.device) + + output = rejection_sampler( + spec_decode_metadata, + draft_probs=None, + target_logits=logits, + bonus_token_ids=bonus_token_tensor, + sampling_metadata=metadata, + ) + expected = torch.tensor([[1, 2, 3, 4]], + dtype=torch.int, + device=logits.device) + assert torch.equal(output, expected) + + +def test_early_mismatch(rejection_sampler): + """Test when there's an early mismatch in tokens""" + spec_tokens = [[1, 2, 3]] + output_tokens = [[1, 5, 3, 4]] # Mismatch at position 1 + + metadata = create_sampling_metadata(all_greedy=True) + logits = create_logits_tensor(output_tokens) + bonus_token_tensor = torch.tensor([output_tokens[0][-1]], + device=logits.device) + spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, + device=logits.device) + + output = rejection_sampler( + spec_decode_metadata, + draft_probs=None, + target_logits=logits, + bonus_token_ids=bonus_token_tensor, + sampling_metadata=metadata, + ) + expected = torch.tensor( + [[1, 5, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]], + dtype=torch.int, + device=logits.device, + ) + assert torch.equal(output, expected) + + +def test_multiple_sequences(rejection_sampler): + """Test handling multiple sequences of speculated tokens""" + spec_tokens = [[1, 2], [3]] + output_tokens = [[1, 2, 5], [3, + 4]] # Two sequences with bonus tokens 5 and 4 + + metadata = create_sampling_metadata(all_greedy=True) + logits = create_logits_tensor(output_tokens) + bonus_token_tensor = torch.tensor( + [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device) + spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, + device=logits.device) + + output = rejection_sampler( + spec_decode_metadata, + draft_probs=None, + target_logits=logits, + bonus_token_ids=bonus_token_tensor, + sampling_metadata=metadata, + ) + expected = torch.tensor([[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]], + dtype=torch.int, + device=logits.device) + assert torch.equal(output, expected) + + +def test_single_token_sequence(rejection_sampler): + """Test handling sequences with single token""" + spec_tokens = [[1]] + output_tokens = [[1, 2]] # Single token with bonus token 2 + + metadata = create_sampling_metadata(all_greedy=True) + logits = create_logits_tensor(output_tokens) + bonus_token_tensor = torch.tensor([output_tokens[0][-1]], + device=logits.device) + spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, + device=logits.device) + + output = rejection_sampler( + spec_decode_metadata, + draft_probs=None, + target_logits=logits, + bonus_token_ids=bonus_token_tensor, + sampling_metadata=metadata, + ) + expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device) + assert torch.equal(output, expected) + + +def test_empty_sequence(rejection_sampler): + """Test handling empty sequence of speculated tokens""" + spec_tokens: list[list[int]] = [[]] + output_tokens = [[5]] # Just the bonus token + + metadata = create_sampling_metadata(all_greedy=True) + logits = create_logits_tensor(output_tokens) + bonus_token_tensor = torch.tensor([output_tokens[0][-1]], + device=logits.device) + spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, + device=logits.device) + + output = rejection_sampler( + spec_decode_metadata, + draft_probs=None, + target_logits=logits, + bonus_token_ids=bonus_token_tensor, + sampling_metadata=metadata, + ) + expected = torch.tensor([[5]], dtype=torch.int, device=logits.device) + assert torch.equal(output, expected) + + +def test_multiple_mismatches(rejection_sampler): + """Test handling multiple sequences with mismatches""" + spec_tokens = [[1, 2, 3], [4, 5, 6]] + output_tokens = [[1, 2, 7, 6], [4, 8, 6, + 9]] # Mismatches in both sequences + + metadata = create_sampling_metadata(all_greedy=True) + logits = create_logits_tensor(output_tokens) + bonus_token_tensor = torch.tensor( + [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device) + spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, + device=logits.device) + + output = rejection_sampler( + spec_decode_metadata, + draft_probs=None, + target_logits=logits, + bonus_token_ids=bonus_token_tensor, + sampling_metadata=metadata, + ) + expected = torch.tensor( + [[1, 2, 7, PLACEHOLDER_TOKEN_ID], + [4, 8, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]], + dtype=torch.int, + device=logits.device, + ) + assert torch.equal(output, expected) + + +@pytest.mark.parametrize( + "spec_tokens,output_tokens,expected", + [ + ([[1, 2]], [[1, 2, 3]], [[1, 2, 3]]), # Perfect match with bonus + ([[1]], [[2, 3]], [[2, PLACEHOLDER_TOKEN_ID]]), # First mismatch + ([[1, 2], [3, 4]], [[1, 5, 6], [3, 4, 7]], + [[1, 5, PLACEHOLDER_TOKEN_ID], [3, 4, 7]]), # Mixed matches + ]) +def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens, + expected): + """Parametrized test for various matching scenarios""" + metadata = create_sampling_metadata(all_greedy=True) + logits = create_logits_tensor(output_tokens) + bonus_token_tensor = torch.tensor([tokens[-1] for tokens in output_tokens], + device=logits.device) + spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, + device=logits.device) + + output = rejection_sampler( + spec_decode_metadata, + draft_probs=None, + target_logits=logits, + bonus_token_ids=bonus_token_tensor, + sampling_metadata=metadata, + ) + expected_tensor = torch.tensor(expected, + dtype=torch.int, + device=logits.device) + assert torch.equal(output, expected_tensor) + + +########################### Tests for Random Sampling ################### +@pytest.mark.parametrize("k", [1, 3, 5]) +@pytest.mark.parametrize("vocab_size", [1000]) +@pytest.mark.parametrize("batch_size", [1, 4, 8]) +@pytest.mark.parametrize("frac_seeded", [0.0, 0.5]) +@pytest.mark.parametrize("n_rep", [20]) +def test_deterministic_when_seeded( + rejection_sampler, + k: int, + vocab_size: int, + batch_size: int, + frac_seeded: float, + n_rep: int, +): + num_tokens = batch_size * k + draft_probs = torch.rand(num_tokens, + vocab_size, + dtype=torch.float32, + device=DEVICE) + draft_probs = F.softmax(draft_probs, dim=-1) + target_logits = torch.rand_like(draft_probs) + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64, + device=DEVICE) + draft_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k), + dtype=torch.int64, + device=DEVICE) + + seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded + + results = [] + for _ in range(n_rep): + seeded_seqs = { + i: torch.Generator(device=DEVICE).manual_seed(i) + for i in range(batch_size) if seeded_mask[i] + } + + temperature = torch.ones(batch_size, + dtype=torch.float32, + device=DEVICE) + sampling_metadata = create_sampling_metadata(all_greedy=False, + temperature=temperature, + generators=seeded_seqs) + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + draft_token_ids.tolist(), device=DEVICE) + rep_result = rejection_sampler( + spec_decode_metadata, + draft_probs=draft_probs, + target_logits=target_logits, + bonus_token_ids=bonus_token_ids, + sampling_metadata=sampling_metadata, + ) + + results.append(rep_result) + + for i in range(batch_size): + if seeded_mask[i]: + for j in range(1, n_rep): + assert torch.equal(results[j][i], results[0][i]) + + +def test_rejection_sampling_approximates_target_distribution(): + """Verify rejection sampling approximates target distribution, + despite sampling from a potentially distinct draft distribution. + + This is done by first creating a random target probability + distribution and a random draft probability distribution. We then + sample token ids from the rejection sampler using these draft + and target distributions. The samples are used to estimate + the output probability distribution, which we expect to approximate + the target distribution. + + A basic distance metric is used to determine similarity between + distributions. + + We expect that as we increase the number of samples, + the distance between the observed distribution and the target + distribution decreases. To measure this, we compare the distance + of the observed distribution against both the target distribution + and a uniform random distribution. We expect the distance between + the observed distribution and the target distribution to improve + much more than the distance improvement between the observed + distribution and the random distribution. + """ + torch.set_default_device(DEVICE) + vocab_size = 10 + k = 2 + num_reference_probs = 100 + + # Prepare draft, target, and reference probability distributions + draft_probs = F.softmax(torch.rand(vocab_size, dtype=torch.float32), + dim=-1) + target_logits = torch.rand(vocab_size, dtype=torch.float32) + target_probs = F.softmax(target_logits, dim=-1) + reference_probs = F.softmax( + torch.rand(num_reference_probs, vocab_size, dtype=torch.float32), + dim=-1, + ) + + sample_sizes = [10, 100, 1_000, 10_000, 100_000] + distance_wrt_reference: list[float] = [] + distance_wrt_target: list[float] = [] + + for num_samples in sample_sizes: + # Sample using rejection sampling. + rej_sample_probs = estimate_rejection_sampling_pdf( + draft_probs, target_logits, k, vocab_size, num_samples) + rej_sample_probs = rej_sample_probs.to(DEVICE) + + # Average distance from reference probs. + reference_vs_rejsample_dist = torch.dist( + reference_probs, + rej_sample_probs).item() / reference_probs.shape[0] + target_vs_rejsample_dist = torch.dist(target_probs, + rej_sample_probs).item() + + distance_wrt_reference.append(reference_vs_rejsample_dist) + distance_wrt_target.append(target_vs_rejsample_dist) + + relative_change_in_distance_wrt_target = get_ratio_first_to_last( + distance_wrt_target) + relative_change_in_distance_wrt_reference = get_ratio_first_to_last( + distance_wrt_reference) + + print(f"{num_samples=} {target_vs_rejsample_dist=:.05f} " + f"{reference_vs_rejsample_dist=:.05f}") + print(f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} " + f"{relative_change_in_distance_wrt_reference=:.02f}") + + relative_change_in_distance_wrt_target = get_ratio_first_to_last( + distance_wrt_target) + relative_change_in_distance_wrt_reference = get_ratio_first_to_last( + distance_wrt_reference) + + expected_improvement_multiplier = 20 + assert (relative_change_in_distance_wrt_target > + relative_change_in_distance_wrt_reference * + expected_improvement_multiplier) + + +def get_ratio_first_to_last(elements: list[float]) -> float: + return elements[0] / elements[-1] + + +def estimate_rejection_sampling_pdf( + draft_probs: torch.Tensor, + target_logits: torch.Tensor, + k: int, + vocab_size: int, + num_samples: int, +) -> torch.Tensor: + """Estimate the probability distribution of the output tokens + using rejection sampling. + + Args: + draft_probs: Draft probability distribution. + target_logits: Target logits. + num_samples: Number of samples to draw. + + Returns: + Estimated probability distribution of the output tokens. + """ + rejection_sampler = AscendRejectionSampler() + num_tokens = num_samples * k + # Repeat draft probs num_samples * k times. + draft_probs = draft_probs.reshape(1, 1, + vocab_size).repeat(num_samples, k, 1) + + # Repeat target probs num_tokens times. + target_logits = target_logits.reshape(1, vocab_size).repeat(num_tokens, 1) + + # Randomly sample draft token ids from draft probs. + draft_token_ids = torch.multinomial(draft_probs[:, 0, :], + num_samples=k, + replacement=True).reshape( + num_samples, k) + draft_probs = draft_probs.view(num_tokens, vocab_size) + + # Bonus tokens not used but required. + bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64, + device=DEVICE).repeat(num_samples, 1) + + temperature = torch.ones(num_samples, dtype=torch.float32, device=DEVICE) + sampling_metadata = create_sampling_metadata(all_greedy=False, + temperature=temperature) + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + draft_token_ids.tolist(), device=bonus_token_ids.device) + output_token_ids = rejection_sampler( + spec_decode_metadata, + draft_probs=draft_probs, + target_logits=target_logits, + bonus_token_ids=bonus_token_ids, + sampling_metadata=sampling_metadata, + ) + output_token_ids = output_token_ids[:, :-1].flatten() + + hist = torch.histogram(output_token_ids.to(dtype=torch.float, + device="cpu"), + bins=vocab_size, + range=(0, vocab_size), + density=True) + + return hist.hist + + +def _test_masked_logits( + rejection_sampler, + batch_size: int, + num_draft_tokens: int, + vocab_size: int, + target_logits: torch.Tensor, + unmasked_indices: torch.Tensor, + sampling_metadata: SamplingMetadata, +): + # Set up test parameters + num_tokens = batch_size * num_draft_tokens + + # Create random draft probabilities. + draft_probs = torch.rand((num_tokens, vocab_size), + dtype=torch.float32, + device=DEVICE) + draft_probs = F.softmax(draft_probs, dim=-1) + + # Randomly sample draft token ids from draft probs + draft_token_ids = torch.multinomial(draft_probs, num_samples=1) + draft_token_ids = draft_token_ids.reshape(batch_size, num_draft_tokens) + draft_token_ids = draft_token_ids.tolist() + + # Bonus tokens not used but required + bonus_token_ids = torch.zeros((batch_size, 1), + dtype=torch.int64, + device=DEVICE) + + # Create spec decode metadata + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + draft_token_ids, + device=DEVICE, + ) + + # Run rejection sampling + output_token_ids = rejection_sampler( + spec_decode_metadata, + draft_probs=draft_probs, + target_logits=target_logits, + bonus_token_ids=bonus_token_ids, + sampling_metadata=sampling_metadata, + ) + + # Remove bonus tokens and reshape + output_token_ids = output_token_ids[:, :-1].flatten().tolist() + + # Check that all sampled tokens are within the unmasked indices. + for i in range(num_tokens): + token_id = output_token_ids[i] + if token_id == PLACEHOLDER_TOKEN_ID: + continue + assert token_id in unmasked_indices[i] + + +@pytest.mark.parametrize("top_k", [1, 5, 99]) +def test_top_k(rejection_sampler, top_k): + """Test rejection sampling with top-k sampling""" + vocab_size = 100 + batch_size = 100 + num_draft_tokens = 3 + num_tokens = batch_size * num_draft_tokens + + # Randomly create top-k indices. + top_k_indices = [ + torch.randperm(vocab_size, device=DEVICE)[:top_k] + for _ in range(num_tokens) + ] + top_k_indices = torch.stack(top_k_indices) + + # Create logits with the uniform distribution. + target_logits = torch.zeros((num_tokens, vocab_size), device=DEVICE) + + # Increment the logits for top-k indices, a little bit more than the other + # ones. If the masking is effective, the non-topk indices will never be + # sampled despite the small difference in logits. + for i in range(num_tokens): + target_logits[i, top_k_indices[i]] += 0.1 + + # Create sampling metadata + temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE) + sampling_metadata = create_sampling_metadata( + all_greedy=False, + temperature=temperature, + top_k=torch.tensor([top_k] * batch_size, + device=DEVICE, + dtype=torch.int64), + ) + + _test_masked_logits( + rejection_sampler, + batch_size=batch_size, + num_draft_tokens=num_draft_tokens, + vocab_size=vocab_size, + target_logits=target_logits, + unmasked_indices=top_k_indices, + sampling_metadata=sampling_metadata, + ) + + +@pytest.mark.parametrize("top_p", [0.5, 0.9, 0.99]) +def test_top_p(rejection_sampler, top_p): + """Test rejection sampling with top-p sampling""" + vocab_size = 100 + batch_size = 100 + num_draft_tokens = 3 + num_tokens = batch_size * num_draft_tokens + + # Create logits with the uniform distribution. + target_logits = torch.randn((num_tokens, vocab_size), device=DEVICE) + temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE) + rescaled_logits = target_logits / temperature + + logits_sort, logits_idx = rescaled_logits.sort(dim=-1, descending=False) + probs_sort = logits_sort.softmax(dim=-1) + probs_sum = probs_sort.cumsum(dim=-1) + top_p_mask = probs_sum <= 1 - top_p + # at least one + top_p_mask[:, -1] = False + + # Get the top-p indices. + top_p_indices = [] + for i in range(num_tokens): + top_p_indices.append(logits_idx[i][~top_p_mask[i]].tolist()) + + # Create sampling metadata + sampling_metadata = create_sampling_metadata( + all_greedy=False, + temperature=temperature, + top_p=torch.tensor([top_p] * batch_size, + device=DEVICE, + dtype=torch.float32), + ) + + _test_masked_logits( + rejection_sampler, + batch_size=batch_size, + num_draft_tokens=num_draft_tokens, + vocab_size=vocab_size, + target_logits=target_logits, + unmasked_indices=top_p_indices, + sampling_metadata=sampling_metadata, + ) diff --git a/tests/singlecard/spec_decode/e2e/test_v1_spec_decode.py b/tests/singlecard/spec_decode/e2e/test_v1_spec_decode.py new file mode 100644 index 0000000..d7bac41 --- /dev/null +++ b/tests/singlecard/spec_decode/e2e/test_v1_spec_decode.py @@ -0,0 +1,156 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import os +import random +from typing import Any + +import pytest +from vllm import LLM, SamplingParams + +os.environ["VLLM_USE_MODELSCOPE"] = "True" + + +@pytest.fixture +def test_prompts(): + prompt_types = ["repeat", "sentence"] + num_prompts = 100 + prompts = [] + + random.seed(0) + random_prompt_type_choices = random.choices(prompt_types, k=num_prompts) + + # Generate a mixed batch of prompts, some of which can be easily + # predicted by n-gram matching and some which likely cannot. + for kind in random_prompt_type_choices: + word_choices = ["test", "temp", "hello", "where"] + word = random.choice(word_choices) + if kind == "repeat": + prompt = f""" + please repeat the word '{word}' 10 times. + give no other output than the word at least ten times in a row, + in lowercase with spaces between each word and without quotes. + """ + elif kind == "sentence": + prompt = f""" + please give a ten-word sentence that + uses the word {word} at least once. + give no other output than that simple sentence without quotes. + """ + else: + raise ValueError(f"Unknown prompt type: {kind}") + prompts.append([{"role": "user", "content": prompt}]) + + return prompts + + +@pytest.fixture +def sampling_config(): + return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False) + + +@pytest.fixture +def model_name(): + return "LLM-Research/Meta-Llama-3.1-8B-Instruct" + + +def eagle_model_name(): + return "vllm-ascend/EAGLE-LLaMA3.1-Instruct-8B" + + +def eagle3_model_name(): + return "vllm-ascend/EAGLE3-LLaMA3.1-Instruct-8B" + + +def test_ngram_correctness( + monkeypatch: pytest.MonkeyPatch, + test_prompts: list[list[dict[str, Any]]], + sampling_config: SamplingParams, + model_name: str, +): + ''' + Compare the outputs of a original LLM and a speculative LLM + should be the same when using ngram speculative decoding. + ''' + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + ref_llm = LLM(model=model_name, max_model_len=1024) + ref_outputs = ref_llm.chat(test_prompts, sampling_config) + del ref_llm + + spec_llm = LLM( + model=model_name, + speculative_config={ + "method": "ngram", + "prompt_lookup_max": 5, + "prompt_lookup_min": 3, + "num_speculative_tokens": 3, + }, + max_model_len=1024, + ) + spec_outputs = spec_llm.chat(test_prompts, sampling_config) + matches = 0 + misses = 0 + for ref_output, spec_output in zip(ref_outputs, spec_outputs): + if ref_output.outputs[0].text == spec_output.outputs[0].text: + matches += 1 + else: + misses += 1 + print(f"ref_output: {ref_output.outputs[0].text}") + print(f"spec_output: {spec_output.outputs[0].text}") + + # Heuristic: expect at least 70% of the prompts to match exactly + # Upon failure, inspect the outputs to check for inaccuracy. + assert matches > int(0.7 * len(ref_outputs)) + del spec_llm + + +@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"]) +def test_eagle_correctness( + monkeypatch: pytest.MonkeyPatch, + test_prompts: list[list[dict[str, Any]]], + sampling_config: SamplingParams, + model_name: str, + use_eagle3: bool, +): + ''' + Compare the outputs of a original LLM and a speculative LLM + should be the same when using eagle speculative decoding. + ''' + pytest.skip("Not current support for the test.") + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + ref_llm = LLM(model=model_name, max_model_len=2048) + ref_outputs = ref_llm.chat(test_prompts, sampling_config) + del ref_llm + + spec_model_name = eagle3_model_name( + ) if use_eagle3 else eagle_model_name() + spec_llm = LLM( + model=model_name, + trust_remote_code=True, + speculative_config={ + "method": "eagle3" if use_eagle3 else "eagle", + "model": spec_model_name, + "num_speculative_tokens": 3, + "max_model_len": 2048, + }, + max_model_len=2048, + ) + spec_outputs = spec_llm.chat(test_prompts, sampling_config) + matches = 0 + misses = 0 + for ref_output, spec_output in zip(ref_outputs, spec_outputs): + if ref_output.outputs[0].text == spec_output.outputs[0].text: + matches += 1 + else: + misses += 1 + print(f"ref_output: {ref_output.outputs[0].text}") + print(f"spec_output: {spec_output.outputs[0].text}") + + # Heuristic: expect at least 66% of the prompts to match exactly + # Upon failure, inspect the outputs to check for inaccuracy. + assert matches > int(0.66 * len(ref_outputs)) + del spec_llm diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 490c819..36ac972 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -110,6 +110,7 @@ class AscendMetadata: block_tables: torch.Tensor # (batch_size,). The sequence length per sequence. Sequence length means # the computed tokens + new tokens None if it is a decoding. + query_start_loc: torch.Tensor query_lens: torch.Tensor seq_lens: torch.Tensor # Maximum query length in the batch. None for decoding. @@ -149,9 +150,13 @@ class AscendAttentionMetadataBuilder: self.runner.device, non_blocking=True) attn_mask = self.runner.attn_mask attn_state = self.runner.attn_state + query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1] + query_start_loc = query_start_loc_cpu.to(self.runner.device, + non_blocking=True) attn_metadata = AscendMetadata(num_actual_tokens=num_actual_tokens, block_tables=block_table, + query_start_loc=query_start_loc, query_lens=query_lens, seq_lens=seq_lens, max_query_len=max_query_len, diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 445a167..4604c88 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -158,4 +158,18 @@ # - https://github.com/vllm-project/vllm-ascend/pull/395 # Future Plan: # Revert it when the related pr is merged in vllm and vllm-ascend. -# \ No newline at end of file +# +# ** File: worker/patch_common/patch_eagle.py ** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.v1.spec_decode.eagle.prepare_inputs` +# Why: +# We need to use the patched `prepare_input_kernel` in `eagle.prepare_inputs`. +# The mainly reason to overwrite `prepare_input_kernel` is this is a triton +# kernel, ascend is now not support triton kernel. +# How: +# Re-implementation the `prepare_input_kernel` triton kernel by pytorch +# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit.... +# - Ascend doesn't support triton +# Future Plan: +# Revert it when the ascend support triton kernel. +# diff --git a/vllm_ascend/patch/worker/patch_common/__init__.py b/vllm_ascend/patch/worker/patch_common/__init__.py index 9369596..5b55ac6 100644 --- a/vllm_ascend/patch/worker/patch_common/__init__.py +++ b/vllm_ascend/patch/worker/patch_common/__init__.py @@ -19,6 +19,7 @@ # patch files. import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa +import vllm_ascend.patch.worker.patch_common.patch_eagle # noqa import vllm_ascend.patch.worker.patch_common.patch_metrics # noqa import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa diff --git a/vllm_ascend/patch/worker/patch_common/patch_eagle.py b/vllm_ascend/patch/worker/patch_common/patch_eagle.py new file mode 100644 index 0000000..f1c2a62 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_common/patch_eagle.py @@ -0,0 +1,70 @@ +# SPDX-License-Identifier: Apache-2.0 +import torch +from vllm.v1.spec_decode.eagle import EagleProposer + + +def prepare_inputs( + # [batch_size + 1] + cu_target_query_lens: torch.Tensor, + # [batch_size] + num_rejected_tokens: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + # cu_target_query_lens: [0, a, a + b, a + b + c] + # num_rejected_tokens: [n1, n2, n3] + # num_tokens_per_req: [a - n1, b - n2, c - n3] + # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] + # token_indices: [0, 1, ..., a - n1 - 1, + # a, a + 1, ..., a + b - n2 - 1, + # a + b, a + b + 1, ..., a + b + c - n3 - 1] + + # [0, a, a + b, a + b + c] -> [a, b, c] + query_len_per_req = (cu_target_query_lens[1:] - cu_target_query_lens[:-1]) + # [a, b, c] -> [a - n1, b - n2, c - n3] + num_tokens_per_req = query_len_per_req - num_rejected_tokens + + cu_num_tokens = torch.empty_like(cu_target_query_lens) + torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) + cu_num_tokens[0] = 0 + + # FIXME(woosuk): Avoid synchronization. + num_tokens = cu_num_tokens[-1].item() + token_indices = torch.empty( + num_tokens, + dtype=torch.int32, + device=cu_num_tokens.device, + ) + + BLOCK_SIZE = 1024 + prepare_input_pytorch( + token_indices, + cu_target_query_lens, + cu_num_tokens, + block_size=BLOCK_SIZE, + ) + return cu_num_tokens, token_indices + + +def prepare_input_pytorch(out_ptr: torch.Tensor, cu_query_lens: torch.Tensor, + cu_num_tokens: torch.Tensor, block_size: int): + num_pids = cu_num_tokens.shape[0] - 1 + + for pid in range(num_pids): + start_pos = cu_num_tokens[pid].item() + end_pos = cu_num_tokens[pid + 1].item() + num_tokens = end_pos - start_pos + + index_start = cu_query_lens[pid].item() + num_blocks = (num_tokens + block_size - 1) + + for i in range(num_blocks): + offset = torch.arange(0, + block_size, + dtype=out_ptr.dtype, + device=cu_query_lens.device) + global_indices = start_pos + offset + values = index_start + offset + mask = offset < num_tokens + out_ptr[global_indices[mask]] = values[mask] + + +EagleProposer.prepare_inputs = prepare_inputs diff --git a/vllm_ascend/sample/__init__.py b/vllm_ascend/sample/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_ascend/sample/rejection_sampler.py b/vllm_ascend/sample/rejection_sampler.py new file mode 100644 index 0000000..384787b --- /dev/null +++ b/vllm_ascend/sample/rejection_sampler.py @@ -0,0 +1,456 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import torch +import torch.nn as nn +import vllm.v1.sample.rejection_sampler as rs +from vllm.logger import init_logger +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.rejection_sampler import (RejectionSampler, compute_probs, + generate_uniform_probs) +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata + +logger = init_logger(__name__) + +PLACEHOLDER_TOKEN_ID = -1 +GREEDY_TEMPERATURE = -1 +# Maximum number of speculative draft tokens allowed per request in a single +# step. This value is chosen to be large enough to handle typical use cases. +MAX_SPEC_LEN = 32 + + +class AscendRejectionSampler(RejectionSampler, nn.Module): + """ + The implementation strictly follows the algorithm described in + https://arxiv.org/abs/2211.17192. + However, we want to clarify the terminology used in the implementation: + accepted tokens: tokens that are accepted based on the relationship + between the "raw" draft and target probabilities. + recovered tokens: tokens that are sampled based on the adjusted probability + distribution, which is derived from both the draft and target + probabilities. + bonus tokens: + If all proposed tokens are accepted, the bonus token is added to the + end of the sequence. The bonus token is only sampled from the target + probabilities. We pass in the bonus tokens instead of sampling them + in the rejection sampler to allow for more flexibility in the + sampling process. For example, we can use top_p, top_k sampling for + bonus tokens, while spec decode does not support these sampling + strategies. + output tokens: + Tokens are finally generated with the rejection sampler. + output tokens = accepted tokens + recovered tokens + bonus tokens + """ + + def forward( + self, + metadata: SpecDecodeMetadata, + # [num_tokens, vocab_size] + draft_probs: Optional[torch.Tensor], + # [num_tokens, vocab_size] + target_logits: torch.Tensor, + # [batch_size, 1] + bonus_token_ids: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + ''' + Args: + metadata: + Metadata for spec decoding. + draft_probs (Optional[torch.Tensor]): + Probability distribution for the draft tokens. Shape is + [num_tokens, vocab_size]. Can be None if probabilities are + not provided, which is the case for ngram spec decode. + target_logits (torch.Tensor): + Target model's logits probability distribution. + Shape is [num_tokens, vocab_size]. Here, probabilities from + different requests are flattened into a single tensor because + this is the shape of the output logits. + NOTE: `target_logits` can be updated in place to save memory. + bonus_token_ids_tensor (torch.Tensor): + A tensor containing bonus tokens. Shape is [batch_size, 1]. + Bonus tokens are added to the end of the sequence if all + proposed tokens are accepted. We generate the bonus tokens + outside of the rejection sampler with the default sampling + strategy. It allows for more flexibility in the sampling + process such as top_p, top_k sampling. + sampling_metadata (SamplingMetadata): + Additional metadata needed for sampling, such as temperature, + top-k/top-p parameters, or other relevant information. + Returns: + output_token_ids (torch.Tensor): + A tensor containing the final output token IDs. + ''' + assert metadata.max_spec_len <= MAX_SPEC_LEN + # [num_tokens, vocab_size] + # NOTE(woosuk): `target_logits` can be updated in place inside the + # `compute_probs` function. + target_probs = compute_probs( + target_logits, + metadata.cu_num_draft_tokens, + sampling_metadata, + ) + + output_token_ids = rejection_sample( + metadata.draft_token_ids, + metadata.num_draft_tokens, + metadata.max_spec_len, + metadata.cu_num_draft_tokens, + draft_probs, + target_probs, + bonus_token_ids, + sampling_metadata, + ) + return output_token_ids + + +def rejection_sample( + # [num_tokens] + draft_token_ids: torch.Tensor, + # [batch_size] + num_draft_tokens: list[int], + max_spec_len: int, + # [batch_size] + cu_num_draft_tokens: torch.Tensor, + # [num_tokens, vocab_size] + draft_probs: Optional[torch.Tensor], + # [num_tokens, vocab_size] + target_probs: torch.Tensor, + # [batch_size, 1] + bonus_token_ids: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> torch.Tensor: + assert draft_token_ids.ndim == 1 + assert draft_probs is None or draft_probs.ndim == 2 + assert cu_num_draft_tokens.ndim == 1 + assert target_probs.ndim == 2 + + batch_size = len(num_draft_tokens) + num_tokens = draft_token_ids.shape[0] + vocab_size = target_probs.shape[-1] + device = target_probs.device + assert draft_token_ids.is_contiguous() + assert draft_probs is None or draft_probs.is_contiguous() + assert target_probs.is_contiguous() + assert bonus_token_ids.is_contiguous() + assert target_probs.shape == (num_tokens, vocab_size) + + # Create output buffer. + output_token_ids = torch.empty( + (batch_size, max_spec_len + 1), + dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids. + device=device, + ) + output_token_ids.fill_(PLACEHOLDER_TOKEN_ID) + + if sampling_metadata.all_greedy: + is_greedy = None + else: + is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE + if not sampling_metadata.all_random: + # Rejection sampling for greedy sampling requests. + target_argmax = target_probs.argmax(dim=-1) + rejection_greedy_sample_pytorch( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + target_argmax, + bonus_token_ids, + is_greedy, + max_spec_len, + # num_warps=1, + ) + if sampling_metadata.all_greedy: + return output_token_ids + + # Generate uniform probabilities for rejection sampling. + # [num_tokens] + uniform_probs = generate_uniform_probs( + num_tokens, + num_draft_tokens, + sampling_metadata.generators, + device, + ) + + # Sample recovered tokens for each position. + # [num_tokens] + recovered_token_ids = sample_recovered_tokens( + max_spec_len, + num_draft_tokens, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + sampling_metadata, + device, + ) + + # Rejection sampling for random sampling requests. + rejection_random_sample_pytorch( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + bonus_token_ids, + recovered_token_ids, + uniform_probs, + is_greedy, + max_spec_len, + vocab_size, + IS_NGRAM=draft_probs is None, + # num_warps=1, + ) + return output_token_ids + + +def expand_batch_to_tokens( + x: torch.Tensor, # [batch_size] + cu_num_tokens: torch.Tensor, # [batch_size] + num_tokens: int, + replace_from: int = 0, + replace_to: int = 0, +) -> torch.Tensor: + """Expand [batch_size] tensor to [num_tokens] tensor based on the number of + tokens per batch in cu_num_tokens. + + For example, if x = [a, b, c] and cu_num_tokens = [2, 5, 6], then + num_tokens = 6, and expanded_x = [a, a, b, b, b, c]. + + Args: + x: [batch_size] tensor to expand. + cu_num_tokens: [batch_size] tensor containing the cumulative number of + tokens per batch. Each element represents the total number of + tokens up to and including that batch. + num_tokens: Total number of tokens. + replace_from: int = 0 + Value to be replaced if it is found in x. + replace_to: int = 0 + Value to replace with when replace_from is found. + Returns: + expanded_x: [num_tokens] tensor. + """ + batch_size = x.shape[0] + assert cu_num_tokens.shape[0] == batch_size + expanded_x = x.new_empty(num_tokens) + expand_pytorch( + expanded_x, + x, + cu_num_tokens, + replace_from, + replace_to, + MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation. + ) + return expanded_x + + +def sample_recovered_tokens( + max_spec_len: int, + num_draft_tokens: list[int], + # [batch_size] + cu_num_draft_tokens: torch.Tensor, + # [num_tokens] + draft_token_ids: torch.Tensor, + # [num_tokens, vocab_size] + draft_probs: Optional[torch.Tensor], + # [num_tokens, vocab_size] + target_probs: torch.Tensor, + sampling_metadata: SamplingMetadata, + device: torch.device, +) -> torch.Tensor: + # NOTE(woosuk): Create only one distribution for each request. + batch_size = len(num_draft_tokens) + vocab_size = target_probs.shape[-1] + q = torch.empty( + (batch_size, vocab_size), + dtype=torch.float32, + device=device, + ) + q.exponential_() + for i, generator in sampling_metadata.generators.items(): + # Do not generate random numbers for requests with no draft tokens. + # This can be important for reproducibility. + if num_draft_tokens[i] > 0: + q[i].exponential_(generator=generator) + + recovered_token_ids = torch.empty_like(draft_token_ids) + sample_recovered_tokens_pytorch( + recovered_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + q, + vocab_size, + IS_NGRAM=draft_probs is None, + ) + return recovered_token_ids + + +def rejection_greedy_sample_pytorch( + output_token_ids, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens, # [batch_size] + draft_token_ids, # [num_tokens] + target_argmax, # [num_tokens] + bonus_token_ids, # [batch_size] + is_greedy=None, # [batch_size] or None + max_spec_len=None, +): + batch_size = output_token_ids.shape[0] + + if is_greedy is None: + is_greedy = torch.ones(batch_size, + dtype=torch.bool, + device=output_token_ids.device) + + for req_idx in range(batch_size): + if not is_greedy[req_idx]: + continue + + if req_idx == 0: + start_idx = 0 + else: + start_idx = cu_num_draft_tokens[req_idx - 1].item() + end_idx = cu_num_draft_tokens[req_idx].item() + num_draft_tokens = end_idx - start_idx + + rejected = False + for pos in range(num_draft_tokens): + if not rejected: + draft_token_id = draft_token_ids[start_idx + pos].item() + target_argmax_id = target_argmax[start_idx + pos].item() + + output_token_ids[req_idx, pos] = target_argmax_id + + if draft_token_id != target_argmax_id: + rejected = True + + if not rejected: + bonus_token_id = bonus_token_ids[req_idx].item() + output_token_ids[req_idx, num_draft_tokens] = bonus_token_id + + +def rejection_random_sample_pytorch( + output_token_ids, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens, # [batch_size] + draft_token_ids, # [num_tokens] + draft_probs, # [num_tokens, vocab_size] or None + target_probs, # [num_tokens, vocab_size] + bonus_token_ids, # [batch_size] + recovered_token_ids, # [num_tokens] + uniform_probs, # [num_tokens] + is_greedy, # [batch_size] + max_spec_len, + vocab_size, + IS_NGRAM=False, +): + batch_size = output_token_ids.shape[0] + + for req_idx in range(batch_size): + if is_greedy[req_idx]: + continue + + if req_idx == 0: + start_idx = 0 + else: + start_idx = cu_num_draft_tokens[req_idx - 1].item() + end_idx = cu_num_draft_tokens[req_idx].item() + num_draft_tokens = end_idx - start_idx + + rejected = False + for pos in range(num_draft_tokens): + if not rejected: + draft_token_id = draft_token_ids[start_idx + pos].item() + + if IS_NGRAM: + draft_prob = 1.0 + else: + draft_prob = draft_probs[start_idx + pos, + draft_token_id].item() + + target_prob = target_probs[start_idx + pos, + draft_token_id].item() + uniform_prob = uniform_probs[start_idx + pos].item() + + if draft_prob > 0 and target_prob / draft_prob >= uniform_prob: + token_id = draft_token_id + else: + rejected = True + token_id = recovered_token_ids[start_idx + pos].item() + + output_token_ids[req_idx, pos] = token_id + + if not rejected: + bonus_token_id = bonus_token_ids[req_idx].item() + output_token_ids[req_idx, num_draft_tokens] = bonus_token_id + + +def expand_pytorch( + output_ptr, # [num_tokens] + input_ptr, # [batch_size] + cu_num_tokens_ptr, # [batch_size] + replace_from, + replace_to, + MAX_NUM_TOKENS, +): + batch_size = len(input_ptr) + + for req_idx in range(batch_size): + start_idx = 0 if req_idx == 0 else cu_num_tokens_ptr[req_idx - 1] + end_idx = cu_num_tokens_ptr[req_idx] + num_tokens = end_idx - start_idx + + src_val = input_ptr[req_idx] + src_val = replace_to if src_val == replace_from else src_val + + offset = torch.arange(MAX_NUM_TOKENS, device=num_tokens.device) + mask = offset < num_tokens + + output_slice = start_idx + offset[mask] + output_ptr[output_slice] = src_val + + +def sample_recovered_tokens_pytorch( + output_token_ids, # [num_tokens] + cu_num_draft_tokens, # [batch_size] + draft_token_ids, # [num_tokens] + draft_probs, # [num_tokens, vocab_size] or None + target_probs, # [num_tokens, vocab_size] + q, # [batch_size, vocab_size] + vocab_size, + IS_NGRAM=False, +): + batch_size = len(cu_num_draft_tokens) + + for req_idx in range(batch_size): + start_idx = 0 if req_idx == 0 else cu_num_draft_tokens[req_idx - 1] + end_idx = cu_num_draft_tokens[req_idx] + num_draft_tokens = end_idx - start_idx + + for pos in range(num_draft_tokens): + token_idx = start_idx + pos + + if IS_NGRAM: + draft_token_id = draft_token_ids[token_idx] + orig_prob = target_probs[token_idx, draft_token_id] + target_probs[token_idx, draft_token_id] = 0 + prob = target_probs[token_idx].clone() + else: + draft_p = draft_probs[token_idx].clone() + target_p = target_probs[token_idx].clone() + prob = torch.maximum(target_p - draft_p, + torch.tensor(0.0, device=target_p.device)) + + q_values = torch.full((vocab_size, ), + float('-inf'), + device=q.device) + q_values[:vocab_size] = q[req_idx, :vocab_size] + + recovered_id = torch.argmax(prob / q_values).item() + output_token_ids[token_idx] = recovered_id + + if IS_NGRAM: + target_probs[token_idx, draft_token_id] = orig_prob + + +rs.expand_batch_to_tokens = expand_batch_to_tokens diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index d910750..2ee7426 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -47,7 +47,12 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput +from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import Sampler +from vllm.v1.spec_decode.eagle import EagleProposer +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.spec_decode.utils import is_spec_decode_supported from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin @@ -55,6 +60,7 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm_ascend.attention.attention import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.platform import NPUPlatform +from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler from vllm_ascend.utils import vllm_version_is if TYPE_CHECKING: @@ -110,6 +116,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.model_config = vllm_config.model_config self.lora_config = vllm_config.lora_config self.scheduler_config = vllm_config.scheduler_config + self.speculative_config = vllm_config.speculative_config self.chunked_prefill_enabled = vllm_config.scheduler_config.chunked_prefill_enabled self.device = device @@ -202,6 +209,21 @@ class NPUModelRunner(LoRAModelRunnerMixin): # req_id -> (input_id -> encoder_output) self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {} + # Set up speculative decoding. + self.use_spec_decode = False + if self.speculative_config: + self.use_spec_decode = True + if get_pp_group().is_last_rank: + if self.speculative_config.method == "ngram": + self.drafter = NgramProposer(self.vllm_config) + elif self.speculative_config.method == "eagle": + self.drafter = EagleProposer(self.vllm_config, + self.device) # type: ignore + else: + raise ValueError("Unknown speculative decoding method: " + f"{self.speculative_config.method}") + self.rejection_sampler = AscendRejectionSampler() + # Request states. self.requests: Dict[str, CachedRequestState] = {} # Persistent batch. @@ -511,7 +533,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> tuple[SpecDecodeMetadata, torch.Tensor, SpecDecodeMetadata, + torch.Tensor, int, torch.Tensor]: # Check input valid total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 @@ -523,6 +546,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_input_tokens = self.vllm_config.pad_for_cudagraph( total_num_scheduled_tokens) else: + # Eager mode. num_input_tokens = total_num_scheduled_tokens modified_batch = self.attn_metadata_builder.reorder_batch( @@ -615,6 +639,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): common_prefix_len=None, **extra_builder_kwargs, ) + attn_metadata.num_input_tokens = num_input_tokens # Prepare input_ids token_indices = (positions_np + @@ -670,7 +695,106 @@ class NPUModelRunner(LoRAModelRunnerMixin): **model_kwargs, ) - return hidden_states[sample_indices] + use_spec_decode = len( + scheduler_output.scheduled_spec_decode_tokens) > 0 + if not use_spec_decode: + # NOTE(woosuk): Due to chunked prefills, the batch may contain + # partial requests. While we should not sample any token + # from these partial requests, we do so for simplicity. + # We will ignore the sampled tokens from the partial requests. + # TODO: Support prompt logprobs. + spec_decode_metadata = None + else: + # Get the number of draft tokens for each request. + # Iterate over the dictionary rather than all requests since not all + # requests have draft tokens. + num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) + for req_id, draft_token_ids in ( + scheduler_output.scheduled_spec_decode_tokens.items()): + req_idx = self.input_batch.req_id_to_index[req_id] + num_draft_tokens[req_idx] = len(draft_token_ids) + + spec_decode_metadata = self._calc_spec_decode_metadata( + num_draft_tokens, cu_num_tokens) + sample_indices = spec_decode_metadata.logits_indices + + return (attn_metadata, hidden_states, spec_decode_metadata, positions, + total_num_scheduled_tokens, sample_indices) + + def _calc_spec_decode_metadata( + self, + num_draft_tokens: np.ndarray, + cu_num_scheduled_tokens: np.ndarray, + ) -> SpecDecodeMetadata: + # Inputs: + # cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209] + # num_draft_tokens: [ 3, 0, 2, 0, 1] + # Outputs: + # cu_num_draft_tokens: [ 3, 3, 5, 5, 6] + # logits_indices: [ 0, 1, 2, 3, 103, 104, 105, 106, + # 206, 207, 208] + # target_logits_indices: [ 0, 1, 2, 5, 6, 9] + # bonus_logits_indices: [ 3, 4, 7, 8, 10] + + # Compute the logits indices. + # [4, 1, 3, 1, 2] + num_sampled_tokens = num_draft_tokens + 1 + # Step 1. [4, 5, 8, 9, 11] + cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32) + total_num_sampled_tokens = cu_num_sampled_tokens[-1] + # Step 2. [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] + cumsums_offsets = np.repeat(cu_num_sampled_tokens - num_sampled_tokens, + num_sampled_tokens) + # Step 3. [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] + arange = self.arange_np[:total_num_sampled_tokens] - cumsums_offsets + # Step 4. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] + logits_indices = np.repeat( + cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens) + # Step 5. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] + logits_indices += arange + + # Compute the bonus logits indices. + bonus_logits_indices = cu_num_sampled_tokens - 1 + + # Compute the draft logits indices. + # [3, 3, 5, 5, 6] + cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32) + total_num_draft_tokens = cu_num_draft_tokens[-1] + # [0, 0, 0, 3, 3, 5] + cumsums_offsets = np.repeat(cu_num_draft_tokens - num_draft_tokens, + num_draft_tokens) + # [0, 1, 2, 0, 1, 0] + arange = self.arange_np[:total_num_draft_tokens] - cumsums_offsets + # [0, 0, 0, 5, 5, 9] + target_logits_indices = np.repeat( + cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens) + # [0, 1, 2, 5, 6, 9] + target_logits_indices += arange + + # TODO: Optimize the CPU -> NPU copy. + cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to( + self.device, non_blocking=True) + logits_indices = torch.from_numpy(logits_indices).to(self.device, + non_blocking=True) + target_logits_indices = torch.from_numpy(target_logits_indices).to( + self.device, non_blocking=True) + bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to( + self.device, non_blocking=True) + + # Compute the draft token ids. + # draft_token_indices: [ 1, 2, 3, 105, 106, 208] + draft_token_ids = self.input_ids[logits_indices] + draft_token_ids = draft_token_ids[target_logits_indices + 1] + + metadata = SpecDecodeMetadata( + draft_token_ids=draft_token_ids, + num_draft_tokens=num_draft_tokens.tolist(), + cu_num_draft_tokens=cu_num_draft_tokens, + target_logits_indices=target_logits_indices, + bonus_logits_indices=bonus_logits_indices, + logits_indices=logits_indices, + ) + return metadata def apply_grammar_bitmask( self, @@ -726,6 +850,30 @@ class NPUModelRunner(LoRAModelRunnerMixin): ) return logits.to(self.device).to(logits_dtype) + def _get_spec_token_ids( + self, + valid_sampled_token_ids: list[list[int]], + sampling_metadata: SamplingMetadata, + scheduler_output: "SchedulerOutput", + spec_decode_metadata: SpecDecodeMetadata, + positions: torch.Tensor, + num_scheduled_tokens: int, + hidden_states: torch.Tensor, + attn_metadata: SpecDecodeMetadata, + ) -> Optional[list[list[int]]]: + if not self.use_spec_decode: + # Speculative decoding is not enabled. + spec_token_ids = None + elif self.speculative_config.method == "ngram": + assert isinstance(self.drafter, NgramProposer) + spec_token_ids = self._generate_draft_token_ids( + valid_sampled_token_ids, sampling_metadata) + elif self.speculative_config.method == "eagle": + raise NotImplementedError( + "eagle method for spec decode doesn't work on vllm-ascend currently" + ) + return spec_token_ids + @torch.inference_mode() def execute_model( self, @@ -736,9 +884,11 @@ class NPUModelRunner(LoRAModelRunnerMixin): if not scheduler_output.total_num_scheduled_tokens: # Return empty ModelRunnerOuptut if there's no work to do. return EMPTY_MODEL_RUNNER_OUTPUT - hidden_states = self._process_reqs(scheduler_output, - intermediate_tensors) - logits = self.model.compute_logits(hidden_states, None) + (attn_metadata, hidden_states, spec_decode_metadata, positions, + num_scheduled_tokens, + sample_indices) = (self._process_reqs(scheduler_output, + intermediate_tensors)) + logits = self.model.compute_logits(hidden_states[sample_indices], None) # Apply structured output bitmasks if present if scheduler_output.grammar_bitmask is not None: @@ -746,10 +896,35 @@ class NPUModelRunner(LoRAModelRunnerMixin): # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata - sampler_output = self.sampler( - logits=logits, - sampling_metadata=sampling_metadata, - ) + if spec_decode_metadata is None: + sampler_output = self.sampler( + logits=logits, + sampling_metadata=sampling_metadata, + ) + else: + # When indexing with a tensor (bonus_logits_indices), PyTorch + # creates a new tensor with separate storage from the original + # logits tensor. This means any in-place operations on bonus_logits + # won't affect the original logits tensor. + bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] + sampler_output = self.sampler( + logits=bonus_logits, + sampling_metadata=sampling_metadata, + ) + bonus_token_ids = sampler_output.sampled_token_ids + + # Just like `bonus_logits`, `target_logits` is a new tensor with + # separate storage from the original `logits` tensor. Therefore, + # it is safe to update `target_logits` in place. + target_logits = logits[spec_decode_metadata.target_logits_indices] + output_token_ids = self.rejection_sampler( + spec_decode_metadata, + None, # draft_probs + target_logits, + bonus_token_ids, + sampling_metadata, + ) + sampler_output.sampled_token_ids = output_token_ids # TODO(woosuk): The following loop can be slow since it iterates over # the requests one by one. Optimize. @@ -776,12 +951,29 @@ class NPUModelRunner(LoRAModelRunnerMixin): if max_gen_len == 1: # No spec decode tokens. valid_sampled_token_ids = sampled_token_ids.tolist() + else: + # Includes spec decode tokens. + valid_sampled_token_ids = self.rejection_sampler.parse_output( + sampled_token_ids, + self.input_batch.vocab_size, + ) + + spec_token_ids = self._get_spec_token_ids( + valid_sampled_token_ids, + sampling_metadata, + scheduler_output, + spec_decode_metadata, + positions, + num_scheduled_tokens, + hidden_states, + attn_metadata, + ) model_runner_output = ModelRunnerOutput( req_ids=self.input_batch.req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=valid_sampled_token_ids, - spec_token_ids=None, + spec_token_ids=spec_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict={}, ) @@ -968,6 +1160,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): with DeviceMemoryProfiler() as m: # noqa: SIM117 self.model = get_model(vllm_config=self.vllm_config) + if hasattr(self, "drafter"): + logger.info("Loading drafter model...") + self.drafter.load_model(self.model) if self.lora_config: self.model = self.load_lora_model(self.model, self.model_config, @@ -1132,3 +1327,35 @@ class NPUModelRunner(LoRAModelRunnerMixin): # This usually takes 5~20 seconds. logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, npu_graph_size / (1 << 30)) + + def _generate_draft_token_ids( + self, + sampled_token_ids: list[list[int]], + sampling_metadata: SamplingMetadata, + ) -> list[list[int]]: + # TODO(woosuk): Optimize. + draft_token_ids: list[list[int]] = [] + for i, sampled_ids in enumerate(sampled_token_ids): + num_sampled_ids = len(sampled_ids) + if not num_sampled_ids: + # Skip speculative decoding. + draft_token_ids.append([]) + continue + + # Skip requests that require top-p, top-k, etc. + req_id = self.input_batch.req_ids[i] + if not is_spec_decode_supported(req_id, self.input_batch): + draft_token_ids.append([]) + continue + + # Add sampled_token_ids to token_ids_cpu. + start_idx = self.input_batch.num_tokens_no_spec[i] + end_idx = start_idx + num_sampled_ids + self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids + drafter_output = self.drafter.propose( + self.input_batch.token_ids_cpu[i, :end_idx]) + if drafter_output is None or len(drafter_output) == 0: + draft_token_ids.append([]) + else: + draft_token_ids.append(drafter_output.tolist()) + return draft_token_ids