diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index 716fffc..039c033 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -86,7 +86,7 @@ jobs: - name: Run codespell check run: | CODESPELL_EXCLUDES=('--skip' 'tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**') - CODESPELL_IGNORE_WORDS=('-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn') + CODESPELL_IGNORE_WORDS=('-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,assertIn') codespell --toml pyproject.toml "${CODESPELL_EXCLUDES[@]}" "${CODESPELL_IGNORE_WORDS[@]}" - name: Analysing the code with ruff @@ -262,11 +262,13 @@ jobs: pytest -sv tests/e2e/singlecard/test_ilama_lora.py pytest -sv tests/e2e/singlecard/test_guided_decoding.py pytest -sv tests/e2e/singlecard/test_camem.py + pytest -sv tests/e2e/singlecard/test_embedding.py pytest -sv tests/e2e/singlecard/ \ --ignore=tests/e2e/singlecard/test_offline_inference.py \ --ignore=tests/e2e/singlecard/test_ilama_lora.py \ --ignore=tests/e2e/singlecard/test_guided_decoding.py \ - --ignore=tests/e2e/singlecard/test_camem.py + --ignore=tests/e2e/singlecard/test_camem.py \ + --ignore=tests/e2e/singlecard/test_embedding.py - name: Run e2e test on V0 engine if: ${{ github.event_name == 'schedule' }} @@ -281,6 +283,7 @@ jobs: pytest -sv tests/e2e/singlecard/test_guided_decoding.py pytest -sv tests/e2e/singlecard/test_camem.py pytest -sv tests/e2e/singlecard/test_prompt_embedding.py + pytest -sv tests/e2e/singlecard/test_embedding.py pytest -sv tests/e2e/singlecard/ \ --ignore=tests/e2e/singlecard/test_offline_inference.py \ --ignore=tests/e2e/singlecard/test_ilama_lora.py \ @@ -288,7 +291,8 @@ jobs: --ignore=tests/e2e/singlecard/test_camem.py \ --ignore=tests/e2e/singlecard/test_prompt_embedding.py \ --ignore=tests/e2e/singlecard/core/test_ascend_scheduler.py \ - --ignore=tests/e2e/singlecard/core/test_ascend_scheduler_e2e.py + --ignore=tests/e2e/singlecard/core/test_ascend_scheduler_e2e.py \ + --ignore=tests/e2e/singlecard/test_embedding.py e2e-4-cards: needs: [e2e] diff --git a/examples/offline_embed.py b/examples/offline_embed.py new file mode 100644 index 0000000..39b3f92 --- /dev/null +++ b/examples/offline_embed.py @@ -0,0 +1,53 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from https://www.modelscope.cn/models/Qwen/Qwen3-Embedding-0.6B +# + +import os + +import torch +from vllm import LLM + +os.environ["VLLM_USE_MODELSCOPE"] = "True" + + +def get_detailed_instruct(task_description: str, query: str) -> str: + return f'Instruct: {task_description}\nQuery:{query}' + + +# Each query must come with a one-sentence instruction that describes the task +task = 'Given a web search query, retrieve relevant passages that answer the query' + +queries = [ + get_detailed_instruct(task, 'What is the capital of China?'), + get_detailed_instruct(task, 'Explain gravity') +] +# No need to add instruction for retrieval documents +documents = [ + "The capital of China is Beijing.", + "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun." +] +input_texts = queries + documents + +model = LLM(model="Qwen/Qwen3-Embedding-0.6B", task="embed") + +outputs = model.embed(input_texts) +embeddings = torch.tensor([o.outputs.embedding for o in outputs]) +# Calculate the similarity scores between the first two queries and the last two documents +scores = (embeddings[:2] @ embeddings[2:].T) +print(scores.tolist()) +# [[0.7620252966880798, 0.14078938961029053], [0.1358368694782257, 0.6013815999031067]] diff --git a/requirements-dev.txt b/requirements-dev.txt index 7fb8786..8bd7dca 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -12,3 +12,4 @@ xgrammar zmq types-psutil pytest-cov +sentence_transformers diff --git a/tests/conftest.py b/tests/conftest.py index e0d70a1..cc8e1ee 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,18 +19,23 @@ import contextlib import gc -from typing import List, Optional, Tuple, TypeVar, Union +from typing import Any, List, Optional, Tuple, TypeVar, Union import numpy as np import pytest import torch from huggingface_hub import snapshot_download from PIL import Image +from torch import nn +from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, + BatchEncoding, BatchFeature) +from transformers.models.auto.auto_factory import _BaseAutoModelClass from vllm import LLM, SamplingParams -from vllm.config import TaskOption +from vllm.config import TaskOption, _get_and_verify_dtype from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams +from vllm.transformers_utils.utils import maybe_model_redirect from vllm.utils import is_list_of from tests.model_utils import (PROMPT_TEMPLATES, TokensTextLogprobs, @@ -45,6 +50,7 @@ adapt_patch(True) from vllm.distributed.parallel_state import ( # noqa E402 destroy_distributed_environment, destroy_model_parallel) +_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict) _M = TypeVar("_M") _PromptMultiModalInput = Union[List[_M], List[List[_M]]] @@ -364,3 +370,131 @@ def prompt_template(request): @pytest.fixture(scope="session") def ilama_lora_files(): return snapshot_download(repo_id="jeeejeee/ilama-text2sql-spider") + + +class HfRunner: + + def get_default_device(self): + from vllm.platforms import current_platform + + return ("cpu" + if current_platform.is_cpu() else current_platform.device_type) + + def wrap_device(self, x: _T, device: Optional[str] = None) -> _T: + if x is None or isinstance(x, (bool, )): + return x + + if device is None: + device = self.device + + if isinstance(x, dict): + return {k: self.wrap_device(v, device) for k, v in x.items()} + + if hasattr(x, "device") and x.device.type == device: + return x + + return x.to(device) + + def __init__( + self, + model_name: str, + dtype: str = "auto", + *, + model_kwargs: Optional[dict[str, Any]] = None, + trust_remote_code: bool = True, + is_sentence_transformer: bool = False, + is_cross_encoder: bool = False, + skip_tokenizer_init: bool = False, + auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM, + ) -> None: + model_name = maybe_model_redirect(model_name) + self.model_name = model_name + + self.config = AutoConfig.from_pretrained( + model_name, + trust_remote_code=trust_remote_code, + ) + self.device = self.get_default_device() + self.dtype = torch_dtype = _get_and_verify_dtype( + self.model_name, + self.config, + dtype=dtype, + is_pooling_model=is_sentence_transformer or is_cross_encoder, + ) + + model_kwargs = model_kwargs if model_kwargs is not None else {} + model_kwargs.setdefault("torch_dtype", torch_dtype) + + if is_sentence_transformer: + # Lazy init required for AMD CI + from sentence_transformers import SentenceTransformer + + self.model = SentenceTransformer( + model_name, + device=self.device, + model_kwargs=model_kwargs, + trust_remote_code=trust_remote_code, + ) + elif is_cross_encoder: + # Lazy init required for AMD CI + from sentence_transformers import CrossEncoder + + self.model = CrossEncoder( + model_name, + device=self.device, + automodel_args=model_kwargs, + trust_remote_code=trust_remote_code, + ) + else: + model = auto_cls.from_pretrained( + model_name, + trust_remote_code=trust_remote_code, + **model_kwargs, + ) + + # in case some unquantized custom models are not in same dtype + if (getattr(model, "quantization_method", None) is None + and any(p.dtype != self.dtype + for p in model.parameters())): + model = model.to(dtype=self.dtype) + + if (getattr(model, "quantization_method", None) != "bitsandbytes" + and len({p.device + for p in model.parameters()}) < 2): + model = model.to(device=self.device) + + self.model = model + + if not skip_tokenizer_init: + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + ) + + # don't put this import at the top level + # it will call torch.cuda.device_count() + from transformers import AutoProcessor # noqa: F401 + self.processor = AutoProcessor.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + ) + if skip_tokenizer_init: + self.tokenizer = self.processor.tokenizer + + def encode(self, prompts: list[str], *args, + **kwargs) -> list[list[torch.Tensor]]: + return self.model.encode(prompts, *args, **kwargs) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + del self.model + cleanup_dist_env_and_memory() + + +@pytest.fixture(scope="session") +def hf_runner(): + return HfRunner diff --git a/tests/e2e/singlecard/test_embedding.py b/tests/e2e/singlecard/test_embedding.py new file mode 100644 index 0000000..0ca07a0 --- /dev/null +++ b/tests/e2e/singlecard/test_embedding.py @@ -0,0 +1,72 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py +# +from collections.abc import Sequence +from typing import Optional + +import pytest +from modelscope import snapshot_download # type: ignore[import-untyped] + +from tests.conftest import HfRunner +from tests.utils import check_embeddings_close, matryoshka_fy +from vllm_ascend.utils import vllm_version_is + + +def run_embedding_correctness_test( + hf_model: "HfRunner", + inputs: list[str], + vllm_outputs: Sequence[list[float]], + dimensions: Optional[int] = None, +): + hf_outputs = hf_model.encode(inputs) + if dimensions: + hf_outputs = matryoshka_fy(hf_outputs, dimensions) + + check_embeddings_close( + embeddings_0_lst=hf_outputs, + embeddings_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + tol=1e-2, + ) + + +# dummy to avoid pytest collect nothing and exit code 5 +def test_dummy(): + assert True + + +@pytest.mark.skipif(vllm_version_is("0.9.1"), + reason="vLLM 0.9.1 does not support embed task for v1") +def test_embed_models_correctness(hf_runner, vllm_runner): + queries = ['What is the capital of China?', 'Explain gravity'] + + model_name = snapshot_download("Qwen/Qwen3-Embedding-0.6B") + with vllm_runner( + model_name, + task="embed", + enforce_eager=True, + ) as vllm_model: + vllm_outputs = vllm_model.encode(queries) + + with hf_runner( + model_name, + dtype="float32", + is_sentence_transformer=True, + ) as hf_model: + run_embedding_correctness_test(hf_model, queries, vllm_outputs) diff --git a/tests/ut/worker/test_input_batch.py b/tests/ut/worker/test_input_batch.py new file mode 100644 index 0000000..cbfd67f --- /dev/null +++ b/tests/ut/worker/test_input_batch.py @@ -0,0 +1,162 @@ +import unittest + +import numpy as np +import torch +from vllm.sampling_params import SamplingParams +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.worker.block_table import MultiGroupBlockTable + +from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch + + +def mock_cached_request_state(req_id="1", prompt=[1, 2, 3], output=[4, 5, 6]): + return CachedRequestState( + req_id=req_id, + prompt_token_ids=prompt, + mm_inputs=[], + mm_positions=[], + sampling_params=SamplingParams(), + pooling_params=None, + generator=None, + block_ids=([], ), + num_computed_tokens=0, + output_token_ids=output, + ) + + +class TestInputBatch(unittest.TestCase): + + def setUp(self): + self.max_num_reqs = 10 + self.max_model_len = 32 + self.max_num_batched_tokens = 132 + self.vocab_size = 1000 + self.device = torch.device("cpu") + self.block_sizes = [128] + + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + max_num_batched_tokens=self.max_num_batched_tokens, + device=self.device, + pin_memory=False, + vocab_size=self.vocab_size, + block_sizes=self.block_sizes, + ) + self.cached_request_state = mock_cached_request_state() + + def test_shapes_and_defaults(self): + # torch tensor shape assertions + self.assertEqual(self.input_batch.token_ids_cpu_tensor.shape, + (self.max_num_reqs, self.max_model_len)) + self.assertEqual(self.input_batch.temperature.shape, + (self.max_num_reqs, )) + self.assertEqual(self.input_batch.top_k.shape, (self.max_num_reqs, )) + self.assertEqual(self.input_batch.min_p_cpu_tensor.shape, + (self.max_num_reqs, )) + + # numpy shape assertions + self.assertEqual(self.input_batch.token_ids_cpu.shape, + (self.max_num_reqs, self.max_model_len)) + self.assertEqual(self.input_batch.num_tokens.shape, + (self.max_num_reqs, )) + self.assertEqual(self.input_batch.num_tokens.shape, + (self.max_num_reqs, )) + + # type assertions + self.assertIsInstance(self.input_batch.greedy_reqs, set) + self.assertIsInstance(self.input_batch.req_id_to_index, dict) + self.assertIsInstance(self.input_batch.sampling_metadata, + SamplingMetadata) + self.assertIsInstance(self.input_batch.block_table, + MultiGroupBlockTable) + self.assertIsNone(self.input_batch.allowed_token_ids_mask) + self.assertIsNone(self.input_batch.allowed_token_ids_mask_cpu_tensor) + + def test_add_request(self): + # case1: add a new req + self.input_batch.add_request(self.cached_request_state) + self.assertIn(self.cached_request_state.req_id, + self.input_batch.req_id_to_index) + req_index = self.input_batch.req_id_to_index[ + self.cached_request_state.req_id] + self.assertEqual(self.input_batch.num_prompt_tokens[req_index], + len(self.cached_request_state.prompt_token_ids)) + self.assertEqual(self.input_batch.num_tokens[req_index], + self.cached_request_state.num_tokens) + + # case2: add an existing req, maybe need update + self.cached_request_state.output_token_ids.extend([7, 8, 9]) + self.cached_request_state.num_computed_tokens += 3 + cached_index = self.input_batch.req_id_to_index[ + self.cached_request_state.req_id] + self.input_batch.add_request(self.cached_request_state, cached_index) + # check if this index in the input_batch is updated + # This np arrat "token_ids_cpu" should be filled with prompt_token_ids + output_token_ids + self.assertTrue( + np.all(self.input_batch.token_ids_cpu[ + cached_index, :self.cached_request_state.num_tokens]), + msg=f"Token IDs at index {cached_index} did not update correctly.") + + # case3: add req that greater than max_num_reqs + with self.assertRaises(AssertionError): + self.input_batch.add_request(self.cached_request_state, + req_index=self.max_num_reqs) + + # case4: add req that out of max_model_len + long_prompt = list(range(self.max_model_len + 1)) + long_request = mock_cached_request_state(req_id="2", + prompt=long_prompt, + output=[10]) + with self.assertRaises(ValueError) as cm: + self.input_batch.add_request(long_request) + self.assertIn("could not broadcast", str(cm.exception)) + + def test_remove_request(self): + self.input_batch.add_request(self.cached_request_state) + req_index = self.input_batch.remove_request( + self.cached_request_state.req_id) + self.assertIsNotNone(req_index) + self.assertNotIn(self.cached_request_state.req_id, + self.input_batch.req_id_to_index) + self.assertIsNone(self.input_batch._req_ids[req_index]) + + def test_condense(self): + # Let's say we have some requests like below + # Index Req ID + # 0 1 + # 1 2 + # 2 3 + # 3 4 + for i in range(4): + request = mock_cached_request_state(req_id=str(i + 1)) + self.input_batch.add_request(request) + removed_req_indices = [] + id_to_remove = ["2", "4"] # IDs to remove + for req_id in id_to_remove: + removed_index = self.input_batch.remove_request(req_id) + if removed_index is not None: + removed_req_indices.append(removed_index) + self.assertEqual(len(removed_req_indices), len(id_to_remove)) + self.input_batch.condense(sorted(removed_req_indices, reverse=True)) + + # Check if the remaining requests are condensed correctly + indices = [ + self.input_batch.req_id_to_index[req_id] for req_id in ["1", "3"] + ] + self.assertTrue(all(idx < self.input_batch.num_reqs + for idx in indices)) + + for i in range(self.input_batch.num_reqs): + self.assertIsNotNone(self.input_batch._req_ids[i]) + for i in range(self.input_batch.num_reqs, + len(self.input_batch._req_ids)): + self.assertIsNone(self.input_batch._req_ids[i]) + + for req_id in ["1", "3"]: + idx = self.input_batch.req_id_to_index[req_id] + tokens = self.input_batch.token_ids_cpu[idx] + self.assertTrue( + tokens.any(), + f"Tokens at index {idx} for req {req_id} should not be all zero" + ) diff --git a/tests/utils.py b/tests/utils.py index ced7d9a..2535d08 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -23,10 +23,13 @@ import signal import subprocess import sys import time +from collections.abc import Sequence from typing import Callable, Optional import openai import requests +import torch +import torch.nn.functional as F from typing_extensions import ParamSpec from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.openai.cli_args import make_arg_parser @@ -197,3 +200,37 @@ def fork_new_process_for_each_test( f" args {args} and kwargs {kwargs}") return wrapper + + +def matryoshka_fy(tensor: torch.Tensor, dimensions: int): + tensor = torch.tensor(tensor) + tensor = tensor[..., :dimensions] + tensor = F.normalize(tensor, p=2, dim=1) + return tensor + + +def check_embeddings_close( + *, + embeddings_0_lst: Sequence[list[float]], + embeddings_1_lst: Sequence[list[float]], + name_0: str, + name_1: str, + tol: float = 1e-3, +) -> None: + assert len(embeddings_0_lst) == len(embeddings_1_lst) + + for prompt_idx, (embeddings_0, embeddings_1) in enumerate( + zip(embeddings_0_lst, embeddings_1_lst)): + assert len(embeddings_0) == len(embeddings_1), ( + f"Length mismatch: {len(embeddings_0)} vs. {len(embeddings_1)}") + + sim = F.cosine_similarity(torch.tensor(embeddings_0), + torch.tensor(embeddings_1), + dim=0) + + fail_msg = (f"Test{prompt_idx}:" + f"\nCosine similarity: \t{sim:.4f}" + f"\n{name_0}:\t{embeddings_0[:16]!r}" + f"\n{name_1}:\t{embeddings_1[:16]!r}") + + assert sim >= 1 - tol, fail_msg diff --git a/vllm_ascend/pool/metadata.py b/vllm_ascend/pool/metadata.py new file mode 100644 index 0000000..6dca038 --- /dev/null +++ b/vllm_ascend/pool/metadata.py @@ -0,0 +1,32 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm-project/vllm/vllm/v1/pool/metadata.py +# +from dataclasses import dataclass +from typing import Optional + +import torch +from vllm.pooling_params import PoolingParams + + +@dataclass +class PoolingMetadata: + """Tensors for pooling.""" + + prompt_lens: torch.Tensor + prompt_token_ids: Optional[torch.Tensor] + pooling_params: list[PoolingParams] diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index c75021a..3f99371 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -47,6 +47,7 @@ from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.utils import group_mm_inputs_by_modality +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, @@ -62,7 +63,6 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.utils import is_spec_decode_supported from vllm.v1.utils import bind_kv_cache -from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs, @@ -74,12 +74,14 @@ from vllm_ascend.attention.attention_v1 import (AscendAttentionState, AscendMetadata) from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata from vllm_ascend.platform import NPUPlatform +from vllm_ascend.pool.metadata import PoolingMetadata from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, ProfileExecuteDuration, is_310p, vllm_version_is) from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer +from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch if TYPE_CHECKING: import xgrammar as xgr # type: ignore[import-untyped] @@ -177,6 +179,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.cache_config.cache_dtype] self.is_multimodal_model = self.model_config.is_multimodal_model + self.is_pooling_model = self.model_config.pooler_config is not None if self.is_multimodal_model: self.inputs_embeds = torch.zeros( (self.max_num_tokens, self.model_config.get_hidden_size()), @@ -389,38 +392,29 @@ class NPUModelRunner(LoRAModelRunnerMixin): for new_req_data in scheduler_output.scheduled_new_reqs: req_id = new_req_data.req_id sampling_params = new_req_data.sampling_params - if sampling_params.sampling_type == SamplingType.RANDOM_SEED: + if sampling_params and \ + sampling_params.sampling_type == SamplingType.RANDOM_SEED: generator = torch.Generator(device=self.device) generator.manual_seed(sampling_params.seed) else: generator = None - if vllm_version_is("0.9.1"): - self.requests[req_id] = CachedRequestState( - req_id=req_id, - prompt_token_ids=new_req_data.prompt_token_ids, - mm_inputs=new_req_data.mm_inputs, - mm_positions=new_req_data.mm_positions, - sampling_params=sampling_params, - generator=generator, - block_ids=new_req_data.block_ids, - num_computed_tokens=new_req_data.num_computed_tokens, - output_token_ids=[], - lora_request=new_req_data.lora_request, - ) - else: - self.requests[req_id] = CachedRequestState( - req_id=req_id, - prompt_token_ids=new_req_data.prompt_token_ids, - mm_inputs=new_req_data.mm_inputs, - mm_positions=new_req_data.mm_positions, - sampling_params=sampling_params, - pooling_params=None, - generator=generator, - block_ids=new_req_data.block_ids, - num_computed_tokens=new_req_data.num_computed_tokens, - output_token_ids=[], - lora_request=new_req_data.lora_request, - ) + + # For vllm v0.9.1 version compatibility, we check if + # `pooling_params` is present in the new request data. + pooling_params = getattr(new_req_data, "pooling_params", None) + self.requests[req_id] = CachedRequestState( + req_id=req_id, + prompt_token_ids=new_req_data.prompt_token_ids, + mm_inputs=new_req_data.mm_inputs, + mm_positions=new_req_data.mm_positions, + sampling_params=sampling_params, + pooling_params=pooling_params, + generator=generator, + block_ids=new_req_data.block_ids, + num_computed_tokens=new_req_data.num_computed_tokens, + output_token_ids=[], + lora_request=new_req_data.lora_request, + ) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -893,7 +887,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> tuple[SpecDecodeMetadata, torch.Tensor, SpecDecodeMetadata, - torch.Tensor, int, torch.Tensor, torch.Tensor]: + torch.Tensor, int, torch.Tensor, torch.Tensor, np.ndarray]: # Check input valid total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 @@ -1173,7 +1167,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): hidden_states, aux_hidden_states = hidden_states return (attn_metadata, hidden_states, spec_decode_metadata, positions, - total_num_scheduled_tokens, sample_indices, aux_hidden_states) + total_num_scheduled_tokens, sample_indices, aux_hidden_states, + num_scheduled_tokens) def _get_cumsum_and_arange( self, @@ -1431,6 +1426,47 @@ class NPUModelRunner(LoRAModelRunnerMixin): hidden_states, attn_metadata) return spec_token_ids + def _pool( + self, + hidden_states: torch.Tensor, + num_scheduled_tokens: int, + num_scheduled_tokens_np: np.ndarray, + ) -> ModelRunnerOutput: + assert self.input_batch.num_reqs ==\ + len(self.input_batch.pooling_params), \ + "Either all or none of the requests in" \ + " a batch must be pooling request" + + extracted_hidden_states = list( + torch.split(hidden_states[:num_scheduled_tokens], + num_scheduled_tokens_np.tolist())) + + pooling_metadata = self.input_batch.pooling_metadata + + raw_pooler_output = self.model.pooler( + hidden_states=extracted_hidden_states, + pooling_metadata=pooling_metadata) + + pooler_output: list[Optional[torch.Tensor]] = [] + seq_lens = self.seq_lens[:self.input_batch.num_reqs] + for raw_output, seq_len, prompt_len in zip( + raw_pooler_output, seq_lens, pooling_metadata.prompt_lens): + + if seq_len == prompt_len: + pooler_output.append(raw_output.data.cpu()) + else: + pooler_output.append(None) + + return ModelRunnerOutput( + req_ids=self.input_batch.req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=[], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=pooler_output, + ) + @torch.inference_mode() def execute_model( self, @@ -1444,12 +1480,15 @@ class NPUModelRunner(LoRAModelRunnerMixin): # Return empty ModelRunnerOuptut if there's no work to do. return EMPTY_MODEL_RUNNER_OUTPUT (attn_metadata, hidden_states, spec_decode_metadata, positions, - num_scheduled_tokens, sample_indices, - aux_hidden_states) = (self._process_reqs(scheduler_output, - intermediate_tensors)) + num_scheduled_tokens, sample_indices, aux_hidden_states, + num_scheduled_tokens_np) = (self._process_reqs( + scheduler_output, intermediate_tensors)) with ProfileExecuteDuration().capture_async("post process"): + if self.input_batch.pooling_params: + return self._pool(hidden_states, num_scheduled_tokens, + num_scheduled_tokens_np) logits = self.model.compute_logits(hidden_states[sample_indices], None) if self.use_eagle: @@ -1795,21 +1834,75 @@ class NPUModelRunner(LoRAModelRunnerMixin): hidden_states = self._dummy_run(self.max_num_tokens) if get_pp_group().is_last_rank: - hidden_states = hidden_states[logit_indices] - logits = self.model.compute_logits(hidden_states, None) + if self.is_pooling_model: + output = self._dummy_pooler_run(hidden_states) + else: + # TODO: need to rum a dummy sampler for generate task + hidden_states = hidden_states[logit_indices] + output = self.model.compute_logits(hidden_states, None) else: - logits = None + output = None NPUPlatform.synchronize() - del hidden_states, logits + del hidden_states, output self.encoder_cache.clear() gc.collect() + @torch.inference_mode() + def _dummy_pooler_run( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + + num_tokens = hidden_states.shape[0] + max_num_reqs = self.scheduler_config.max_num_seqs + num_reqs = min(num_tokens, max_num_reqs) + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + assert sum(num_scheduled_tokens_list) == num_tokens + assert len(num_scheduled_tokens_list) == num_reqs + + hidden_states_list = list( + torch.split(hidden_states, num_scheduled_tokens_list)) + + req_num_tokens = num_tokens // num_reqs + + dummy_metadata = PoolingMetadata( + prompt_lens=torch.tensor([h.shape[0] for h in hidden_states_list], + device=self.device), + prompt_token_ids=torch.zeros((num_reqs, req_num_tokens), + dtype=torch.int32, + device=self.device), + pooling_params=[PoolingParams()] * num_reqs) + + try: + pooler_output = self.model.pooler(hidden_states=hidden_states_list, + pooling_metadata=dummy_metadata) + except RuntimeError as e: + if 'out of memory' in str(e): + raise RuntimeError( + "NPU out of memory occurred when warming up pooler with " + f"{num_reqs} dummy requests. Please try lowering " + "`max_num_seqs` or `gpu_memory_utilization` when " + "initializing the engine.") from e + else: + raise e + return pooler_output + def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: # noqa: SIM117 self.model = get_model(vllm_config=self.vllm_config) + try: + # For version compatibility, remove this after we abort vllm v0.9.1 support + from vllm.model_executor.models.interfaces import \ + has_step_pooler # type: ignore + if has_step_pooler(self.model): + self.input_batch.logits_processing_needs_token_ids = True + except ImportError: + pass if self.drafter: logger.info("Loading drafter model...") if self.use_aux_hidden_state_outputs: diff --git a/vllm_ascend/worker/npu_input_batch.py b/vllm_ascend/worker/npu_input_batch.py new file mode 100644 index 0000000..1364505 --- /dev/null +++ b/vllm_ascend/worker/npu_input_batch.py @@ -0,0 +1,681 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm-project/vllm/vllm/worker/gpu_input_batch.py +# + +from dataclasses import dataclass +from typing import Optional, cast + +import numpy as np +import torch +from vllm.lora.request import LoRARequest +from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange +from vllm.pooling_params import PoolingParams +from vllm.sampling_params import SamplingParams, SamplingType +from vllm.v1.outputs import LogprobsTensors +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.utils import copy_slice +from vllm.v1.worker.block_table import MultiGroupBlockTable + +from vllm_ascend.pool.metadata import PoolingMetadata + +_SAMPLING_EPS = 1e-5 + + +@dataclass +class CachedRequestState: + + req_id: str + prompt_token_ids: list[int] + mm_inputs: list[MultiModalKwargs] + mm_positions: list[PlaceholderRange] + sampling_params: Optional[SamplingParams] + pooling_params: Optional[PoolingParams] + generator: Optional[torch.Generator] + + block_ids: tuple[list[int], ...] + num_computed_tokens: int + output_token_ids: list[int] + + mrope_positions: Optional[torch.Tensor] = None + mrope_position_delta: Optional[int] = None + + lora_request: Optional[LoRARequest] = None + + def __post_init__(self): + self.num_prompt_tokens = len(self.prompt_token_ids) + + @property + def num_tokens(self) -> int: + return self.num_prompt_tokens + len(self.output_token_ids) + + def get_token_id(self, idx: int) -> int: + if idx < self.num_prompt_tokens: + return self.prompt_token_ids[idx] + else: + return self.output_token_ids[idx - self.num_prompt_tokens] + + +class InputBatch: + + def __init__( + self, + max_num_reqs: int, + max_model_len: int, + max_num_batched_tokens: int, + device: torch.device, + pin_memory: bool, + vocab_size: int, + block_sizes: list[int], # The block_size of each kv cache group + logits_processing_needs_token_ids: bool = False, + ): + self.max_num_reqs = max_num_reqs + self.max_model_len = max_model_len + self.max_num_batched_tokens = max_num_batched_tokens + self.device = device + self.pin_memory = pin_memory + self.vocab_size = vocab_size + self.logits_processing_needs_token_ids = ( + logits_processing_needs_token_ids) + + self._req_ids: list[Optional[str]] = [] + self.req_id_to_index: dict[str, int] = {} + + # TODO(woosuk): This buffer could be too large if max_model_len is big. + # Find a way to reduce the CPU memory usage. + # This buffer is not directly transferred to the NPU, so it does not + # need to be pinned. + self.token_ids_cpu_tensor = torch.zeros( + (max_num_reqs, max_model_len), + device="cpu", + dtype=torch.int32, + pin_memory=False, + ) + self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() + self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) + self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) + self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) + self.num_computed_tokens_cpu_tensor = torch.zeros( + (max_num_reqs, ), + device="cpu", + dtype=torch.int32, + pin_memory=pin_memory, + ) + self.num_computed_tokens_cpu = \ + self.num_computed_tokens_cpu_tensor.numpy() + + # Block table. + self.block_table = MultiGroupBlockTable( + max_num_reqs=max_num_reqs, + max_model_len=max_model_len, + max_num_batched_tokens=max_num_batched_tokens, + pin_memory=pin_memory, + device=device, + block_sizes=block_sizes, + ) + + # Sampling-related. + self.temperature = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.temperature_cpu = self.temperature_cpu_tensor.numpy() + self.greedy_reqs: set[str] = set() + self.random_reqs: set[str] = set() + + self.top_p = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.top_p_cpu = self.top_p_cpu_tensor.numpy() + self.top_p_reqs: set[str] = set() + + self.top_k = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device=device) + self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device="cpu", + pin_memory=pin_memory) + self.top_k_cpu = self.top_k_cpu_tensor.numpy() + self.top_k_reqs: set[str] = set() + + self.min_p = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.min_p_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.min_p_cpu = self.min_p_cpu_tensor.numpy() + self.min_p_reqs: set[str] = set() + + # Frequency penalty related data structures + self.frequency_penalties = torch.empty((max_num_reqs, ), + dtype=torch.float, + device=device) + self.frequency_penalties_cpu_tensor = torch.empty( + (max_num_reqs, ), + dtype=torch.float, + device="cpu", + pin_memory=pin_memory) + self.frequency_penalties_cpu = \ + self.frequency_penalties_cpu_tensor.numpy() + self.frequency_penalties_reqs: set[str] = set() + + # Presence penalty related data structures + self.presence_penalties = torch.empty((max_num_reqs, ), + dtype=torch.float, + device=device) + self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float, + device="cpu", + pin_memory=pin_memory) + self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy( + ) + self.presence_penalties_reqs: set[str] = set() + + # Repetition penalty related data structures + self.repetition_penalties = torch.empty((max_num_reqs, ), + dtype=torch.float, + device=device) + self.repetition_penalties_cpu_tensor = torch.empty( + (max_num_reqs, ), + dtype=torch.float, + device="cpu", + pin_memory=pin_memory) + self.repetition_penalties_cpu = \ + self.repetition_penalties_cpu_tensor.numpy() + self.repetition_penalties_reqs: set[str] = set() + + # req_index -> (min_tokens, stop_token_ids) + self.min_tokens: dict[int, tuple[int, set[int]]] = {} + + # lora related + self.request_lora_mapping = np.zeros((self.max_num_reqs, ), + dtype=np.int32) + self.lora_id_to_request_ids: dict[int, set[str]] = {} + self.lora_id_to_lora_request: dict[int, LoRARequest] = {} + + # req_index -> generator + # NOTE(woosuk): The indices of the requests that do not have their own + # generator should not be included in the dictionary. + self.generators: dict[int, torch.Generator] = {} + + self.num_logprobs: dict[str, int] = {} + # NOTE(rob): num_prompt_logprobs only includes reqs + # that are currently in the prefill phase. + self.num_prompt_logprobs: dict[str, int] = {} + + # To accumulate prompt logprobs tensor chunks across prefill steps. + self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {} + + self.logit_bias: list[Optional[dict[int, + float]]] = [None] * max_num_reqs + self.has_allowed_token_ids: set[str] = set() + # NOTE(lufang): In the mask tensor, if the corresponding token allowed, + # the value is False. Since we use masked_fill_ to set -inf. + self.allowed_token_ids_mask: Optional[torch.Tensor] = None + self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None + + # req_index -> bad_words_token_ids + self.bad_words_token_ids: dict[int, list[list[int]]] = {} + + self.req_output_token_ids: list[Optional[list[int]]] = [] + + # This is updated each time the batch constituents change. + self.sampling_metadata = self._make_sampling_metadata() + + self.pooling_params: dict[str, PoolingParams] = {} + + @property + def req_ids(self) -> list[str]: + # None elements should only be present transiently + # while performing state updates to the batch. + return cast(list[str], self._req_ids) + + def add_request( + self, + request: "CachedRequestState", + req_index: Optional[int] = None, + ) -> None: + if req_index is None: + req_index = self.num_reqs + assert req_index < self.max_num_reqs + + req_id = request.req_id + if req_index == len(self._req_ids): + self._req_ids.append(req_id) + self.req_output_token_ids.append(request.output_token_ids) + else: + self._req_ids[req_index] = req_id + self.req_output_token_ids[req_index] = request.output_token_ids + + self.req_id_to_index[req_id] = req_index + + # Copy the prompt token ids and output token ids. + num_prompt_tokens = len(request.prompt_token_ids) + self.num_prompt_tokens[req_index] = num_prompt_tokens + self.token_ids_cpu[ + req_index, :num_prompt_tokens] = request.prompt_token_ids + start_idx = num_prompt_tokens + end_idx = start_idx + len(request.output_token_ids) + self.token_ids_cpu[req_index, + start_idx:end_idx] = request.output_token_ids + # Number of token ids in token_ids_cpu. + # NOTE(woosuk): This may include spec decode tokens. + self.num_tokens[req_index] = request.num_tokens + # Number of tokens without spec decode tokens. + self.num_tokens_no_spec[req_index] = request.num_tokens + + self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens + self.block_table.add_row(request.block_ids, req_index) + + if sampling_params := request.sampling_params: + if sampling_params.sampling_type == SamplingType.GREEDY: + # Avoid later division by zero. + self.temperature_cpu[req_index] = -1.0 + self.greedy_reqs.add(req_id) + else: + self.temperature_cpu[req_index] = sampling_params.temperature + self.random_reqs.add(req_id) + + self.top_p_cpu[req_index] = sampling_params.top_p + if sampling_params.top_p < 1: + self.top_p_reqs.add(req_id) + top_k = sampling_params.top_k + if 0 < top_k < self.vocab_size: + self.top_k_reqs.add(req_id) + else: + top_k = self.vocab_size + self.top_k_cpu[req_index] = top_k + self.min_p_cpu[req_index] = sampling_params.min_p + self.frequency_penalties_cpu[ + req_index] = sampling_params.frequency_penalty + if sampling_params.min_p > _SAMPLING_EPS: + self.min_p_reqs.add(req_id) + if sampling_params.frequency_penalty != 0.0: + self.frequency_penalties_reqs.add(req_id) + self.presence_penalties_cpu[ + req_index] = sampling_params.presence_penalty + if sampling_params.presence_penalty != 0.0: + self.presence_penalties_reqs.add(req_id) + self.repetition_penalties_cpu[ + req_index] = sampling_params.repetition_penalty + if sampling_params.repetition_penalty != 1.0: + self.repetition_penalties_reqs.add(req_id) + if sampling_params.min_tokens: + self.min_tokens[req_index] = ( + sampling_params.min_tokens, + sampling_params.all_stop_token_ids) + + # NOTE(woosuk): self.generators should not include the requests that + # do not have their own generator. + if request.generator is not None: + self.generators[req_index] = request.generator + + if sampling_params.logprobs is not None: + self.num_logprobs[req_id] = sampling_params.logprobs + if sampling_params.prompt_logprobs is not None: + self.num_prompt_logprobs[ + req_id] = sampling_params.prompt_logprobs + if sampling_params.logit_bias is not None: + self.logit_bias[req_index] = sampling_params.logit_bias + + if sampling_params.allowed_token_ids: + self.has_allowed_token_ids.add(req_id) + if self.allowed_token_ids_mask_cpu_tensor is None: + # Lazy allocation for this tensor, which can be large. + # False means we don't fill with -inf. + self.allowed_token_ids_mask = torch.zeros( + self.max_num_reqs, + self.vocab_size, + dtype=torch.bool, + device=self.device) + self.allowed_token_ids_mask_cpu_tensor = torch.zeros( + self.max_num_reqs, + self.vocab_size, + dtype=torch.bool, + device="cpu") + self.allowed_token_ids_mask_cpu_tensor[req_index] = True + # False means we don't fill with -inf. + self.allowed_token_ids_mask_cpu_tensor[req_index][ + sampling_params.allowed_token_ids] = False + + if sampling_params.bad_words_token_ids: + self.bad_words_token_ids[ + req_index] = sampling_params.bad_words_token_ids + else: + assert request.pooling_params is not None + self.pooling_params[req_id] = request.pooling_params + + # Add request lora ID + if request.lora_request: + lora_id = request.lora_request.lora_int_id + if lora_id not in self.lora_id_to_request_ids: + self.lora_id_to_request_ids[lora_id] = set() + + self.request_lora_mapping[req_index] = lora_id + self.lora_id_to_request_ids[lora_id].add(request.req_id) + self.lora_id_to_lora_request[lora_id] = request.lora_request + else: + # No LoRA + self.request_lora_mapping[req_index] = 0 + + def remove_request(self, req_id: str) -> Optional[int]: + """This method must always be followed by a call to condense().""" + + req_index = self.req_id_to_index.pop(req_id, None) + if req_index is None: + return None + self._req_ids[req_index] = None + self.req_output_token_ids[req_index] = None + + self.greedy_reqs.discard(req_id) + self.random_reqs.discard(req_id) + self.top_p_reqs.discard(req_id) + self.top_k_reqs.discard(req_id) + self.min_p_reqs.discard(req_id) + self.min_tokens.pop(req_index, None) + self.frequency_penalties_reqs.discard(req_id) + self.presence_penalties_reqs.discard(req_id) + self.repetition_penalties_reqs.discard(req_id) + self.generators.pop(req_index, None) + self.num_logprobs.pop(req_id, None) + self.num_prompt_logprobs.pop(req_id, None) + self.in_progress_prompt_logprobs_cpu.pop(req_id, None) + + # LoRA + lora_id = self.request_lora_mapping[req_index] + if lora_id != 0: + self.lora_id_to_request_ids[lora_id].discard(req_id) + if len(self.lora_id_to_request_ids[lora_id]) == 0: + self.lora_id_to_request_ids.pop(lora_id) + self.lora_id_to_lora_request.pop(lora_id) + self.request_lora_mapping[req_index] = 0 + + self.logit_bias[req_index] = None + self.has_allowed_token_ids.discard(req_id) + if self.allowed_token_ids_mask_cpu_tensor is not None: + # False means we don't fill with -inf. + self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False) + self.bad_words_token_ids.pop(req_index, None) + self.pooling_params.pop(req_id, None) + return req_index + + def condense(self, empty_req_indices: list[int]) -> None: + """Move non-empty requests down into lower, empty indices. + + Args: + empty_req_indices: empty batch indices, sorted descending. + """ + num_reqs = self.num_reqs + if num_reqs == 0: + # The batched states are empty. + self._req_ids.clear() + self.req_output_token_ids.clear() + return + + # NOTE(woosuk): This function assumes that the empty_req_indices + # is sorted in descending order. + last_req_index = num_reqs + len(empty_req_indices) - 1 + while empty_req_indices: + # Find the largest non-empty index. + while last_req_index in empty_req_indices: + last_req_index -= 1 + + # Find the smallest empty index. + empty_index = empty_req_indices.pop() + if empty_index >= last_req_index: + break + + # Swap the states. + req_id = self._req_ids[last_req_index] + output_token_ids = self.req_output_token_ids[last_req_index] + assert req_id is not None + self._req_ids[empty_index] = req_id + self._req_ids[last_req_index] = None + self.req_output_token_ids[empty_index] = output_token_ids + self.req_output_token_ids[last_req_index] = None + self.req_id_to_index[req_id] = empty_index + + num_tokens = self.num_tokens[last_req_index] + self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ + last_req_index, :num_tokens] + self.num_tokens[empty_index] = num_tokens + self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[ + last_req_index] + self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[ + last_req_index] + self.num_computed_tokens_cpu[ + empty_index] = self.num_computed_tokens_cpu[last_req_index] + self.block_table.move_row(last_req_index, empty_index) + self.temperature_cpu[empty_index] = self.temperature_cpu[ + last_req_index] + self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] + self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] + self.frequency_penalties_cpu[ + empty_index] = self.frequency_penalties_cpu[last_req_index] + self.presence_penalties_cpu[ + empty_index] = self.presence_penalties_cpu[last_req_index] + self.repetition_penalties_cpu[ + empty_index] = self.repetition_penalties_cpu[last_req_index] + self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index] + generator = self.generators.pop(last_req_index, None) + if generator is not None: + self.generators[empty_index] = generator + + min_token = self.min_tokens.pop(last_req_index, None) + if min_token is not None: + self.min_tokens[empty_index] = min_token + + self.request_lora_mapping[empty_index] = self.request_lora_mapping[ + last_req_index] + + self.logit_bias[empty_index] = self.logit_bias[last_req_index] + + if self.allowed_token_ids_mask_cpu_tensor is not None: + self.allowed_token_ids_mask_cpu_tensor[ + empty_index] = self.allowed_token_ids_mask_cpu_tensor[ + last_req_index] + + bad_words_token_ids = self.bad_words_token_ids.pop( + last_req_index, None) + if bad_words_token_ids is not None: + self.bad_words_token_ids[empty_index] = bad_words_token_ids + # Decrement last_req_index since it is now empty. + last_req_index -= 1 + + # Trim lists to the batch size. + del self._req_ids[self.num_reqs:] + del self.req_output_token_ids[self.num_reqs:] + + def refresh_sampling_metadata(self): + self.sampling_metadata = self._make_sampling_metadata() + + def _make_sampling_metadata(self) -> SamplingMetadata: + num_reqs = self.num_reqs + if not self.all_greedy: + temperature = copy_slice(self.temperature_cpu_tensor, + self.temperature, num_reqs) + else: + temperature = None + if not self.no_top_p: + copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs) + if not self.no_top_k: + copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs) + if not self.no_min_p: + copy_slice(self.min_p_cpu_tensor, self.min_p, num_reqs) + + if not self.no_penalties: + # Since syncing these tensors is expensive only copy them + # if necessary i.e. if there are requests which require + # penalties to be applied during sampling. + copy_slice(self.frequency_penalties_cpu_tensor, + self.frequency_penalties, num_reqs) + copy_slice(self.presence_penalties_cpu_tensor, + self.presence_penalties, num_reqs) + copy_slice(self.repetition_penalties_cpu_tensor, + self.repetition_penalties, num_reqs) + + needs_prompt_token_ids = (not self.no_penalties or + (self.num_reqs > 0 + and self.logits_processing_needs_token_ids)) + if needs_prompt_token_ids: + # The prompt tokens are used only for applying penalties or + # step pooling during the sampling/pooling process. + # Hence copy these tensors only when there are requests which + # need penalties/step_pooler to be applied. + prompt_token_ids = self._make_prompt_token_ids_tensor() + else: + prompt_token_ids = None + + allowed_token_ids_mask: Optional[torch.Tensor] = None + if not self.no_allowed_token_ids: + assert self.allowed_token_ids_mask is not None + copy_slice(self.allowed_token_ids_mask_cpu_tensor, + self.allowed_token_ids_mask, num_reqs) + allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs] + + return SamplingMetadata( + temperature=temperature, + all_greedy=self.all_greedy, + all_random=self.all_random, + top_p=None if self.no_top_p else self.top_p[:num_reqs], + top_k=None if self.no_top_k else self.top_k[:num_reqs], + min_p=None if self.no_min_p else self.min_p[:num_reqs], + generators=self.generators, + max_num_logprobs=self.max_num_logprobs, + prompt_token_ids=prompt_token_ids, + frequency_penalties=self.frequency_penalties[:num_reqs], + presence_penalties=self.presence_penalties[:num_reqs], + repetition_penalties=self.repetition_penalties[:num_reqs], + output_token_ids=cast(list[list[int]], self.req_output_token_ids), + min_tokens=self.min_tokens, + no_penalties=self.no_penalties, + logit_bias=self.logit_bias[:num_reqs], + allowed_token_ids_mask=allowed_token_ids_mask, + bad_words_token_ids=self.bad_words_token_ids, + ) + + @property + def pooling_metadata(self) -> PoolingMetadata: + if len(self.pooling_params) == 0: + pooling_params = [] + else: + # Note, for now this assumes that all request in the batch + # are either sampling or pooling requests + assert len(self.req_ids) == len(self.pooling_params) + pooling_params = [ + self.pooling_params[req_id] for req_id in self.req_ids + ] + + return PoolingMetadata( + prompt_lens=torch.from_numpy( + self.num_prompt_tokens[:self.num_reqs]).to(self.device), + prompt_token_ids=self.sampling_metadata.prompt_token_ids, + pooling_params=pooling_params, + ) + + def _make_prompt_token_ids_tensor(self) -> torch.Tensor: + max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max() + prompt_token_ids_cpu_tensor = torch.empty( + (self.num_reqs, max_prompt_len), + device="cpu", + dtype=torch.int64, + pin_memory=self.pin_memory, + ) + prompt_token_ids = prompt_token_ids_cpu_tensor.numpy() + prompt_token_ids[:] = self.token_ids_cpu[:self. + num_reqs, :max_prompt_len] + # Use the value of vocab_size as a pad since we don't have a + # token_id of this value. + for i in range(self.num_reqs): + prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size + return prompt_token_ids_cpu_tensor.to(device=self.device, + non_blocking=True) + + def make_lora_inputs( + self, num_scheduled_tokens: np.ndarray + ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]: + """ + Given the num_scheduled_tokens for each request in the batch, return + datastructures used to activate the current LoRAs. + Returns: + 1. prompt_lora_mapping: A tuple of size self.num_reqs where, + prompt_lora_mapping[i] is the LoRA id to use for the ith prompt. + 2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens) + where, token_lora_mapping[i] is the LoRA id to use for ith token. + 3. lora_requests: Set of relevant LoRA requests. + """ + + req_lora_mapping = self.request_lora_mapping[:self.num_reqs] + prompt_lora_mapping = tuple(req_lora_mapping) + token_lora_mapping = tuple( + req_lora_mapping.repeat(num_scheduled_tokens)) + active_lora_requests: set[LoRARequest] = set( + self.lora_id_to_lora_request.values()) + + return prompt_lora_mapping, token_lora_mapping, active_lora_requests + + @property + def num_reqs(self) -> int: + return len(self.req_id_to_index) + + @property + def all_greedy(self) -> bool: + return len(self.random_reqs) == 0 + + @property + def all_random(self) -> bool: + return len(self.greedy_reqs) == 0 + + @property + def no_top_p(self) -> bool: + return len(self.top_p_reqs) == 0 + + @property + def no_top_k(self) -> bool: + return len(self.top_k_reqs) == 0 + + @property + def no_min_p(self) -> bool: + return len(self.min_p_reqs) == 0 + + @property + def no_penalties(self) -> bool: + return (len(self.presence_penalties_reqs) == 0 + and len(self.frequency_penalties_reqs) == 0 + and len(self.repetition_penalties_reqs) == 0) + + @property + def max_num_logprobs(self) -> Optional[int]: + return max(self.num_logprobs.values()) if self.num_logprobs else None + + @property + def no_prompt_logprob(self) -> bool: + return not self.num_prompt_logprobs + + @property + def no_allowed_token_ids(self) -> bool: + return len(self.has_allowed_token_ids) == 0