diff --git a/docs/source/user_guide/suppoted_features.md b/docs/source/user_guide/suppoted_features.md index 39055b6..b864ef6 100644 --- a/docs/source/user_guide/suppoted_features.md +++ b/docs/source/user_guide/suppoted_features.md @@ -7,7 +7,7 @@ | LoRA | ✗ | Plan in 2025 Q1 | | Prompt adapter | ✗ | Plan in 2025 Q1 | | Speculative decoding | ✗ | Plan in 2025 Q1 | -| Pooling | ✗ | Plan in 2025 Q2 | +| Pooling | ✅ | | | Enc-dec | ✗ | Plan in 2025 Q2 | | Multi Modality | ✅ (LLaVA/Qwen2-vl/Qwen2-audio/internVL)| Add more model support in 2025 Q1 | | LogProbs | ✅ || diff --git a/vllm_ascend/attention.py b/vllm_ascend/attention.py index 674e100..3b1eb2b 100644 --- a/vllm_ascend/attention.py +++ b/vllm_ascend/attention.py @@ -38,7 +38,7 @@ from vllm.attention.backends.utils import (CommonAttentionState, from vllm.utils import async_tensor_h2d, make_tensor_with_pad if TYPE_CHECKING: - from vllm_ascend.model_runner import ModelInputForNPUBuilder + from vllm_ascend.worker.model_runner import ModelInputForNPUBuilder def generate_attn_mask(max_seq_len: int, dtype=torch.float16): @@ -211,6 +211,9 @@ class AscendMetadata(AttentionMetadata): # the computed tokens + new tokens None if it is a decoding. seq_lens: Optional[List[int]] = None + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] = None + # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] = None @@ -258,6 +261,9 @@ class AscendMetadata(AttentionMetadata): block_tables = (None if self.block_tables is None else self.block_tables[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + # Construct & cache prefill-phase attention metadata structure. self._cached_prefill_metadata = AscendMetadata( num_prefills=self.num_prefills, @@ -265,6 +271,7 @@ class AscendMetadata(AttentionMetadata): num_decode_tokens=0, slot_mapping=slot_mapping, seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, @@ -297,7 +304,8 @@ class AscendMetadata(AttentionMetadata): self.seq_lens[self.num_prefills:]) block_tables = (None if self.block_tables is None else self.block_tables[self.num_prefills:]) - + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) # Construct & cache decode-phase attention metadata structure. self._cached_decode_metadata = AscendMetadata( num_prefills=0, @@ -305,6 +313,7 @@ class AscendMetadata(AttentionMetadata): num_decode_tokens=self.num_decode_tokens, slot_mapping=slot_mapping, seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, block_tables=block_tables, @@ -322,7 +331,6 @@ class AscendMetadata(AttentionMetadata): class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]): - _metadata_cls = AscendMetadata _attn_mask_builder = None # noqa def __init__(self, input_builder: "ModelInputForNPUBuilder"): @@ -451,7 +459,11 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]): self.multimodal_placeholder_maps.items() } - return self._metadata_cls( # type: ignore + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.long, + device=device) + + return AscendMetadata( num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, multi_modal_placeholder_index_maps=placeholder_index_maps, @@ -459,6 +471,7 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]): num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=self.num_decode_tokens, seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, max_prefill_seq_len=max_prefill_seq_len, max_decode_seq_len=max_decode_seq_len, @@ -528,12 +541,6 @@ class AscendAttentionBackendImpl(AttentionImpl): shape = [batch_size, seq_len * num_heads * head_size] """ assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 - attn_type = self.attn_type - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "PallasAttentionBackendImpl") # View q k v to BSH. num_tokens = query.shape[0] query = query.view(-1, self.num_heads, self.head_size) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 45647b0..9e84a13 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -105,7 +105,7 @@ class NPUPlatform(Platform): def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config = vllm_config.parallel_config if parallel_config.worker_cls == "auto": - parallel_config.worker_cls = "vllm_ascend.worker.NPUWorker" + parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker" cache_config = vllm_config.cache_config if cache_config and cache_config.block_size is None: # TODO: Set block_size to 128 will lead unexpected accuracy issue in mla case. Please set block_size to 128 back once the problem is fixed. diff --git a/vllm_ascend/worker/__init__.py b/vllm_ascend/worker/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_ascend/model_runner.py b/vllm_ascend/worker/model_runner.py similarity index 100% rename from vllm_ascend/model_runner.py rename to vllm_ascend/worker/model_runner.py index 1cf06dd..d24ff7e 100644 --- a/vllm_ascend/model_runner.py +++ b/vllm_ascend/worker/model_runner.py @@ -808,6 +808,9 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]): SamplingMetadataCache() \ if self.parallel_config.pipeline_parallel_size == 1 else None + def get_model(self) -> nn.Module: + return self.model + def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: @@ -1341,6 +1344,3 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]): self.execute_model(model_input, kv_caches, intermediate_tensors) current_platform.synchronize() return - - def get_model(self) -> nn.Module: - return self.model diff --git a/vllm_ascend/worker/pooling_model_runner.py b/vllm_ascend/worker/pooling_model_runner.py new file mode 100644 index 0000000..6fbe2d1 --- /dev/null +++ b/vllm_ascend/worker/pooling_model_runner.py @@ -0,0 +1,187 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# Adapted from vllm-project/vllm/vllm/worker/worker.py +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import dataclasses +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +from vllm.distributed import get_pp_group +from vllm.forward_context import set_forward_context +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.multimodal import MultiModalKwargs +from vllm.pooling_params import PoolingParams +from vllm.sequence import (IntermediateTensors, SequenceData, + SequenceGroupMetadata) + +from vllm_ascend.worker.model_runner import (ModelInputForNPU, + ModelInputForNPUBuilder, + NPUModelRunnerBase) + + +@dataclasses.dataclass(frozen=True) +class ModelInputForNPUWithPoolingMetadata(ModelInputForNPU): + """ + Used by the PoolingModelRunner. + """ + pooling_metadata: Optional["PoolingMetadata"] = None + + +class NPUPoolingModelRunner( + NPUModelRunnerBase[ModelInputForNPUWithPoolingMetadata]): + + _model_input_cls: Type[ModelInputForNPUWithPoolingMetadata] = ( + ModelInputForNPUWithPoolingMetadata) + _builder_cls: Type[ModelInputForNPUBuilder] = ModelInputForNPUBuilder + + def make_model_input_from_broadcasted_tensor_dict( + self, + tensor_dict: Dict[str, + Any]) -> ModelInputForNPUWithPoolingMetadata: + return ModelInputForNPUWithPoolingMetadata.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + ) + + def prepare_model_input( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForNPUWithPoolingMetadata: + assert seq_group_metadata_list is not None + model_input = self._prepare_model_input_tensors( + seq_group_metadata_list, finished_requests_ids) + # Prepare PoolingMetadata. + assert model_input.seq_lens is not None + pooling_metadata = self._prepare_pooling(seq_group_metadata_list, + model_input.seq_lens) + + return dataclasses.replace(model_input, + pooling_metadata=pooling_metadata) + + def _prepare_pooling( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + prompt_lens: List[int], + ) -> PoolingMetadata: + """Prepare PoolingMetadata for the sequence group metadata list.""" + seq_groups: List[Tuple[List[int], PoolingParams]] = [] + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + seq_ids = list(seq_group_metadata.seq_data.keys()) + pooling_params = seq_group_metadata.pooling_params + seq_groups.append((seq_ids, pooling_params)) + + seq_data: Dict[int, SequenceData] = {} + for seq_group_metadata in seq_group_metadata_list: + seq_data.update(seq_group_metadata.seq_data) + + pooling_metadata = PoolingMetadata( + seq_groups=seq_groups, + seq_data=seq_data, + prompt_lens=prompt_lens, + ) + + return pooling_metadata + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForNPUWithPoolingMetadata, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ): + if num_steps > 1: + raise ValueError( + "PoolingModelRunner does not support multi-step execution.") + + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) + + if self.prompt_adapter_config: + assert model_input.prompt_adapter_requests is not None + assert model_input.prompt_adapter_mapping is not None + self.set_active_prompt_adapters( + model_input.prompt_adapter_requests, + model_input.prompt_adapter_mapping) + + assert model_input.attn_metadata is not None + virtual_engine = model_input.virtual_engine + model_executable = self.model + + multi_modal_kwargs = model_input.multi_modal_kwargs or {} + seqlen_agnostic_kwargs = { + "finished_requests_ids": model_input.finished_requests_ids, + "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, + } if self.has_inner_state else {} + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time): + import torch_npu + model_forward_start = torch_npu.npu.Event(enable_timing=True) + model_forward_end = torch_npu.npu.Event(enable_timing=True) + model_forward_start.record() + + cross_enc_kwargs = {} + if model_input.token_types is not None: + cross_enc_kwargs["token_type_ids"] = model_input.token_types + + with set_forward_context(model_input.attn_metadata, self.vllm_config, + virtual_engine): + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + intermediate_tensors=intermediate_tensors, + **MultiModalKwargs.as_kwargs(multi_modal_kwargs, + device=self.device), + **cross_enc_kwargs, + **seqlen_agnostic_kwargs) + + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time): + model_forward_end.record() + + # Only perform pooling in the last pipeline stage. + if not get_pp_group().is_last_rank: + if (self.is_driver_worker + and hidden_or_intermediate_states is not None + and isinstance(hidden_or_intermediate_states, + IntermediateTensors) + and self.observability_config is not None + and self.observability_config.collect_model_forward_time): + model_forward_end.synchronize() + model_forward_time = model_forward_start.elapsed_time( + model_forward_end) + orig_model_forward_time = 0.0 + if intermediate_tensors is not None: + orig_model_forward_time = intermediate_tensors.tensors.get( + "model_forward_time", torch.tensor(0.0)).item() + hidden_or_intermediate_states.tensors["model_forward_time"] = ( + torch.tensor(model_forward_time + orig_model_forward_time)) + return hidden_or_intermediate_states + + # Only perform pooling in the driver worker. + if not self.is_driver_worker: + return [] + + return [ + self.model.pooler(hidden_states=hidden_or_intermediate_states, + pooling_metadata=model_input.pooling_metadata) + ] diff --git a/vllm_ascend/worker.py b/vllm_ascend/worker/worker.py similarity index 90% rename from vllm_ascend/worker.py rename to vllm_ascend/worker/worker.py index bcb6bde..829a2ec 100644 --- a/vllm_ascend/worker.py +++ b/vllm_ascend/worker/worker.py @@ -33,7 +33,6 @@ from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.tensorizer import TensorizerConfig -from vllm.platforms import current_platform from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, SequenceGroupMetadata, SequenceGroupMetadataDelta) @@ -41,12 +40,13 @@ from vllm.utils import bind_kv_cache from vllm.worker.cache_engine import CacheEngine from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner_base import ModelRunnerBase -from vllm.worker.pooling_model_runner import PoolingModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, WorkerInput) -from vllm_ascend.model_runner import NPUModelRunner +from vllm_ascend.platform import NPUPlatform from vllm_ascend.utils import try_register_lib +from vllm_ascend.worker.model_runner import NPUModelRunner +from vllm_ascend.worker.pooling_model_runner import NPUPoolingModelRunner logger = init_logger(__name__) @@ -58,15 +58,12 @@ class NPUWorker(LocalOrDistributedWorkerBase): distributed inference, each worker is assigned a partition of the model. """ - def __init__( - self, - vllm_config: VllmConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - is_driver_worker: bool = False, - model_runner_cls: Optional[Type[ModelRunnerBase]] = None, - ) -> None: + def __init__(self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False): # Register ops when worker init. from vllm_ascend import ops # noqa: F401 @@ -101,7 +98,7 @@ class NPUWorker(LocalOrDistributedWorkerBase): ModelRunnerClass: Type[ModelRunnerBase] = NPUModelRunner if model_config.runner_type == "pooling": - ModelRunnerClass = PoolingModelRunner + ModelRunnerClass = NPUPoolingModelRunner elif self.model_config.is_encoder_decoder: ModelRunnerClass = EncoderDecoderModelRunner self.model_runner: ModelRunnerBase = ModelRunnerClass( @@ -110,8 +107,6 @@ class NPUWorker(LocalOrDistributedWorkerBase): is_driver_worker=is_driver_worker, **speculative_args, ) - if model_runner_cls is not None: - self.model_runner = model_runner_cls(self.model_runner) # Uninitialized cache engine. Will be initialized by # initialize_cache. @@ -155,6 +150,26 @@ class NPUWorker(LocalOrDistributedWorkerBase): else: self.profiler = None + def init_device(self) -> None: + if self.device_config.device.type == "npu": + self.device = torch.device(f"npu:{self.local_rank}") + NPUPlatform.set_device(self.device) + NPUPlatform.empty_cache() + self.init_npu_memory = NPUPlatform.mem_get_info()[0] + else: + raise RuntimeError( + f"Not support device type: {self.device_config.device}") + # Initialize the distributed environment. + self._init_worker_distributed_environment(self.parallel_config, + self.rank, + self.distributed_init_method, + self.local_rank) + # Set random seed. + set_random_seed(self.model_config.seed) + + def load_model(self): + self.model_runner.load_model() + def start_profile(self): if self.profiler is None: raise RuntimeError("Profiler is not enabled.") @@ -165,28 +180,6 @@ class NPUWorker(LocalOrDistributedWorkerBase): raise RuntimeError("Profiler is not enabled.") self.profiler.stop() - def init_device(self) -> None: - if self.device_config.device.type == "npu": - # # This env var set by Ray causes exceptions with graph building. - # os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) - self.device = torch.device(f"npu:{self.local_rank}") - current_platform.set_device(self.device) - - current_platform.empty_cache() - self.init_npu_memory = current_platform.mem_get_info()[0] - else: - raise RuntimeError( - f"Not support device type: {self.device_config.device}") - # Initialize the distributed environment. - init_worker_distributed_environment(self.parallel_config, self.rank, - self.distributed_init_method, - self.local_rank) - # Set random seed. - set_random_seed(self.model_config.seed) - - def load_model(self): - self.model_runner.load_model() - def save_sharded_state( self, path: str, @@ -206,7 +199,7 @@ class NPUWorker(LocalOrDistributedWorkerBase): self.model_runner.save_tensorized_model( tensorizer_config=tensorizer_config, ) - @current_platform.inference_mode() + @NPUPlatform.inference_mode() def determine_num_available_blocks(self) -> Tuple[int, int]: """Profiles the peak memory usage of the model to determine how many KV blocks may be allocated without OOMs. @@ -219,7 +212,7 @@ class NPUWorker(LocalOrDistributedWorkerBase): """ # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. - current_platform.empty_cache() + NPUPlatform.empty_cache() # Execute a forward pass with dummy inputs to profile the memory usage # of the model. @@ -227,7 +220,7 @@ class NPUWorker(LocalOrDistributedWorkerBase): # Calculate the number of blocks that can be allocated with the # profiled peak memory. - free_npu_memory, total_npu_memory = current_platform.mem_get_info() + free_npu_memory, total_npu_memory = NPUPlatform.mem_get_info() # NOTE(woosuk): Here we assume that the other processes using the same # GPU did not change their memory usage during the profiling. peak_memory = self.init_npu_memory - free_npu_memory @@ -248,7 +241,7 @@ class NPUWorker(LocalOrDistributedWorkerBase): gc.collect() # TODO: don`t need impl this func after empty_cache in # Worker.determine_num_available_blocks() unified` - current_platform.empty_cache() + NPUPlatform.empty_cache() return num_npu_blocks, num_cpu_blocks def initialize_cache(self, num_gpu_blocks: int, @@ -448,21 +441,21 @@ class NPUWorker(LocalOrDistributedWorkerBase): self.model_config, self.parallel_config) - -def init_worker_distributed_environment( - parallel_config: ParallelConfig, - rank: int, - distributed_init_method: Optional[str] = None, - local_rank: int = -1, - backend: str = "hccl") -> None: - """Initialize the distributed environment.""" - set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) - - init_distributed_environment(parallel_config.world_size, rank, - distributed_init_method, local_rank, backend) - - ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size) + def _init_worker_distributed_environment( + self, + parallel_config: ParallelConfig, + rank: int, + distributed_init_method: Optional[str] = None, + local_rank: int = -1, + backend: str = "hccl") -> None: + """Initialize the distributed environment.""" + set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) + init_distributed_environment(parallel_config.world_size, rank, + distributed_init_method, local_rank, + backend) + ensure_model_parallel_initialized( + parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free,