Clean up v0.9.1 code (#1672)

vllm has released 0.9.2. This PR drop 0.9.1 support.

- vLLM version: v0.9.1
- vLLM main:
b942c094e3

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-07-09 08:52:24 +08:00
committed by GitHub
parent 0d4bc03946
commit 830332ebfc
23 changed files with 205 additions and 846 deletions

View File

@@ -363,7 +363,6 @@ jobs:
# To avoid oom, we need to run the test in a single process. # To avoid oom, we need to run the test in a single process.
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_multistream_moe pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_multistream_moe
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_topk
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8 pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_dbo pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_dbo
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeekV3_dbo pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeekV3_dbo
@@ -382,7 +381,6 @@ jobs:
# Fixme: run VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py will raise error. # Fixme: run VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py will raise error.
# To avoid oom, we need to run the test in a single process. # To avoid oom, we need to run the test in a single process.
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_topk
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8 pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8
pytest -sv tests/e2e/multicard/test_data_parallel.py pytest -sv tests/e2e/multicard/test_data_parallel.py
pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py \ pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py \

View File

@@ -20,11 +20,11 @@ In `vllm_ascend/patch`, you can see the code structure as follows:
vllm_ascend vllm_ascend
├── patch ├── patch
│ ├── platform │ ├── platform
│ │ ├── patch_0_9_1 │ │ ├── patch_0_9_2
│ │ ├── patch_common │ │ ├── patch_common
│ │ ├── patch_main │ │ ├── patch_main
│ ├── worker │ ├── worker
│ │ ├── patch_0_9_1 │ │ ├── patch_0_9_2
│ │ ├── patch_common │ │ ├── patch_common
│ │ ├── patch_main │ │ ├── patch_main
└─────────── └───────────
@@ -38,15 +38,15 @@ vllm_ascend
In both **platform** and **worker** folder, there are several patch modules. They are used for patching different version of vLLM. In both **platform** and **worker** folder, there are several patch modules. They are used for patching different version of vLLM.
- `patch_0_9_1`: This module is used for patching vLLM 0.9.1. The version is always the nearest version of vLLM. Once vLLM is released, we will drop this patch module and bump to a new version. For example, `patch_0_9_2` is used for patching vLLM 0.9.2. - `patch_0_9_2`: This module is used for patching vLLM 0.9.2. The version is always the nearest version of vLLM. Once vLLM is released, we will drop this patch module and bump to a new version. For example, `patch_0_9_2` is used for patching vLLM 0.9.2.
- `patch_main`: This module is used for patching the code in vLLM main branch. - `patch_main`: This module is used for patching the code in vLLM main branch.
- `patch_common`: This module is used for patching both vLLM 0.9.1 and vLLM main branch. - `patch_common`: This module is used for patching both vLLM 0.9.2 and vLLM main branch.
## How to write a patch ## How to write a patch
Before writing a patch, following the principle above, we should patch the least code. If it's necessary, we can patch the code in either **platform** and **worker** folder. Here is an example to patch `distributed` module in vLLM. Before writing a patch, following the principle above, we should patch the least code. If it's necessary, we can patch the code in either **platform** and **worker** folder. Here is an example to patch `distributed` module in vLLM.
1. Decide which version of vLLM we should patch. For example, after analysis, here we want to patch both 0.9.1 and main of vLLM. 1. Decide which version of vLLM we should patch. For example, after analysis, here we want to patch both 0.9.2 and main of vLLM.
2. Decide which process we should patch. For example, here `distributed` belongs to the vLLM main process, so we should patch `platform`. 2. Decide which process we should patch. For example, here `distributed` belongs to the vLLM main process, so we should patch `platform`.
3. Create the patch file in the right folder. The file should be named as `patch_{module_name}.py`. The example here is `vllm_ascend/patch/platform/patch_common/patch_distributed.py`. 3. Create the patch file in the right folder. The file should be named as `patch_{module_name}.py`. The example here is `vllm_ascend/patch/platform/patch_common/patch_distributed.py`.
4. Write your patch code in the new file. Here is an example: 4. Write your patch code in the new file. Here is an example:
@@ -79,4 +79,4 @@ Before writing a patch, following the principle above, we should patch the least
## Limitation ## Limitation
1. In V1 Engine, vLLM starts three kinds of process: Main process, EngineCore process and Worker process. Now vLLM Ascend only support patch the code in Main process and Worker process by default. If you want to patch the code runs in EngineCore process, you should patch EngineCore process entirely during setup, the entry code is here `vllm.v1.engine.core`. Please override `EngineCoreProc` and `DPEngineCoreProc` entirely. 1. In V1 Engine, vLLM starts three kinds of process: Main process, EngineCore process and Worker process. Now vLLM Ascend only support patch the code in Main process and Worker process by default. If you want to patch the code runs in EngineCore process, you should patch EngineCore process entirely during setup, the entry code is here `vllm.v1.engine.core`. Please override `EngineCoreProc` and `DPEngineCoreProc` entirely.
2. If you are running an edited vLLM code, the version of the vLLM may be changed automatically. For example, if you runs an edited vLLM based on v0.9.1, the version of vLLM may be change to v0.9.2xxx, in this case, the patch for v0.9.1 in vLLM Ascend would not work as expect, because that vLLM Ascend can't distinguish the version of vLLM you're using. In this case, you can set the environment variable `VLLM_VERSION` to specify the version of vLLM you're using, then the patch for v0.9.1 should work. 2. If you are running an edited vLLM code, the version of the vLLM may be changed automatically. For example, if you runs an edited vLLM based on v0.9.n, the version of vLLM may be change to v0.9.nxxx, in this case, the patch for v0.9.n in vLLM Ascend would not work as expect, because that vLLM Ascend can't distinguish the version of vLLM you're using. In this case, you can set the environment variable `VLLM_VERSION` to specify the version of vLLM you're using, then the patch for v0.9.2 should work.

View File

@@ -73,28 +73,6 @@ def test_models_distributed_DeepSeek_multistream_moe():
vllm_model.generate_greedy(example_prompts, max_tokens) vllm_model.generate_greedy(example_prompts, max_tokens)
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE": "1"})
def test_models_distributed_topk() -> None:
example_prompts = [
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
"Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
]
dtype = "half"
sampling_params = SamplingParams(max_tokens=5,
temperature=0.0,
top_k=50,
top_p=0.9)
with VllmRunner(
"deepseek-ai/DeepSeek-V2-Lite",
dtype=dtype,
tensor_parallel_size=4,
distributed_executor_backend="mp",
) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DBO": "1"}) @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DBO": "1"})
def test_models_distributed_DeepSeek_dbo(): def test_models_distributed_DeepSeek_dbo():
example_prompts = ["The president of the United States is"] * 41 example_prompts = ["The president of the United States is"] * 41

View File

