[Core] Support pooling (#229)

This PR added pooling support for vllm-ascend

Tested with `bge-base-en-v1.5` by encode:
```
from vllm import LLM

# Sample prompts.
prompts = [
  "Hello, my name is",
  "The president of the United States is",
  "The capital of France is",
  "The future of AI is",
]
# Create an LLM.
model = LLM(model="./bge-base-en-v1.5", enforce_eager=True)
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = model.encode(prompts)
# Print the outputs.
for output in outputs:
    print(output.outputs.embedding)  # list of 4096 floats
```

Tested by embedding:
```
from vllm import LLM, SamplingParams

llm = LLM(model="./bge-base-en-v1.5", task="embed")
(output,) = llm.embed("Hello, my name is")

embeds = output.outputs.embedding
print(f"Embeddings: {embeds!r} (size={len(embeds)})")
```

Related: https://github.com/vllm-project/vllm-ascend/issues/200

## Known issue
The accuracy is not correct since this feature rely on `enc-dec`
support. It'll be done in the following PR by @MengqingCao

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-03-04 15:59:34 +08:00
committed by GitHub
parent 8fda31cafe
commit ae49bfd13a
7 changed files with 258 additions and 71 deletions

View File

@@ -7,7 +7,7 @@
| LoRA | ✗ | Plan in 2025 Q1 |
| Prompt adapter | ✗ | Plan in 2025 Q1 |
| Speculative decoding | ✗ | Plan in 2025 Q1 |
| Pooling | | Plan in 2025 Q2 |
| Pooling | | |
| Enc-dec | ✗ | Plan in 2025 Q2 |
| Multi Modality | ✅ (LLaVA/Qwen2-vl/Qwen2-audio/internVL)| Add more model support in 2025 Q1 |
| LogProbs | ✅ ||

View File

@@ -38,7 +38,7 @@ from vllm.attention.backends.utils import (CommonAttentionState,
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
if TYPE_CHECKING:
from vllm_ascend.model_runner import ModelInputForNPUBuilder
from vllm_ascend.worker.model_runner import ModelInputForNPUBuilder
def generate_attn_mask(max_seq_len: int, dtype=torch.float16):
@@ -211,6 +211,9 @@ class AscendMetadata(AttentionMetadata):
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]] = None
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor] = None
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int] = None
@@ -258,6 +261,9 @@ class AscendMetadata(AttentionMetadata):
block_tables = (None if self.block_tables is None else
self.block_tables[:self.num_prefills])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[:self.num_prefills])
# Construct & cache prefill-phase attention metadata structure.
self._cached_prefill_metadata = AscendMetadata(
num_prefills=self.num_prefills,
@@ -265,6 +271,7 @@ class AscendMetadata(AttentionMetadata):
num_decode_tokens=0,
slot_mapping=slot_mapping,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
@@ -297,7 +304,8 @@ class AscendMetadata(AttentionMetadata):
self.seq_lens[self.num_prefills:])
block_tables = (None if self.block_tables is None else
self.block_tables[self.num_prefills:])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[self.num_prefills:])
# Construct & cache decode-phase attention metadata structure.
self._cached_decode_metadata = AscendMetadata(
num_prefills=0,
@@ -305,6 +313,7 @@ class AscendMetadata(AttentionMetadata):
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
block_tables=block_tables,
@@ -322,7 +331,6 @@ class AscendMetadata(AttentionMetadata):
class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
_metadata_cls = AscendMetadata
_attn_mask_builder = None # noqa
def __init__(self, input_builder: "ModelInputForNPUBuilder"):
@@ -451,7 +459,11 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
self.multimodal_placeholder_maps.items()
}
return self._metadata_cls( # type: ignore
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.long,
device=device)
return AscendMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps,
@@ -459,6 +471,7 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=self.num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
@@ -528,12 +541,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
shape = [batch_size, seq_len * num_heads * head_size]
"""
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
attn_type = self.attn_type
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl")
# View q k v to BSH.
num_tokens = query.shape[0]
query = query.view(-1, self.num_heads, self.head_size)

View File

@@ -105,7 +105,7 @@ class NPUPlatform(Platform):
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config = vllm_config.parallel_config
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm_ascend.worker.NPUWorker"
parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker"
cache_config = vllm_config.cache_config
if cache_config and cache_config.block_size is None:
# TODO: Set block_size to 128 will lead unexpected accuracy issue in mla case. Please set block_size to 128 back once the problem is fixed.

View File

View File

@@ -808,6 +808,9 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
SamplingMetadataCache() \
if self.parallel_config.pipeline_parallel_size == 1 else None
def get_model(self) -> nn.Module:
return self.model
def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m:
@@ -1341,6 +1344,3 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
self.execute_model(model_input, kv_caches, intermediate_tensors)
current_platform.synchronize()
return
def get_model(self) -> nn.Module:
return self.model

View File

@@ -0,0 +1,187 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/vllm/worker/worker.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import dataclasses
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
from vllm.distributed import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MultiModalKwargs
from vllm.pooling_params import PoolingParams
from vllm.sequence import (IntermediateTensors, SequenceData,
SequenceGroupMetadata)
from vllm_ascend.worker.model_runner import (ModelInputForNPU,
ModelInputForNPUBuilder,
NPUModelRunnerBase)
@dataclasses.dataclass(frozen=True)
class ModelInputForNPUWithPoolingMetadata(ModelInputForNPU):
"""
Used by the PoolingModelRunner.
"""
pooling_metadata: Optional["PoolingMetadata"] = None
class NPUPoolingModelRunner(
NPUModelRunnerBase[ModelInputForNPUWithPoolingMetadata]):
_model_input_cls: Type[ModelInputForNPUWithPoolingMetadata] = (
ModelInputForNPUWithPoolingMetadata)
_builder_cls: Type[ModelInputForNPUBuilder] = ModelInputForNPUBuilder
def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str,
Any]) -> ModelInputForNPUWithPoolingMetadata:
return ModelInputForNPUWithPoolingMetadata.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
)
def prepare_model_input(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForNPUWithPoolingMetadata:
assert seq_group_metadata_list is not None
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids)
# Prepare PoolingMetadata.
assert model_input.seq_lens is not None
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
model_input.seq_lens)
return dataclasses.replace(model_input,
pooling_metadata=pooling_metadata)
def _prepare_pooling(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
) -> PoolingMetadata:
"""Prepare PoolingMetadata for the sequence group metadata list."""
seq_groups: List[Tuple[List[int], PoolingParams]] = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
pooling_params = seq_group_metadata.pooling_params
seq_groups.append((seq_ids, pooling_params))
seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list:
seq_data.update(seq_group_metadata.seq_data)
pooling_metadata = PoolingMetadata(
seq_groups=seq_groups,
seq_data=seq_data,
prompt_lens=prompt_lens,
)
return pooling_metadata
@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForNPUWithPoolingMetadata,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
):
if num_steps > 1:
raise ValueError(
"PoolingModelRunner does not support multi-step execution.")
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
if self.prompt_adapter_config:
assert model_input.prompt_adapter_requests is not None
assert model_input.prompt_adapter_mapping is not None
self.set_active_prompt_adapters(
model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping)
assert model_input.attn_metadata is not None
virtual_engine = model_input.virtual_engine
model_executable = self.model
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
seqlen_agnostic_kwargs = {
"finished_requests_ids": model_input.finished_requests_ids,
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_inner_state else {}
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
import torch_npu
model_forward_start = torch_npu.npu.Event(enable_timing=True)
model_forward_end = torch_npu.npu.Event(enable_timing=True)
model_forward_start.record()
cross_enc_kwargs = {}
if model_input.token_types is not None:
cross_enc_kwargs["token_type_ids"] = model_input.token_types
with set_forward_context(model_input.attn_metadata, self.vllm_config,
virtual_engine):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device),
**cross_enc_kwargs,
**seqlen_agnostic_kwargs)
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end.record()
# Only perform pooling in the last pipeline stage.
if not get_pp_group().is_last_rank:
if (self.is_driver_worker
and hidden_or_intermediate_states is not None
and isinstance(hidden_or_intermediate_states,
IntermediateTensors)
and self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end.synchronize()
model_forward_time = model_forward_start.elapsed_time(
model_forward_end)
orig_model_forward_time = 0.0
if intermediate_tensors is not None:
orig_model_forward_time = intermediate_tensors.tensors.get(
"model_forward_time", torch.tensor(0.0)).item()
hidden_or_intermediate_states.tensors["model_forward_time"] = (
torch.tensor(model_forward_time + orig_model_forward_time))
return hidden_or_intermediate_states
# Only perform pooling in the driver worker.
if not self.is_driver_worker:
return []
return [
self.model.pooler(hidden_states=hidden_or_intermediate_states,
pooling_metadata=model_input.pooling_metadata)
]

View File

@@ -33,7 +33,6 @@ from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.platforms import current_platform
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SequenceGroupMetadata, SequenceGroupMetadataDelta)
@@ -41,12 +40,13 @@ from vllm.utils import bind_kv_cache
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner_base import ModelRunnerBase
from vllm.worker.pooling_model_runner import PoolingModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput)
from vllm_ascend.model_runner import NPUModelRunner
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import try_register_lib
from vllm_ascend.worker.model_runner import NPUModelRunner
from vllm_ascend.worker.pooling_model_runner import NPUPoolingModelRunner
logger = init_logger(__name__)
@@ -58,15 +58,12 @@ class NPUWorker(LocalOrDistributedWorkerBase):
distributed inference, each worker is assigned a partition of the model.
"""
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
model_runner_cls: Optional[Type[ModelRunnerBase]] = None,
) -> None:
def __init__(self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False):
# Register ops when worker init.
from vllm_ascend import ops # noqa: F401
@@ -101,7 +98,7 @@ class NPUWorker(LocalOrDistributedWorkerBase):
ModelRunnerClass: Type[ModelRunnerBase] = NPUModelRunner
if model_config.runner_type == "pooling":
ModelRunnerClass = PoolingModelRunner
ModelRunnerClass = NPUPoolingModelRunner
elif self.model_config.is_encoder_decoder:
ModelRunnerClass = EncoderDecoderModelRunner
self.model_runner: ModelRunnerBase = ModelRunnerClass(
@@ -110,8 +107,6 @@ class NPUWorker(LocalOrDistributedWorkerBase):
is_driver_worker=is_driver_worker,
**speculative_args,
)
if model_runner_cls is not None:
self.model_runner = model_runner_cls(self.model_runner)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
@@ -155,6 +150,26 @@ class NPUWorker(LocalOrDistributedWorkerBase):
else:
self.profiler = None
def init_device(self) -> None:
if self.device_config.device.type == "npu":
self.device = torch.device(f"npu:{self.local_rank}")
NPUPlatform.set_device(self.device)
NPUPlatform.empty_cache()
self.init_npu_memory = NPUPlatform.mem_get_info()[0]
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
self._init_worker_distributed_environment(self.parallel_config,
self.rank,
self.distributed_init_method,
self.local_rank)
# Set random seed.
set_random_seed(self.model_config.seed)
def load_model(self):
self.model_runner.load_model()
def start_profile(self):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
@@ -165,28 +180,6 @@ class NPUWorker(LocalOrDistributedWorkerBase):
raise RuntimeError("Profiler is not enabled.")
self.profiler.stop()
def init_device(self) -> None:
if self.device_config.device.type == "npu":
# # This env var set by Ray causes exceptions with graph building.
# os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
self.device = torch.device(f"npu:{self.local_rank}")
current_platform.set_device(self.device)
current_platform.empty_cache()
self.init_npu_memory = current_platform.mem_get_info()[0]
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
init_worker_distributed_environment(self.parallel_config, self.rank,
self.distributed_init_method,
self.local_rank)
# Set random seed.
set_random_seed(self.model_config.seed)
def load_model(self):
self.model_runner.load_model()
def save_sharded_state(
self,
path: str,
@@ -206,7 +199,7 @@ class NPUWorker(LocalOrDistributedWorkerBase):
self.model_runner.save_tensorized_model(
tensorizer_config=tensorizer_config, )
@current_platform.inference_mode()
@NPUPlatform.inference_mode()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
@@ -219,7 +212,7 @@ class NPUWorker(LocalOrDistributedWorkerBase):
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
current_platform.empty_cache()
NPUPlatform.empty_cache()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
@@ -227,7 +220,7 @@ class NPUWorker(LocalOrDistributedWorkerBase):
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
free_npu_memory, total_npu_memory = current_platform.mem_get_info()
free_npu_memory, total_npu_memory = NPUPlatform.mem_get_info()
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
peak_memory = self.init_npu_memory - free_npu_memory
@@ -248,7 +241,7 @@ class NPUWorker(LocalOrDistributedWorkerBase):
gc.collect()
# TODO: don`t need impl this func after empty_cache in
# Worker.determine_num_available_blocks() unified`
current_platform.empty_cache()
NPUPlatform.empty_cache()
return num_npu_blocks, num_cpu_blocks
def initialize_cache(self, num_gpu_blocks: int,
@@ -448,21 +441,21 @@ class NPUWorker(LocalOrDistributedWorkerBase):
self.model_config,
self.parallel_config)
def init_worker_distributed_environment(
parallel_config: ParallelConfig,
rank: int,
distributed_init_method: Optional[str] = None,
local_rank: int = -1,
backend: str = "hccl") -> None:
"""Initialize the distributed environment."""
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank, backend)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
def _init_worker_distributed_environment(
self,
parallel_config: ParallelConfig,
rank: int,
distributed_init_method: Optional[str] = None,
local_rank: int = -1,
backend: str = "hccl") -> None:
"""Initialize the distributed environment."""
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank,
backend)
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free,