[Core] Support pooling (#229)

This PR added pooling support for vllm-ascend

Tested with `bge-base-en-v1.5` by encode:
```
from vllm import LLM

# Sample prompts.
prompts = [
  "Hello, my name is",
  "The president of the United States is",
  "The capital of France is",
  "The future of AI is",
]
# Create an LLM.
model = LLM(model="./bge-base-en-v1.5", enforce_eager=True)
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = model.encode(prompts)
# Print the outputs.
for output in outputs:
    print(output.outputs.embedding)  # list of 4096 floats
```

Tested by embedding:
```
from vllm import LLM, SamplingParams

llm = LLM(model="./bge-base-en-v1.5", task="embed")
(output,) = llm.embed("Hello, my name is")

embeds = output.outputs.embedding
print(f"Embeddings: {embeds!r} (size={len(embeds)})")
```

Related: https://github.com/vllm-project/vllm-ascend/issues/200

## Known issue
The accuracy is not correct since this feature rely on `enc-dec`
support. It'll be done in the following PR by @MengqingCao

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-03-04 15:59:34 +08:00
committed by GitHub
parent 8fda31cafe
commit ae49bfd13a
7 changed files with 258 additions and 71 deletions

View File

@@ -7,7 +7,7 @@
| LoRA | ✗ | Plan in 2025 Q1 | | LoRA | ✗ | Plan in 2025 Q1 |
| Prompt adapter | ✗ | Plan in 2025 Q1 | | Prompt adapter | ✗ | Plan in 2025 Q1 |
| Speculative decoding | ✗ | Plan in 2025 Q1 | | Speculative decoding | ✗ | Plan in 2025 Q1 |
| Pooling | | Plan in 2025 Q2 | | Pooling | | |
| Enc-dec | ✗ | Plan in 2025 Q2 | | Enc-dec | ✗ | Plan in 2025 Q2 |
| Multi Modality | ✅ (LLaVA/Qwen2-vl/Qwen2-audio/internVL)| Add more model support in 2025 Q1 | | Multi Modality | ✅ (LLaVA/Qwen2-vl/Qwen2-audio/internVL)| Add more model support in 2025 Q1 |
| LogProbs | ✅ || | LogProbs | ✅ ||

View File

@@ -38,7 +38,7 @@ from vllm.attention.backends.utils import (CommonAttentionState,
from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.utils import async_tensor_h2d, make_tensor_with_pad
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm_ascend.model_runner import ModelInputForNPUBuilder from vllm_ascend.worker.model_runner import ModelInputForNPUBuilder
def generate_attn_mask(max_seq_len: int, dtype=torch.float16): def generate_attn_mask(max_seq_len: int, dtype=torch.float16):
@@ -211,6 +211,9 @@ class AscendMetadata(AttentionMetadata):
# the computed tokens + new tokens None if it is a decoding. # the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]] = None seq_lens: Optional[List[int]] = None
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor] = None
# Maximum query length in the batch. None for decoding. # Maximum query length in the batch. None for decoding.
max_query_len: Optional[int] = None max_query_len: Optional[int] = None
@@ -258,6 +261,9 @@ class AscendMetadata(AttentionMetadata):
block_tables = (None if self.block_tables is None else block_tables = (None if self.block_tables is None else
self.block_tables[:self.num_prefills]) self.block_tables[:self.num_prefills])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[:self.num_prefills])
# Construct & cache prefill-phase attention metadata structure. # Construct & cache prefill-phase attention metadata structure.
self._cached_prefill_metadata = AscendMetadata( self._cached_prefill_metadata = AscendMetadata(
num_prefills=self.num_prefills, num_prefills=self.num_prefills,
@@ -265,6 +271,7 @@ class AscendMetadata(AttentionMetadata):
num_decode_tokens=0, num_decode_tokens=0,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
seq_lens=seq_lens, seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=self.max_query_len, max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len, max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0, max_decode_seq_len=0,
@@ -297,7 +304,8 @@ class AscendMetadata(AttentionMetadata):
self.seq_lens[self.num_prefills:]) self.seq_lens[self.num_prefills:])
block_tables = (None if self.block_tables is None else block_tables = (None if self.block_tables is None else
self.block_tables[self.num_prefills:]) self.block_tables[self.num_prefills:])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[self.num_prefills:])
# Construct & cache decode-phase attention metadata structure. # Construct & cache decode-phase attention metadata structure.
self._cached_decode_metadata = AscendMetadata( self._cached_decode_metadata = AscendMetadata(
num_prefills=0, num_prefills=0,
@@ -305,6 +313,7 @@ class AscendMetadata(AttentionMetadata):
num_decode_tokens=self.num_decode_tokens, num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
seq_lens=seq_lens, seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_prefill_seq_len=0, max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len, max_decode_seq_len=self.max_decode_seq_len,
block_tables=block_tables, block_tables=block_tables,
@@ -322,7 +331,6 @@ class AscendMetadata(AttentionMetadata):
class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]): class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
_metadata_cls = AscendMetadata
_attn_mask_builder = None # noqa _attn_mask_builder = None # noqa
def __init__(self, input_builder: "ModelInputForNPUBuilder"): def __init__(self, input_builder: "ModelInputForNPUBuilder"):
@@ -451,7 +459,11 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
self.multimodal_placeholder_maps.items() self.multimodal_placeholder_maps.items()
} }
return self._metadata_cls( # type: ignore seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.long,
device=device)
return AscendMetadata(
num_prefills=self.num_prefills, num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor, slot_mapping=slot_mapping_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps, multi_modal_placeholder_index_maps=placeholder_index_maps,
@@ -459,6 +471,7 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
num_prefill_tokens=self.num_prefill_tokens, num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=self.num_decode_tokens, num_decode_tokens=self.num_decode_tokens,
seq_lens=seq_lens, seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len, max_query_len=max_query_len,
max_prefill_seq_len=max_prefill_seq_len, max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len, max_decode_seq_len=max_decode_seq_len,
@@ -528,12 +541,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
shape = [batch_size, seq_len * num_heads * head_size] shape = [batch_size, seq_len * num_heads * head_size]
""" """
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
attn_type = self.attn_type
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl")
# View q k v to BSH. # View q k v to BSH.
num_tokens = query.shape[0] num_tokens = query.shape[0]
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)