@@ -16,7 +16,6 @@ from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output import StructuredOutputManager
from vllm_ascend.core.scheduler import AscendScheduler from vllm_ascend.core.scheduler import AscendScheduler
from vllm_ascend.utils import vllm_version_is
EOS_TOKEN_ID = 50256 EOS_TOKEN_ID = 50256
@@ -140,9 +139,7 @@ def create_requests(num_requests: int,
multi_modal_placeholders=mm_position, multi_modal_placeholders=mm_position,
multi_modal_hashes=None, multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID, eos_token_id=EOS_TOKEN_ID,
**({ pooling_params=None,
"pooling_params": None
} if not vllm_version_is("0.9.1") else {}),
) )
requests.append(request) requests.append(request)
return requests return requests
@@ -201,10 +198,7 @@ def test_schedule(enable_prefix_caching: Optional[bool],
# Test initial scheduling # Test initial scheduling
output = scheduler.schedule() output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests) assert len(output.scheduled_new_reqs) == len(requests)
if vllm_version_is("0.9.1"): assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.scheduled_cached_reqs) == 0
else:
assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0 assert len(output.finished_req_ids) == 0
# Verify all requests are scheduled. # Verify all requests are scheduled.
for req_id, num_tokens in output.num_scheduled_tokens.items(): for req_id, num_tokens in output.num_scheduled_tokens.items():
@@ -241,10 +235,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
output = scheduler.schedule() output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 3 assert len(output.scheduled_new_reqs) == 3
if vllm_version_is("0.9.1"): assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.scheduled_cached_reqs) == 0
else:
assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0 assert len(output.finished_req_ids) == 0
# The first request is scheduled partially - 400. # The first request is scheduled partially - 400.
@@ -264,9 +255,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
**({ pooler_output=[])
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))
scheduler.update_from_output(output, model_runner_output) scheduler.update_from_output(output, model_runner_output)
# Schedule the next step. All three requests are running. # Schedule the next step. All three requests are running.
@@ -274,10 +263,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
output1 = scheduler.schedule() output1 = scheduler.schedule()
assert len(scheduler.running) == 3 assert len(scheduler.running) == 3
assert len(output1.scheduled_new_reqs) == 0 assert len(output1.scheduled_new_reqs) == 0
if vllm_version_is("0.9.1"): assert output1.scheduled_cached_reqs.num_reqs == 3
assert len(output1.scheduled_cached_reqs) == 3
else:
assert output1.scheduled_cached_reqs.num_reqs == 3
assert len(output1.finished_req_ids) == 0 assert len(output1.finished_req_ids) == 0
assert output1.num_scheduled_tokens[requests[0].request_id] == 400 assert output1.num_scheduled_tokens[requests[0].request_id] == 400
assert output1.num_scheduled_tokens[requests[1].request_id] == 400 assert output1.num_scheduled_tokens[requests[1].request_id] == 400
@@ -293,18 +279,13 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
**({ pooler_output=[])
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))
scheduler.update_from_output(output1, model_runner_output) scheduler.update_from_output(output1, model_runner_output)
output2 = scheduler.schedule() output2 = scheduler.schedule()
assert len(scheduler.running) == 3 assert len(scheduler.running) == 3
assert len(output2.scheduled_new_reqs) == 0 assert len(output2.scheduled_new_reqs) == 0
if vllm_version_is("0.9.1"): assert output2.scheduled_cached_reqs.num_reqs == 3
assert len(output2.scheduled_cached_reqs) == 3
else:
assert output2.scheduled_cached_reqs.num_reqs == 3
assert len(output2.finished_req_ids) == 0 assert len(output2.finished_req_ids) == 0
assert output2.num_scheduled_tokens[requests[0].request_id] == 1 assert output2.num_scheduled_tokens[requests[0].request_id] == 1
assert output2.num_scheduled_tokens[requests[1].request_id] == 1 assert output2.num_scheduled_tokens[requests[1].request_id] == 1
@@ -351,9 +332,7 @@ def test_stop_via_update_from_output():
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
**({ pooler_output=[])
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))
scheduler.update_from_output(scheduler_output, model_output) scheduler.update_from_output(scheduler_output, model_output)
@@ -402,9 +381,7 @@ def test_stop_via_update_from_output():
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
**({ pooler_output=[])
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))
scheduler.update_from_output(scheduler_output, model_output) scheduler.update_from_output(scheduler_output, model_output)
@@ -452,9 +429,7 @@ def test_stop_via_update_from_output():
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
**({ pooler_output=[])
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))
scheduler.update_from_output(scheduler_output, model_output) scheduler.update_from_output(scheduler_output, model_output)
@@ -497,9 +472,7 @@ def test_stop_via_update_from_output():
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
**({ pooler_output=[])
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))
scheduler.update_from_output(scheduler_output, model_output) scheduler.update_from_output(scheduler_output, model_output)
@@ -549,9 +522,7 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
**({ pooler_output=[])
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))
scheduler.update_from_output(scheduler_output0, model_runner_output) scheduler.update_from_output(scheduler_output0, model_runner_output)
@@ -569,9 +540,7 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
**({ pooler_output=[])
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))
scheduler.update_from_output(scheduler_output1, model_runner_output) scheduler.update_from_output(scheduler_output1, model_runner_output)
@@ -622,9 +591,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
spec_token_ids=spec_tokens, spec_token_ids=spec_tokens,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
**({ pooler_output=[])
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))
engine_core_outputs = scheduler.update_from_output(output, engine_core_outputs = scheduler.update_from_output(output,
model_runner_output) model_runner_output)
@@ -657,16 +624,13 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
else: else:
assert req_id not in output.scheduled_spec_decode_tokens assert req_id not in output.scheduled_spec_decode_tokens
model_runner_output = ModelRunnerOutput( model_runner_output = ModelRunnerOutput(req_ids=req_ids,
req_ids=req_ids, req_id_to_index=req_to_index,
req_id_to_index=req_to_index, sampled_token_ids=output_tokens,
sampled_token_ids=output_tokens, spec_token_ids=None,
spec_token_ids=None, logprobs=None,
logprobs=None, prompt_logprobs_dict={},
prompt_logprobs_dict={}, pooler_output=[])
**({
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))
engine_core_outputs = scheduler.update_from_output(output, engine_core_outputs = scheduler.update_from_output(output,
model_runner_output) model_runner_output)
@@ -695,9 +659,7 @@ def make_output(scheduler: AscendScheduler):
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
**({ pooler_output=[])
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))
def assert_scheduler_empty(scheduler: AscendScheduler): def assert_scheduler_empty(scheduler: AscendScheduler):

View File

@@ -4,12 +4,12 @@ from typing import Any, Optional
import pytest import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from vllm.v1.sample.logits_processor import LogitsProcessorManager
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm_ascend.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID, from vllm_ascend.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID,
AscendRejectionSampler) AscendRejectionSampler)
from vllm_ascend.utils import vllm_version_is
DEVICE = "npu" DEVICE = "npu"
@@ -50,46 +50,23 @@ def create_sampling_metadata(
temperature = None temperature = None
else: else:
assert temperature is not None assert temperature is not None
if vllm_version_is("0.9.1"):
return SamplingMetadata(
temperature=temperature,
all_greedy=all_greedy,
all_random=not all_greedy,
top_p=top_p,
top_k=top_k,
min_p=torch.empty(1, ),
generators=generators,
max_num_logprobs=0,
no_penalties=False,
prompt_token_ids=None,
frequency_penalties=torch.tensor([]),
presence_penalties=torch.tensor([]),
repetition_penalties=torch.tensor([]),
output_token_ids=[],
min_tokens={},
logit_bias=[None],
allowed_token_ids_mask=None,
bad_words_token_ids={},
)
else:
from vllm.v1.sample.logits_processor import LogitsProcessorManager
return SamplingMetadata(temperature=temperature, return SamplingMetadata(temperature=temperature,
all_greedy=all_greedy, all_greedy=all_greedy,
all_random=not all_greedy, all_random=not all_greedy,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
generators=generators, generators=generators,
max_num_logprobs=0, max_num_logprobs=0,
no_penalties=False, no_penalties=False,
prompt_token_ids=None, prompt_token_ids=None,
frequency_penalties=torch.tensor([]), frequency_penalties=torch.tensor([]),
presence_penalties=torch.tensor([]), presence_penalties=torch.tensor([]),
repetition_penalties=torch.tensor([]), repetition_penalties=torch.tensor([]),
output_token_ids=[], output_token_ids=[],
allowed_token_ids_mask=None, allowed_token_ids_mask=None,
bad_words_token_ids={}, bad_words_token_ids={},
logitsprocs=LogitsProcessorManager()) logitsprocs=LogitsProcessorManager())
########################### Tests for Greedy Sampling ################### ########################### Tests for Greedy Sampling ###################

View File

@@ -19,12 +19,10 @@
from collections.abc import Sequence from collections.abc import Sequence
from typing import Optional from typing import Optional
import pytest
from modelscope import snapshot_download # type: ignore[import-untyped] from modelscope import snapshot_download # type: ignore[import-untyped]
from tests.conftest import HfRunner from tests.conftest import HfRunner
from tests.utils import check_embeddings_close, matryoshka_fy from tests.utils import check_embeddings_close, matryoshka_fy
from vllm_ascend.utils import vllm_version_is
def run_embedding_correctness_test( def run_embedding_correctness_test(
@@ -51,8 +49,6 @@ def test_dummy():
assert True assert True
@pytest.mark.skipif(vllm_version_is("0.9.1"),
reason="vLLM 0.9.1 does not support embed task for v1")
def test_embed_models_correctness(hf_runner, vllm_runner): def test_embed_models_correctness(hf_runner, vllm_runner):
queries = ['What is the capital of China?', 'Explain gravity'] queries = ['What is the capital of China?', 'Explain gravity']

View File

@@ -21,12 +21,9 @@
Run `pytest tests/test_offline_inference.py`. Run `pytest tests/test_offline_inference.py`.
""" """
import os import os
from unittest.mock import patch
import pytest import pytest
import vllm # noqa: F401
from modelscope import snapshot_download # type: ignore[import-untyped] from modelscope import snapshot_download # type: ignore[import-untyped]
from vllm import SamplingParams
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
import vllm_ascend # noqa: F401 import vllm_ascend # noqa: F401
@@ -106,24 +103,3 @@ def test_multimodal(model, prompt_template, vllm_runner):
vllm_model.generate_greedy(prompts=prompts, vllm_model.generate_greedy(prompts=prompts,
images=images, images=images,
max_tokens=64) max_tokens=64)
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE": "1"})
def test_models_topk() -> None:
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(max_tokens=5,
temperature=0.0,
top_k=50,
top_p=0.9)
with VllmRunner("Qwen/Qwen2.5-0.5B-Instruct",
max_model_len=8192,
dtype="float16",
enforce_eager=True,
gpu_memory_utilization=0.7) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)

View File

@@ -1,152 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/entrypoints/llm/test_guided_generate.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional
import pytest
import torch
from vllm.v1.sample.sampler import Sampler # noqa: F401
from vllm_ascend.utils import vllm_version_is
# Set tolerance to 1 for quant ops
DEFAULT_ATOL = 1e-3
DEFAULT_RTOL = 1e-3
def apply_min_p_new(
logits: torch.Tensor,
min_p: torch.Tensor,
) -> torch.Tensor:
"""
Filters logits using adaptive probability thresholding.
"""
if min_p == 0:
return logits
# Convert logits to probability distribution
probability_values = torch.nn.functional.softmax(logits, dim=-1)
# Calculate maximum probabilities per sequence
max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True)
# Reshape min_p for broadcasting
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
# Identify valid tokens using threshold comparison
# Apply mask using boolean indexing
logits = logits.masked_fill(probability_values < adjusted_min_p,
-float('inf'))
return logits
def apply_top_k_top_p(
logits: torch.Tensor,
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""Apply top-k and top-p masks to the logits.
If a top-p is used, this function will sort the logits tensor,
which can be slow for large batches.
The logits tensor may be updated in-place.
"""
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
if k is not None:
# Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B
# Get all the top_k values.
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
top_k_mask = logits_sort < top_k_mask
logits_sort.masked_fill_(top_k_mask, -float("inf"))
if p is not None:
# Apply top-p.
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
# at least one
top_p_mask[:, -1] = False
logits_sort.masked_fill_(top_p_mask, -float("inf"))
# Re-sort the probabilities.
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
return logits
def apply_top_k_top_p_new(
logits: torch.Tensor,
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
batch_size, vocab_size = logits.shape
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
# Apply top-k.
boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1))
top_k_mask = logits_sort < boundary
logits_sort.masked_fill_(top_k_mask, -float("inf"))
if p is not None:
# Apply top-p.
cutoff = top_k_mask.sum(dim=-1).min()
probs_sort = logits_sort.softmax(dim=-1)[:, cutoff:]
probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1)
top_p_mask[:, -1] = True
strides = torch.arange(0,
batch_size * vocab_size,
vocab_size,
device=logits.device)
flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1)
valid_idx = torch.masked_select(flatten_idx, top_p_mask)
logits_flatten = logits.flatten()
valid_logits = torch.index_select(logits_flatten, 0, valid_idx)
logits = torch.empty_like(logits_flatten).fill_(-float("inf"))
logits[valid_idx] = valid_logits
return logits.reshape(batch_size, vocab_size)
# test with leading dimension and merge seqlen and batch_size as num_tokens
@pytest.mark.skipif(not vllm_version_is("0.9.1"),
reason="apply_min_p has been removed after vllm 0.9.1")
@torch.inference_mode()
def test_apply_min_p() -> None:
logits = torch.randn((128, 7168)).npu()
min_p = torch.Tensor([0.01]).npu()
logits_new = apply_min_p_new(logits, min_p)
sampler = Sampler()
logits_old = sampler.apply_min_p(logits, min_p)
# Compare the results.
torch.testing.assert_close(logits_new,
logits_old,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
# test with leading dimension and merge seqlen and batch_size as num_tokens
@torch.inference_mode()
def test_apply_top_k_top_p() -> None:
logits = torch.randn((128, 7168)).npu()
k = torch.Tensor([-1]).int().npu()
p = torch.Tensor([1]).int().npu()
logits_new = apply_top_k_top_p_new(logits, k, p)
logits_old = apply_top_k_top_p(logits, k, p)
# Compare the results.
torch.testing.assert_close(logits_new,
logits_old,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)

View File

@@ -31,7 +31,6 @@ from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output import StructuredOutputManager
from vllm_ascend.core.scheduler import AscendScheduler from vllm_ascend.core.scheduler import AscendScheduler
from vllm_ascend.utils import vllm_version_is
EOS_TOKEN_ID = 50256 EOS_TOKEN_ID = 50256
@@ -131,9 +130,7 @@ def create_requests(num_requests: int,
multi_modal_placeholders=mm_position, multi_modal_placeholders=mm_position,
multi_modal_hashes=None, multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID, eos_token_id=EOS_TOKEN_ID,
**({ pooling_params=None,
"pooling_params": None
} if not vllm_version_is("0.9.1") else {}),
) )
requests.append(request) requests.append(request)
return requests return requests
@@ -192,10 +189,7 @@ def test_schedule(enable_prefix_caching: Optional[bool],
# Test initial scheduling # Test initial scheduling
output = scheduler.schedule() output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests) assert len(output.scheduled_new_reqs) == len(requests)
if vllm_version_is("0.9.1"): assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.scheduled_cached_reqs) == 0
else:
assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0 assert len(output.finished_req_ids) == 0
# Verify all requests are scheduled. # Verify all requests are scheduled.
for req_id, num_tokens in output.num_scheduled_tokens.items(): for req_id, num_tokens in output.num_scheduled_tokens.items():
@@ -245,9 +239,7 @@ def test_stop_via_update_from_output():
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
**({ pooler_output=[])
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))
scheduler.update_from_output(scheduler_output, model_output) scheduler.update_from_output(scheduler_output, model_output)
@@ -294,9 +286,7 @@ def test_stop_via_update_from_output():
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
**({ pooler_output=[])
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))
scheduler.update_from_output(scheduler_output, model_output) scheduler.update_from_output(scheduler_output, model_output)
@@ -342,9 +332,7 @@ def test_stop_via_update_from_output():
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
**({ pooler_output=[])
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))
scheduler.update_from_output(scheduler_output, model_output) scheduler.update_from_output(scheduler_output, model_output)
@@ -386,9 +374,7 @@ def test_stop_via_update_from_output():
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
**({ pooler_output=[])
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))
scheduler.update_from_output(scheduler_output, model_output) scheduler.update_from_output(scheduler_output, model_output)

View File

@@ -1,31 +0,0 @@
import importlib
import os
from unittest import mock
import torch
from vllm.v1.sample.ops import topk_topp_sampler
from tests.ut.base import TestBase
class TestTopKTopPSamplerOptimize(TestBase):
@mock.patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE": "1"})
@mock.patch("torch_npu.npu_top_k_top_p")
def test_npu_topk_topp_called_when_optimized(self, mock_npu_op):
# We have to patch and reload because the patch will take effect
# only after VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE is set.
import vllm_ascend.patch.worker.patch_0_9_1.patch_sampler
importlib.reload(vllm_ascend.patch.worker.patch_0_9_1.patch_sampler)
mock_npu_op.return_value = (torch.randn(1, 3))
sampler = topk_topp_sampler.TopKTopPSampler()
logits = torch.tensor([[1.0, 2.0, 3.0]])
k = torch.tensor([2])
p = torch.tensor([0.9])
generators = {0: torch.Generator()}
generators[0].manual_seed(42)
sampler.forward_native(logits, generators, k, p)
mock_npu_op.assert_called_once_with(logits, p, k)

View File

@@ -32,8 +32,6 @@ from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output import StructuredOutputManager
from vllm_ascend.utils import vllm_version_is
class AscendScheduler(Scheduler): class AscendScheduler(Scheduler):
"""This Scheduler extends vllm's original v1 scheduler """This Scheduler extends vllm's original v1 scheduler
@@ -366,32 +364,12 @@ class AscendScheduler(Scheduler):
req_to_new_block_ids[req.request_id]) req_to_new_block_ids[req.request_id])
for req in scheduled_new_reqs for req in scheduled_new_reqs
] ]
if vllm_version_is("0.9.1"):
resumed_reqs_data = [ cached_reqs_data = self._make_cached_request_data(
self._make_cached_request_data( scheduled_running_reqs, scheduled_resumed_reqs,
req, num_scheduled_tokens, scheduled_spec_decode_tokens,
num_scheduled_tokens[req.request_id], req_to_new_block_ids)
len(scheduled_spec_decode_tokens.get(req.request_id, ())), scheduled_cached_reqs = cached_reqs_data
req_to_new_block_ids[req.request_id],
resumed_from_preemption=True,
) for req in scheduled_resumed_reqs
]
running_reqs_data = [
self._make_cached_request_data(
req,
num_scheduled_tokens[req.request_id],
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
req_to_new_block_ids[req.request_id],
resumed_from_preemption=False,
) for req in scheduled_running_reqs
]
scheduled_cached_reqs = resumed_reqs_data + running_reqs_data
else:
cached_reqs_data = self._make_cached_request_data(
scheduled_running_reqs, scheduled_resumed_reqs,
num_scheduled_tokens, scheduled_spec_decode_tokens,
req_to_new_block_ids)
scheduled_cached_reqs = cached_reqs_data
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data, scheduled_new_reqs=new_reqs_data,