View File

@@ -105,7 +105,7 @@ class NPUPlatform(Platform):
def check_and_update_config(cls, vllm_config: VllmConfig) -> None: def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
if parallel_config.worker_cls == "auto": if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm_ascend.worker.NPUWorker" parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker"
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
if cache_config and cache_config.block_size is None: if cache_config and cache_config.block_size is None:
# TODO: Set block_size to 128 will lead unexpected accuracy issue in mla case. Please set block_size to 128 back once the problem is fixed. # TODO: Set block_size to 128 will lead unexpected accuracy issue in mla case. Please set block_size to 128 back once the problem is fixed.

View File

View File

@@ -808,6 +808,9 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
SamplingMetadataCache() \ SamplingMetadataCache() \
if self.parallel_config.pipeline_parallel_size == 1 else None if self.parallel_config.pipeline_parallel_size == 1 else None
def get_model(self) -> nn.Module:
return self.model
def load_model(self) -> None: def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model) logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: with DeviceMemoryProfiler() as m:
@@ -1341,6 +1344,3 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
self.execute_model(model_input, kv_caches, intermediate_tensors) self.execute_model(model_input, kv_caches, intermediate_tensors)
current_platform.synchronize() current_platform.synchronize()
return return
def get_model(self) -> nn.Module:
return self.model

View File

@@ -0,0 +1,187 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/vllm/worker/worker.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import dataclasses
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
from vllm.distributed import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MultiModalKwargs
from vllm.pooling_params import PoolingParams
from vllm.sequence import (IntermediateTensors, SequenceData,
SequenceGroupMetadata)
from vllm_ascend.worker.model_runner import (ModelInputForNPU,
ModelInputForNPUBuilder,
NPUModelRunnerBase)
@dataclasses.dataclass(frozen=True)
class ModelInputForNPUWithPoolingMetadata(ModelInputForNPU):
"""
Used by the PoolingModelRunner.
"""
pooling_metadata: Optional["PoolingMetadata"] = None
class NPUPoolingModelRunner(
NPUModelRunnerBase[ModelInputForNPUWithPoolingMetadata]):
_model_input_cls: Type[ModelInputForNPUWithPoolingMetadata] = (
ModelInputForNPUWithPoolingMetadata)
_builder_cls: Type[ModelInputForNPUBuilder] = ModelInputForNPUBuilder
def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str,
Any]) -> ModelInputForNPUWithPoolingMetadata:
return ModelInputForNPUWithPoolingMetadata.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
)
def prepare_model_input(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForNPUWithPoolingMetadata:
assert seq_group_metadata_list is not None
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids)
# Prepare PoolingMetadata.
assert model_input.seq_lens is not None
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
model_input.seq_lens)
return dataclasses.replace(model_input,
pooling_metadata=pooling_metadata)
def _prepare_pooling(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
) -> PoolingMetadata:
"""Prepare PoolingMetadata for the sequence group metadata list."""
seq_groups: List[Tuple[List[int], PoolingParams]] = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
pooling_params = seq_group_metadata.pooling_params
seq_groups.append((seq_ids, pooling_params))
seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list:
seq_data.update(seq_group_metadata.seq_data)
pooling_metadata = PoolingMetadata(
seq_groups=seq_groups,
seq_data=seq_data,
prompt_lens=prompt_lens,
)
return pooling_metadata
@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForNPUWithPoolingMetadata,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
):
if num_steps > 1:
raise ValueError(
"PoolingModelRunner does not support multi-step execution.")
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
if self.prompt_adapter_config:
assert model_input.prompt_adapter_requests is not None
assert model_input.prompt_adapter_mapping is not None
self.set_active_prompt_adapters(
model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping)
assert model_input.attn_metadata is not None
virtual_engine = model_input.virtual_engine
model_executable = self.model
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
seqlen_agnostic_kwargs = {
"finished_requests_ids": model_input.finished_requests_ids,
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_inner_state else {}
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
import torch_npu
model_forward_start = torch_npu.npu.Event(enable_timing=True)
model_forward_end = torch_npu.npu.Event(enable_timing=True)
model_forward_start.record()
cross_enc_kwargs = {}
if model_input.token_types is not None:
cross_enc_kwargs["token_type_ids"] = model_input.token_types
with set_forward_context(model_input.attn_metadata, self.vllm_config,
virtual_engine):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device),
**cross_enc_kwargs,
**seqlen_agnostic_kwargs)
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end.record()
# Only perform pooling in the last pipeline stage.
if not get_pp_group().is_last_rank:
if (self.is_driver_worker
and hidden_or_intermediate_states is not None
and isinstance(hidden_or_intermediate_states,
IntermediateTensors)
and self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end.synchronize()
model_forward_time = model_forward_start.elapsed_time(
model_forward_end)
orig_model_forward_time = 0.0
if intermediate_tensors is not None:
orig_model_forward_time = intermediate_tensors.tensors.get(
"model_forward_time", torch.tensor(0.0)).item()
hidden_or_intermediate_states.tensors["model_forward_time"] = (
torch.tensor(model_forward_time + orig_model_forward_time))
return hidden_or_intermediate_states
# Only perform pooling in the driver worker.
if not self.is_driver_worker:
return []
return [
self.model.pooler(hidden_states=hidden_or_intermediate_states,
pooling_metadata=model_input.pooling_metadata)
]

View File

@@ -33,7 +33,6 @@ from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.platforms import current_platform
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SequenceGroupMetadata, SequenceGroupMetadataDelta) SequenceGroupMetadata, SequenceGroupMetadataDelta)
@@ -41,12 +40,13 @@ from vllm.utils import bind_kv_cache
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner_base import ModelRunnerBase from vllm.worker.model_runner_base import ModelRunnerBase
from vllm.worker.pooling_model_runner import PoolingModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput) WorkerInput)
from vllm_ascend.model_runner import NPUModelRunner from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import try_register_lib from vllm_ascend.utils import try_register_lib
from vllm_ascend.worker.model_runner import NPUModelRunner
from vllm_ascend.worker.pooling_model_runner import NPUPoolingModelRunner
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -58,15 +58,12 @@ class NPUWorker(LocalOrDistributedWorkerBase):
distributed inference, each worker is assigned a partition of the model. distributed inference, each worker is assigned a partition of the model.
""" """
def __init__( def __init__(self,
self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
local_rank: int, local_rank: int,
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: str,
is_driver_worker: bool = False, is_driver_worker: bool = False):
model_runner_cls: Optional[Type[ModelRunnerBase]] = None,
) -> None:
# Register ops when worker init. # Register ops when worker init.
from vllm_ascend import ops # noqa: F401 from vllm_ascend import ops # noqa: F401
@@ -101,7 +98,7 @@ class NPUWorker(LocalOrDistributedWorkerBase):
ModelRunnerClass: Type[ModelRunnerBase] = NPUModelRunner ModelRunnerClass: Type[ModelRunnerBase] = NPUModelRunner
if model_config.runner_type == "pooling": if model_config.runner_type == "pooling":
ModelRunnerClass = PoolingModelRunner ModelRunnerClass = NPUPoolingModelRunner
elif self.model_config.is_encoder_decoder: elif self.model_config.is_encoder_decoder:
ModelRunnerClass = EncoderDecoderModelRunner ModelRunnerClass = EncoderDecoderModelRunner
self.model_runner: ModelRunnerBase = ModelRunnerClass( self.model_runner: ModelRunnerBase = ModelRunnerClass(
@@ -110,8 +107,6 @@ class NPUWorker(LocalOrDistributedWorkerBase):
is_driver_worker=is_driver_worker, is_driver_worker=is_driver_worker,
**speculative_args, **speculative_args,
) )
if model_runner_cls is not None:
self.model_runner = model_runner_cls(self.model_runner)
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
# initialize_cache. # initialize_cache.
@@ -155,6 +150,26 @@ class NPUWorker(LocalOrDistributedWorkerBase):
else: else:
self.profiler = None self.profiler = None
def init_device(self) -> None:
if self.device_config.device.type == "npu":
self.device = torch.device(f"npu:{self.local_rank}")
NPUPlatform.set_device(self.device)
NPUPlatform.empty_cache()
self.init_npu_memory = NPUPlatform.mem_get_info()[0]
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
self._init_worker_distributed_environment(self.parallel_config,
self.rank,
self.distributed_init_method,
self.local_rank)
# Set random seed.
set_random_seed(self.model_config.seed)
def load_model(self):
self.model_runner.load_model()
def start_profile(self): def start_profile(self):
if self.profiler is None: if self.profiler is None:
raise RuntimeError("Profiler is not enabled.") raise RuntimeError("Profiler is not enabled.")
@@ -165,28 +180,6 @@ class NPUWorker(LocalOrDistributedWorkerBase):
raise RuntimeError("Profiler is not enabled.") raise RuntimeError("Profiler is not enabled.")
self.profiler.stop() self.profiler.stop()
def init_device(self) -> None:
if self.device_config.device.type == "npu":
# # This env var set by Ray causes exceptions with graph building.
# os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
self.device = torch.device(f"npu:{self.local_rank}")
current_platform.set_device(self.device)
current_platform.empty_cache()
self.init_npu_memory = current_platform.mem_get_info()[0]
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
init_worker_distributed_environment(self.parallel_config, self.rank,
self.distributed_init_method,
self.local_rank)
# Set random seed.
set_random_seed(self.model_config.seed)
def load_model(self):
self.model_runner.load_model()
def save_sharded_state( def save_sharded_state(
self, self,
path: str, path: str,
@@ -206,7 +199,7 @@ class NPUWorker(LocalOrDistributedWorkerBase):
self.model_runner.save_tensorized_model( self.model_runner.save_tensorized_model(
tensorizer_config=tensorizer_config, ) tensorizer_config=tensorizer_config, )
@current_platform.inference_mode() @NPUPlatform.inference_mode()
def determine_num_available_blocks(self) -> Tuple[int, int]: def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model to determine how many """Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs. KV blocks may be allocated without OOMs.
@@ -219,7 +212,7 @@ class NPUWorker(LocalOrDistributedWorkerBase):
""" """
# Profile the memory usage of the model and get the maximum number of # Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory. # cache blocks that can be allocated with the remaining free memory.
current_platform.empty_cache() NPUPlatform.empty_cache()
# Execute a forward pass with dummy inputs to profile the memory usage # Execute a forward pass with dummy inputs to profile the memory usage
# of the model. # of the model.
@@ -227,7 +220,7 @@ class NPUWorker(LocalOrDistributedWorkerBase):
# Calculate the number of blocks that can be allocated with the # Calculate the number of blocks that can be allocated with the
# profiled peak memory. # profiled peak memory.
free_npu_memory, total_npu_memory = current_platform.mem_get_info() free_npu_memory, total_npu_memory = NPUPlatform.mem_get_info()
# NOTE(woosuk): Here we assume that the other processes using the same # NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling. # GPU did not change their memory usage during the profiling.
peak_memory = self.init_npu_memory - free_npu_memory peak_memory = self.init_npu_memory - free_npu_memory
@@ -248,7 +241,7 @@ class NPUWorker(LocalOrDistributedWorkerBase):
gc.collect() gc.collect()
# TODO: don`t need impl this func after empty_cache in # TODO: don`t need impl this func after empty_cache in
# Worker.determine_num_available_blocks() unified` # Worker.determine_num_available_blocks() unified`
current_platform.empty_cache() NPUPlatform.empty_cache()
return num_npu_blocks, num_cpu_blocks return num_npu_blocks, num_cpu_blocks
def initialize_cache(self, num_gpu_blocks: int, def initialize_cache(self, num_gpu_blocks: int,
@@ -448,8 +441,8 @@ class NPUWorker(LocalOrDistributedWorkerBase):
self.model_config, self.model_config,
self.parallel_config) self.parallel_config)
def _init_worker_distributed_environment(
def init_worker_distributed_environment( self,
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
rank: int, rank: int,
distributed_init_method: Optional[str] = None, distributed_init_method: Optional[str] = None,
@@ -457,11 +450,11 @@ def init_worker_distributed_environment(
backend: str = "hccl") -> None: backend: str = "hccl") -> None:
"""Initialize the distributed environment.""" """Initialize the distributed environment."""
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
init_distributed_environment(parallel_config.world_size, rank, init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank, backend) distributed_init_method, local_rank,
backend)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size) parallel_config.pipeline_parallel_size)