View File

@@ -50,10 +50,6 @@ env_variables: Dict[str, Callable[[], Any]] = {
# value is None, which means the system default C compiler will be used. # value is None, which means the system default C compiler will be used.
"C_COMPILER": "C_COMPILER":
lambda: os.getenv("C_COMPILER", None), lambda: os.getenv("C_COMPILER", None),
# Whether to enable the topk optimization. It's disabled by default for experimental support
# We'll make it enabled by default in the future.
"VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE", '0'))),
# The version of the Ascend chip. If not set, the default value is # The version of the Ascend chip. If not set, the default value is
# ASCEND910B1. It's used for package building. Please make sure that the # ASCEND910B1. It's used for package building. Please make sure that the
# version is correct. # version is correct.

View File

@@ -78,7 +78,7 @@ from vllm_ascend.multistream.metadata import (MultiStreamConfig,
make_multistream_metadata_ds) make_multistream_metadata_ds)
from vllm_ascend.multistream.ms_split import compute_split_seq_index from vllm_ascend.multistream.ms_split import compute_split_seq_index
from vllm_ascend.ops.fused_moe import AscendFusedMoE from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.utils import dispose_tensor, vllm_version_is from vllm_ascend.utils import dispose_tensor
VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO
@@ -1032,19 +1032,12 @@ class CustomDeepseekDBOForCausalLM(DeepseekV2ForCausalLM):
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
if vllm_version_is("0.9.1"): weight_loader(param,
weight_loader(param, loaded_weight,
loaded_weight, name,
name, shard_id=shard_id,
shard_id=shard_id, expert_id=expert_id,
expert_id=expert_id) return_success=False)
else:
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
return_success=False)
break break
else: else:
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.

View File

@@ -75,7 +75,7 @@ from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.quantization.quant_config import AscendLinearMethod from vllm_ascend.quantization.quant_config import AscendLinearMethod
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch, from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
npu_wait_tensor, vllm_version_is) npu_wait_tensor)
class CustomDeepseekV2SiluAndMul(SiluAndMul): class CustomDeepseekV2SiluAndMul(SiluAndMul):
@@ -936,19 +936,12 @@ class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM):
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
if vllm_version_is("0.9.1"): weight_loader(param,
weight_loader(param, loaded_weight,
loaded_weight, name,
name, shard_id=shard_id,
shard_id=shard_id, expert_id=expert_id,
expert_id=expert_id) return_success=False)
else:
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
return_success=False)
break break
else: else:
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.

View File

@@ -28,6 +28,10 @@ from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import get_dp_group, get_tp_group from vllm.distributed.parallel_state import get_dp_group, get_tp_group
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import \
FusedMoEConfig # isort: skip
from vllm.model_executor.layers.fused_moe.config import \
FusedMoEParallelConfig # isort: skip
from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map) FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
from vllm.model_executor.layers.quantization.base_config import \ from vllm.model_executor.layers.quantization.base_config import \
@@ -39,16 +43,7 @@ from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.utils import (FusedMoEState, dispose_tensor, from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
get_fused_moe_state, is_310p, npu_stream_switch, get_fused_moe_state, is_310p, npu_stream_switch,
npu_wait_tensor, vllm_version_is) npu_wait_tensor)
if vllm_version_is("0.9.1"):
from vllm.model_executor.layers.fused_moe.layer import \
FusedMoEParallelConfig
from vllm.model_executor.layers.fused_moe.layer import \
MoEConfig as FusedMoEConfig
else:
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEParallelConfig)
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
@@ -1177,27 +1172,15 @@ class AscendFusedMoE(FusedMoE):
if self.scoring_func != "softmax" and not self.use_grouped_topk: if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for " raise ValueError("Only softmax scoring function is supported for "
"non-grouped topk.") "non-grouped topk.")
moe = FusedMoEConfig.make(
if vllm_version_is("0.9.1"): num_experts=self.global_num_experts,
moe = FusedMoEConfig( experts_per_token=top_k,
num_experts=self.global_num_experts, hidden_dim=hidden_size,
experts_per_token=top_k, num_local_experts=self.local_num_experts,
hidden_dim=hidden_size, moe_parallel_config=self.moe_parallel_config,
num_local_experts=self.local_num_experts, # TODO (bnell): this needs to be fixed for quantized types.
moe_parallel_config=self.moe_parallel_config, in_dtype=params_dtype,
# TODO (bnell): this needs to be fixed for quantized types. quant_config=quant_config)
in_dtype=params_dtype,
)
else:
moe = FusedMoEConfig.make(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
# TODO (bnell): this needs to be fixed for quantized types.
in_dtype=params_dtype,
quant_config=quant_config)
if quant_config is None: if quant_config is None:
self.quant_method = AscendUnquantizedFusedMoEMethod(moe) self.quant_method = AscendUnquantizedFusedMoEMethod(moe)

View File

@@ -24,9 +24,9 @@
# each worker's `__init__` function. # each worker's `__init__` function.
# #
# Then in each kind of patch, there are three folders: # Then in each kind of patch, there are three folders:
# - patch_0_9_1: contains the patches applied when vllm version is 0.9.1. # - patch_0_9_2: contains the patches applied when vllm version is 0.9.2.
# - patch_main: contains the patches applied when vllm version is main branch. # - patch_main: contains the patches applied when vllm version is main branch.
# - patch_common: contains the patches applied in both 0.9.1 and main branch. # - patch_common: contains the patches applied in both 0.9.2 and main branch.
# #
# Once a new patch is added in vllm-ascend, please add the patch description into this file as well. # Once a new patch is added in vllm-ascend, please add the patch description into this file as well.
# ---------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------
@@ -105,32 +105,6 @@
# Future Plan: # Future Plan:
# Revert it when the related pr is merged in vllm and vllm-ascend. # Revert it when the related pr is merged in vllm and vllm-ascend.
# #
# ** File: worker/patch_common/patch_sampler.py **
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.v1.sample.sampler.Sampler.apply_top_k_top_p`
# Why:
# We need to use the patched `apply_top_k_top_p` in `sample`.
# The mainly reason to overwrite `apply_top_k_top_p` is
# to improve performance.
# How
# Re-implementation the `apply_top_k_top_p` function by pytorch
# Related PR (if no, explain why):
# - https://github.com/vllm-project/vllm-ascend/pull/970
# Future Plan:
# Revert it when the ascend scatter performance improves.
#
# 2. `vllm.v1.sample.sampler.Sampler.apply_min_p`
# Why:
# We need to use the patched `apply_min_p` in `sample`.
# The mainly reason to overwrite `apply_min_p` is
# to improve performance.
# How
# Re-implementation the `apply_min_p` function by pytorch
# Related PR (if no, explain why):
# - https://github.com/vllm-project/vllm-ascend/pull/970
# Future Plan:
# Revert it when the ascend indexput performance improves.
#
# ** File: worker/patch_common/patch_distributed.py ** # ** File: worker/patch_common/patch_distributed.py **
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.distributed.parallel_state.GroupCoordinator` # 1. `vllm.distributed.parallel_state.GroupCoordinator`
@@ -154,4 +128,4 @@
# Related PR (if no, explain why): # Related PR (if no, explain why):
# This is the problem in vllm-ascend # This is the problem in vllm-ascend
# Future Plan: # Future Plan:
# Remove this patch once pytorch 2.7.0 is supported for vllm ascend. # Remove this patch once pytorch 2.7.0 is supported for vllm ascend.

View File

@@ -17,8 +17,8 @@
from vllm_ascend.utils import vllm_version_is from vllm_ascend.utils import vllm_version_is
# Import specific patches for different versions # Import specific patches for different versions
if vllm_version_is("0.9.1"): if vllm_version_is("0.9.2"):
from vllm_ascend.patch.platform import patch_0_9_1 # noqa: F401 from vllm_ascend.patch.platform import patch_0_9_2 # noqa: F401
from vllm_ascend.patch.platform import patch_common # noqa: F401 from vllm_ascend.patch.platform import patch_common # noqa: F401
else: else:
from vllm_ascend.patch.platform import patch_common # noqa: F401 from vllm_ascend.patch.platform import patch_common # noqa: F401

View File

@@ -18,8 +18,8 @@
from vllm_ascend.utils import vllm_version_is from vllm_ascend.utils import vllm_version_is
# Import specific patches for different versions # Import specific patches for different versions
if vllm_version_is("0.9.1"): if vllm_version_is("0.9.2"):
from vllm_ascend.patch.worker import patch_0_9_1 # noqa: F401 from vllm_ascend.patch.worker import patch_0_9_2 # noqa: F401
from vllm_ascend.patch.worker import patch_common # noqa: F401 from vllm_ascend.patch.worker import patch_common # noqa: F401
else: else:
from vllm_ascend.patch.worker import patch_common # noqa: F401 from vllm_ascend.patch.worker import patch_common # noqa: F401

View File

@@ -1,106 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional
import torch
import torch_npu
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
from vllm.v1.sample.sampler import Sampler
from vllm_ascend import envs
def apply_min_p(
self,
logits: torch.Tensor,
min_p: torch.Tensor,
) -> torch.Tensor:
"""
Filters logits using adaptive probability thresholding.
"""
# Convert logits to probability distribution
probability_values = torch.nn.functional.softmax(logits, dim=-1)
# Calculate maximum probabilities per sequence
max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True)
# Reshape min_p for broadcasting
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
# Identify valid tokens using threshold comparison
# Apply mask using boolean indexing
logits = logits.masked_fill(probability_values < adjusted_min_p,
-float('inf'))
return logits
def _apply_top_k_top_p(
logits: torch.Tensor,
k: torch.Tensor,
p: torch.Tensor,
) -> torch.Tensor:
if p is not None and k is not None:
# npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p)
return torch_npu.npu_top_k_top_p(logits, p, k)
probs = logits.softmax(dim=-1)
probs_sort, _ = probs.sort(dim=-1, descending=False)
if k is not None:
top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, )
top_k_count = top_k_count.unsqueeze(dim=1)
top_k_cutoff = probs_sort.gather(-1, top_k_count)
# Make sure the no top-k rows are no-op.
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
elements_to_discard = probs < top_k_cutoff
logits.masked_fill_(elements_to_discard, -float("inf"))
if p is not None:
cumprob = torch.cumsum(probs_sort, dim=-1)
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
top_p_mask[:, -1] = False # at least one
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
top_p_cutoff = probs_sort.gather(-1, top_p_count)
elements_to_discard = probs < top_p_cutoff
logits.masked_fill_(elements_to_discard, -float("inf"))
return logits
def topk_topp_forward_native(
self,
logits: torch.Tensor,
generators: dict[int, torch.Generator],
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""
PyTorch-native implementation of top-k and top-p sampling.
The logits tensor may be updated in-place.
"""
logits = _apply_top_k_top_p(logits, k, p)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)
Sampler.apply_min_p = apply_min_p
if envs.VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE:
TopKTopPSampler.forward_native = topk_topp_forward_native

View File

@@ -14,4 +14,3 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import vllm_ascend.patch.worker.patch_0_9_1.patch_sampler # noqa

View File

@@ -44,6 +44,7 @@ from vllm.logger import logger
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models.interfaces import has_step_pooler
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.multimodal.utils import group_mm_inputs_by_modality
@@ -79,7 +80,7 @@ from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
ProfileExecuteDuration, ProfileExecuteDuration,
check_torchair_cache_exist, is_310p, check_torchair_cache_exist, is_310p,
maybe_converting_weight_acl_format, maybe_converting_weight_acl_format,
vllm_version_is, write_kv_cache_bytes_to_file) write_kv_cache_bytes_to_file)
from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
@@ -95,9 +96,6 @@ import vllm.envs as envs_vllm
import vllm_ascend.envs as envs_ascend import vllm_ascend.envs as envs_ascend
if vllm_version_is("0.9.1"):
from vllm.v1.spec_decode.utils import is_spec_decode_supported
if is_310p(): if is_310p():
torch_npu.npu.set_compile_mode(jit_compile=False) torch_npu.npu.set_compile_mode(jit_compile=False)
@@ -408,16 +406,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
else: else:
generator = None generator = None
# For vllm v0.9.1 version compatibility, we check if
# `pooling_params` is present in the new request data.
pooling_params = getattr(new_req_data, "pooling_params", None)
self.requests[req_id] = CachedRequestState( self.requests[req_id] = CachedRequestState(
req_id=req_id, req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids, prompt_token_ids=new_req_data.prompt_token_ids,
mm_inputs=new_req_data.mm_inputs, mm_inputs=new_req_data.mm_inputs,
mm_positions=new_req_data.mm_positions, mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=pooling_params, pooling_params=new_req_data.pooling_params,
generator=generator, generator=generator,
block_ids=new_req_data.block_ids, block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens, num_computed_tokens=new_req_data.num_computed_tokens,
@@ -465,62 +460,59 @@ class NPUModelRunner(LoRAModelRunnerMixin):
req_ids_to_add.append(req_id) req_ids_to_add.append(req_id)
# Update the states of the running/resumed requests. # Update the states of the running/resumed requests.
if vllm_version_is("0.9.1"): req_data = scheduler_output.scheduled_cached_reqs
for req_data in scheduler_output.scheduled_cached_reqs: is_last_rank = get_pp_group().is_last_rank
req_id = req_data.req_id for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id] req_state = self.requests[req_id]
num_computed_tokens = req_data.num_computed_tokens[i]
new_block_ids = req_data.new_block_ids[i]
resumed_from_preemption = req_data.resumed_from_preemption[i]
# Update the cached states. req_state.num_computed_tokens = num_computed_tokens
num_computed_tokens = req_data.num_computed_tokens if not is_last_rank:
req_state.num_computed_tokens = num_computed_tokens new_token_ids = req_data.new_token_ids[i]
# Add the sampled token(s) from the previous step (if any). # Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec decode tokens. # This doesn't include "unverified" tokens like spec decode tokens.
num_new_tokens = (num_computed_tokens + num_new_tokens = (num_computed_tokens + len(new_token_ids) -
len(req_data.new_token_ids) -
req_state.num_tokens) req_state.num_tokens)
if num_new_tokens == 1: if num_new_tokens == 1:
# Avoid slicing list in most common case. # Avoid slicing list in most common case.
req_state.output_token_ids.append( req_state.output_token_ids.append(new_token_ids[-1])
req_data.new_token_ids[-1])
elif num_new_tokens > 0: elif num_new_tokens > 0:
req_state.output_token_ids.extend( req_state.output_token_ids.extend(
req_data.new_token_ids[-num_new_tokens:]) new_token_ids[-num_new_tokens:])
# Update the block IDs. # Update the block IDs.
if not req_data.resumed_from_preemption: if not resumed_from_preemption:
# Append the new blocks to the existing block IDs. # Append the new blocks to the existing block IDs.
for block_ids, new_block_ids in zip( # type: ignore[call-overload] for block_ids, new_ids in zip( # type: ignore[call-overload]
req_state.block_ids, req_state.block_ids, new_block_ids):
req_data.new_block_ids, block_ids.extend(new_ids)
strict=True): else:
block_ids.extend(new_block_ids) # The request is resumed from preemption.
else: # Replace the existing block IDs with the new ones.
# The request is resumed from preemption. req_state.block_ids = new_block_ids
# Replace the existing block IDs with the new ones.
req_state.block_ids = req_data.new_block_ids
req_index = self.input_batch.req_id_to_index.get(req_id) req_index = self.input_batch.req_id_to_index.get(req_id)
if req_index is None: if req_index is None:
# The request is not in the persistent batch. # The request is not in the persistent batch.
# The request was either preempted and resumed later, or was not # The request was either preempted and resumed later, or was not
# scheduled in the previous step and needs to be added again. # scheduled in the previous step and needs to be added again.
req_ids_to_add.append(req_id) req_ids_to_add.append(req_id)
continue continue
# Update the persistent batch. # Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = ( self.input_batch.num_computed_tokens_cpu[req_index] = (
num_computed_tokens) num_computed_tokens)
start_index = (len(req_state.block_ids) - self.input_batch.block_table.append_row(new_block_ids, req_index)
len(req_data.new_block_ids))
self.input_batch.block_table.append_row( if not is_last_rank:
req_data.new_block_ids, req_index)
# Add new_token_ids to token_ids_cpu. # Add new_token_ids to token_ids_cpu.
start_token_index = num_computed_tokens start_token_index = num_computed_tokens
end_token_index = num_computed_tokens + len( end_token_index = num_computed_tokens + len(new_token_ids)
req_data.new_token_ids)
self.input_batch.token_ids_cpu[ self.input_batch.token_ids_cpu[
req_index, req_index,
start_token_index:end_token_index] = req_data.new_token_ids start_token_index:end_token_index] = new_token_ids
self.input_batch.num_tokens_no_spec[ self.input_batch.num_tokens_no_spec[
req_index] = end_token_index req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu. # Add spec_token_ids to token_ids_cpu.
@@ -534,75 +526,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
start_index:end_token_index] = spec_token_ids start_index:end_token_index] = spec_token_ids
# NOTE(woosuk): `num_tokens` here may include spec decode tokens. # NOTE(woosuk): `num_tokens` here may include spec decode tokens.
self.input_batch.num_tokens[req_index] = end_token_index self.input_batch.num_tokens[req_index] = end_token_index
else:
req_data = scheduler_output.scheduled_cached_reqs
is_last_rank = get_pp_group().is_last_rank
for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id]
num_computed_tokens = req_data.num_computed_tokens[i]
new_block_ids = req_data.new_block_ids[i]
resumed_from_preemption = req_data.resumed_from_preemption[i]
req_state.num_computed_tokens = num_computed_tokens
if not is_last_rank:
new_token_ids = req_data.new_token_ids[i]
# Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec decode tokens.
num_new_tokens = (num_computed_tokens +
len(new_token_ids) -
req_state.num_tokens)
if num_new_tokens == 1:
# Avoid slicing list in most common case.
req_state.output_token_ids.append(new_token_ids[-1])
elif num_new_tokens > 0:
req_state.output_token_ids.extend(
new_token_ids[-num_new_tokens:])
# Update the block IDs.
if not resumed_from_preemption:
# Append the new blocks to the existing block IDs.
for block_ids, new_ids in zip( # type: ignore[call-overload]
req_state.block_ids, new_block_ids):
block_ids.extend(new_ids)
else:
# The request is resumed from preemption.
# Replace the existing block IDs with the new ones.
req_state.block_ids = new_block_ids
req_index = self.input_batch.req_id_to_index.get(req_id)
if req_index is None:
# The request is not in the persistent batch.
# The request was either preempted and resumed later, or was not
# scheduled in the previous step and needs to be added again.
req_ids_to_add.append(req_id)
continue
# Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = (
num_computed_tokens)
self.input_batch.block_table.append_row(
new_block_ids, req_index)
if not is_last_rank:
# Add new_token_ids to token_ids_cpu.
start_token_index = num_computed_tokens
end_token_index = num_computed_tokens + len(new_token_ids)
self.input_batch.token_ids_cpu[
req_index,
start_token_index:end_token_index] = new_token_ids
self.input_batch.num_tokens_no_spec[
req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu.
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
req_id, ())
if spec_token_ids:
start_index = end_token_index
end_token_index += len(spec_token_ids)
self.input_batch.token_ids_cpu[
req_index,
start_index:end_token_index] = spec_token_ids
# NOTE(woosuk): `num_tokens` here may include spec decode tokens.
self.input_batch.num_tokens[req_index] = end_token_index
# Check if the batch has changed. If not, we can skip copying the # Check if the batch has changed. If not, we can skip copying the
# sampling metadata from CPU to GPU. # sampling metadata from CPU to GPU.
@@ -835,25 +758,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# compute completion's mrope_positions on-the-fly # compute completion's mrope_positions on-the-fly
dst_start = mrope_pos_ptr dst_start = mrope_pos_ptr
dst_end = mrope_pos_ptr + completion_part_len dst_end = mrope_pos_ptr + completion_part_len
MRotaryEmbedding.get_next_input_positions_tensor(
if vllm_version_is("0.9.1"): out=self.mrope_positions_np,
self.mrope_positions_cpu[:, dst_start:dst_end] = \ out_offset=dst_start,
MRotaryEmbedding.get_next_input_positions_tensor( mrope_position_delta=req.mrope_position_delta,
req.mrope_position_delta, context_len=num_computed_tokens + prompt_part_len,
context_len=num_computed_tokens + num_new_tokens=completion_part_len,
prompt_part_len, )
seq_len=num_computed_tokens +
prompt_part_len +
completion_part_len,
)
else:
MRotaryEmbedding.get_next_input_positions_tensor(
out=self.mrope_positions_np,
out_offset=dst_start,
mrope_position_delta=req.mrope_position_delta,
context_len=num_computed_tokens + prompt_part_len,
num_new_tokens=completion_part_len,
)
mrope_pos_ptr += completion_part_len mrope_pos_ptr += completion_part_len
@@ -1661,30 +1572,29 @@ class NPUModelRunner(LoRAModelRunnerMixin):
for i in discard_sampled_tokens_req_indices: for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear() valid_sampled_token_ids[i].clear()
if not vllm_version_is("0.9.1"): # Cache the sampled tokens in the model runner, so that the schedulerAdd commentMore actions
# Cache the sampled tokens in the model runner, so that the schedulerAdd commentMore actions # doesn't need to send them back.
# doesn't need to send them back. # NOTE(woosuk): As an exception, when using PP, the scheduler sends
# NOTE(woosuk): As an exception, when using PP, the scheduler sends # the sampled tokens back, because there's no direct communication
# the sampled tokens back, because there's no direct communication # between the first-stage worker and the last-stage worker.
# between the first-stage worker and the last-stage worker. for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids): if not sampled_ids:
if not sampled_ids: continue
continue
start_idx = self.input_batch.num_tokens_no_spec[req_idx] start_idx = self.input_batch.num_tokens_no_spec[req_idx]
end_idx = start_idx + len(sampled_ids) end_idx = start_idx + len(sampled_ids)
assert end_idx <= self.model_config.max_model_len, ( assert end_idx <= self.model_config.max_model_len, (
"Sampled token IDs exceed the max model length. " "Sampled token IDs exceed the max model length. "
f"Total number of tokens: {end_idx} > max_model_len: " f"Total number of tokens: {end_idx} > max_model_len: "
f"{self.model_config.max_model_len}") f"{self.model_config.max_model_len}")
self.input_batch.token_ids_cpu[ self.input_batch.token_ids_cpu[req_idx,
req_idx, start_idx:end_idx] = sampled_ids start_idx:end_idx] = sampled_ids
self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx
req_id = self.input_batch.req_ids[req_idx] req_id = self.input_batch.req_ids[req_idx]
req_state = self.requests[req_id] req_state = self.requests[req_id]
req_state.output_token_ids.extend(sampled_ids) req_state.output_token_ids.extend(sampled_ids)
spec_token_ids = self._get_spec_token_ids( spec_token_ids = self._get_spec_token_ids(
valid_sampled_token_ids, valid_sampled_token_ids,
@@ -1697,25 +1607,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
attn_metadata, attn_metadata,
aux_hidden_states, aux_hidden_states,
) )
if vllm_version_is("0.9.1"):
model_runner_output = ModelRunnerOutput( model_runner_output = ModelRunnerOutput(
req_ids=self.input_batch.req_ids, req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index, req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=valid_sampled_token_ids, sampled_token_ids=valid_sampled_token_ids,
spec_token_ids=spec_token_ids, spec_token_ids=spec_token_ids,
logprobs=logprobs_lists, logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict, prompt_logprobs_dict=prompt_logprobs_dict,
) pooler_output=[],
else: )
model_runner_output = ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=valid_sampled_token_ids,
spec_token_ids=spec_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
)
durations = ProfileExecuteDuration().pop_captured_sync() durations = ProfileExecuteDuration().pop_captured_sync()
if durations: if durations:
@@ -2024,15 +1925,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
QKVParallelLinear, RowParallelLinear)): QKVParallelLinear, RowParallelLinear)):
module.weight.data = torch_npu.npu_format_cast( module.weight.data = torch_npu.npu_format_cast(
module.weight.data, ACL_FORMAT_FRACTAL_NZ) module.weight.data, ACL_FORMAT_FRACTAL_NZ)
if has_step_pooler(self.model):
try: self.input_batch.logits_processing_needs_token_ids = True
# For version compatibility, remove this after we abort vllm v0.9.1 support
from vllm.model_executor.models.interfaces import \
has_step_pooler # type: ignore
if has_step_pooler(self.model):
self.input_batch.logits_processing_needs_token_ids = True
except ImportError:
pass
if self.drafter: if self.drafter:
logger.info("Loading drafter model...") logger.info("Loading drafter model...")
if self.use_aux_hidden_state_outputs: if self.use_aux_hidden_state_outputs:
@@ -2362,14 +2256,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Skip requests that require top-p, top-k, etc. # Skip requests that require top-p, top-k, etc.
req_id = self.input_batch.req_ids[i] req_id = self.input_batch.req_ids[i]
if vllm_version_is("0.9.1"): if req_id in self.input_batch.spec_decode_unsupported_reqs:
if not is_spec_decode_supported(req_id, self.input_batch): draft_token_ids.append([])
draft_token_ids.append([]) continue
continue
else:
if req_id in self.input_batch.spec_decode_unsupported_reqs:
draft_token_ids.append([])
continue
# Add sampled_token_ids to token_ids_cpu. # Add sampled_token_ids to token_ids_cpu.
start_idx = self.input_batch.num_tokens_no_spec[i] start_idx = self.input_batch.num_tokens_no_spec[i]

View File

@@ -28,15 +28,13 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import swap_dict_values from vllm.utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors from vllm.v1.outputs import LogprobsTensors
from vllm.v1.sample.logits_processor import init_builtin_logitsprocs
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
from vllm.v1.utils import copy_slice from vllm.v1.utils import copy_slice
from vllm.v1.worker.block_table import MultiGroupBlockTable from vllm.v1.worker.block_table import MultiGroupBlockTable
from vllm_ascend.pool.metadata import PoolingMetadata from vllm_ascend.pool.metadata import PoolingMetadata
from vllm_ascend.utils import vllm_version_is
if not vllm_version_is("0.9.1"):
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
@@ -253,17 +251,13 @@ class InputBatch:
self.req_output_token_ids: list[Optional[list[int]]] = [] self.req_output_token_ids: list[Optional[list[int]]] = []
if not vllm_version_is("0.9.1"): # Define logits processors.
from vllm.v1.sample.logits_processor import \ # TODO(andy): logits processor list should be extensible via engine
init_builtin_logitsprocs # constructor argument; for now the list is fixed.
self.logitsprocs = init_builtin_logitsprocs(
# Define logits processors. pin_memory_available=pin_memory,
# TODO(andy): logits processor list should be extensible via engine max_num_reqs=max_num_reqs + 1,
# constructor argument; for now the list is fixed. device=device)
self.logitsprocs = init_builtin_logitsprocs(
pin_memory_available=pin_memory,
max_num_reqs=max_num_reqs + 1,
device=device)
# This is updated each time the batch constituents change. # This is updated each time the batch constituents change.
self.sampling_metadata = self._make_sampling_metadata() self.sampling_metadata = self._make_sampling_metadata()
@@ -314,8 +308,8 @@ class InputBatch:
self.block_table.add_row(request.block_ids, req_index) self.block_table.add_row(request.block_ids, req_index)
if sampling_params := request.sampling_params: if sampling_params := request.sampling_params:
if ((not vllm_version_is("0.9.1")) and self.is_spec_decode if self.is_spec_decode and is_spec_decode_unsupported(
and is_spec_decode_unsupported(sampling_params)): sampling_params):
self.spec_decode_unsupported_reqs.add(req_id) self.spec_decode_unsupported_reqs.add(req_id)
if sampling_params.sampling_type == SamplingType.GREEDY: if sampling_params.sampling_type == SamplingType.GREEDY:
# Avoid later division by zero. # Avoid later division by zero.
@@ -641,48 +635,24 @@ class InputBatch:
self.allowed_token_ids_mask, num_reqs) self.allowed_token_ids_mask, num_reqs)
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs] allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
if vllm_version_is("0.9.1"): return SamplingMetadata(
return SamplingMetadata( temperature=temperature,
temperature=temperature, all_greedy=self.all_greedy,
all_greedy=self.all_greedy, all_random=self.all_random,
all_random=self.all_random, top_p=None if self.no_top_p else self.top_p[:num_reqs],
top_p=None if self.no_top_p else self.top_p[:num_reqs], top_k=None if self.no_top_k else self.top_k[:num_reqs],
top_k=None if self.no_top_k else self.top_k[:num_reqs], generators=self.generators,
min_p=None if self.no_min_p else self.min_p[:num_reqs], max_num_logprobs=self.max_num_logprobs,
generators=self.generators, prompt_token_ids=prompt_token_ids,
max_num_logprobs=self.max_num_logprobs, frequency_penalties=self.frequency_penalties[:num_reqs],
prompt_token_ids=prompt_token_ids, presence_penalties=self.presence_penalties[:num_reqs],
frequency_penalties=self.frequency_penalties[:num_reqs], repetition_penalties=self.repetition_penalties[:num_reqs],
presence_penalties=self.presence_penalties[:num_reqs], output_token_ids=cast(list[list[int]], self.req_output_token_ids),
repetition_penalties=self.repetition_penalties[:num_reqs], no_penalties=self.no_penalties,
output_token_ids=cast(list[list[int]], allowed_token_ids_mask=allowed_token_ids_mask,
self.req_output_token_ids), bad_words_token_ids=self.bad_words_token_ids,
min_tokens=self.min_tokens, logitsprocs=self.logitsprocs,
no_penalties=self.no_penalties, )
logit_bias=self.logit_bias[:num_reqs],
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=self.bad_words_token_ids,
)
else:
return SamplingMetadata(
temperature=temperature,
all_greedy=self.all_greedy,
all_random=self.all_random,
top_p=None if self.no_top_p else self.top_p[:num_reqs],
top_k=None if self.no_top_k else self.top_k[:num_reqs],
generators=self.generators,
max_num_logprobs=self.max_num_logprobs,
prompt_token_ids=prompt_token_ids,
frequency_penalties=self.frequency_penalties[:num_reqs],
presence_penalties=self.presence_penalties[:num_reqs],
repetition_penalties=self.repetition_penalties[:num_reqs],
output_token_ids=cast(list[list[int]],
self.req_output_token_ids),
no_penalties=self.no_penalties,
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=self.bad_words_token_ids,
logitsprocs=self.logitsprocs,
)
@property @property
def pooling_metadata(self) -> PoolingMetadata: def pooling_metadata(self) -> PoolingMetadata: