[Core] Support pooling (#229)
This PR added pooling support for vllm-ascend
Tested with `bge-base-en-v1.5` by encode:
```
from vllm import LLM
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create an LLM.
model = LLM(model="./bge-base-en-v1.5", enforce_eager=True)
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = model.encode(prompts)
# Print the outputs.
for output in outputs:
print(output.outputs.embedding) # list of 4096 floats
```
Tested by embedding:
```
from vllm import LLM, SamplingParams
llm = LLM(model="./bge-base-en-v1.5", task="embed")
(output,) = llm.embed("Hello, my name is")
embeds = output.outputs.embedding
print(f"Embeddings: {embeds!r} (size={len(embeds)})")
```
Related: https://github.com/vllm-project/vllm-ascend/issues/200
## Known issue
The accuracy is not correct since this feature rely on `enc-dec`
support. It'll be done in the following PR by @MengqingCao
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -7,7 +7,7 @@
|
||||
| LoRA | ✗ | Plan in 2025 Q1 |
|
||||
| Prompt adapter | ✗ | Plan in 2025 Q1 |
|
||||
| Speculative decoding | ✗ | Plan in 2025 Q1 |
|
||||
| Pooling | ✗ | Plan in 2025 Q2 |
|
||||
| Pooling | ✅ | |
|
||||
| Enc-dec | ✗ | Plan in 2025 Q2 |
|
||||
| Multi Modality | ✅ (LLaVA/Qwen2-vl/Qwen2-audio/internVL)| Add more model support in 2025 Q1 |
|
||||
| LogProbs | ✅ ||
|
||||
|
||||
@@ -38,7 +38,7 @@ from vllm.attention.backends.utils import (CommonAttentionState,
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm_ascend.model_runner import ModelInputForNPUBuilder
|
||||
from vllm_ascend.worker.model_runner import ModelInputForNPUBuilder
|
||||
|
||||
|
||||
def generate_attn_mask(max_seq_len: int, dtype=torch.float16):
|
||||
@@ -211,6 +211,9 @@ class AscendMetadata(AttentionMetadata):
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]] = None
|
||||
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
# Maximum query length in the batch. None for decoding.
|
||||
max_query_len: Optional[int] = None
|
||||
|
||||
@@ -258,6 +261,9 @@ class AscendMetadata(AttentionMetadata):
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[:self.num_prefills])
|
||||
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[:self.num_prefills])
|
||||
|
||||
# Construct & cache prefill-phase attention metadata structure.
|
||||
self._cached_prefill_metadata = AscendMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
@@ -265,6 +271,7 @@ class AscendMetadata(AttentionMetadata):
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_seq_len=0,
|
||||
@@ -297,7 +304,8 @@ class AscendMetadata(AttentionMetadata):
|
||||
self.seq_lens[self.num_prefills:])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[self.num_prefills:])
|
||||
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[self.num_prefills:])
|
||||
# Construct & cache decode-phase attention metadata structure.
|
||||
self._cached_decode_metadata = AscendMetadata(
|
||||
num_prefills=0,
|
||||
@@ -305,6 +313,7 @@ class AscendMetadata(AttentionMetadata):
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
block_tables=block_tables,
|
||||
@@ -322,7 +331,6 @@ class AscendMetadata(AttentionMetadata):
|
||||
|
||||
class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||
|
||||
_metadata_cls = AscendMetadata
|
||||
_attn_mask_builder = None # noqa
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForNPUBuilder"):
|
||||
@@ -451,7 +459,11 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||
self.multimodal_placeholder_maps.items()
|
||||
}
|
||||
|
||||
return self._metadata_cls( # type: ignore
|
||||
seq_lens_tensor = torch.tensor(seq_lens,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
return AscendMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
@@ -459,6 +471,7 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_prefill_seq_len=max_prefill_seq_len,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
@@ -528,12 +541,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
shape = [batch_size, seq_len * num_heads * head_size]
|
||||
"""
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
attn_type = self.attn_type
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"PallasAttentionBackendImpl")
|
||||
# View q k v to BSH.
|
||||
num_tokens = query.shape[0]
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
|
||||
@@ -105,7 +105,7 @@ class NPUPlatform(Platform):
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
if parallel_config.worker_cls == "auto":
|
||||
parallel_config.worker_cls = "vllm_ascend.worker.NPUWorker"
|
||||
parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker"
|
||||
cache_config = vllm_config.cache_config
|
||||
if cache_config and cache_config.block_size is None:
|
||||
# TODO: Set block_size to 128 will lead unexpected accuracy issue in mla case. Please set block_size to 128 back once the problem is fixed.
|
||||
|
||||
0
vllm_ascend/worker/__init__.py
Normal file
0
vllm_ascend/worker/__init__.py
Normal file
@@ -808,6 +808,9 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
|
||||
SamplingMetadataCache() \
|
||||
if self.parallel_config.pipeline_parallel_size == 1 else None
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
def load_model(self) -> None:
|
||||
logger.info("Starting to load model %s...", self.model_config.model)
|
||||
with DeviceMemoryProfiler() as m:
|
||||
@@ -1341,6 +1344,3 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
||||
self.execute_model(model_input, kv_caches, intermediate_tensors)
|
||||
current_platform.synchronize()
|
||||
return
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
187
vllm_ascend/worker/pooling_model_runner.py
Normal file
187
vllm_ascend/worker/pooling_model_runner.py
Normal file
@@ -0,0 +1,187 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/vllm/worker/worker.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import dataclasses
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sequence import (IntermediateTensors, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
|
||||
from vllm_ascend.worker.model_runner import (ModelInputForNPU,
|
||||
ModelInputForNPUBuilder,
|
||||
NPUModelRunnerBase)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ModelInputForNPUWithPoolingMetadata(ModelInputForNPU):
|
||||
"""
|
||||
Used by the PoolingModelRunner.
|
||||
"""
|
||||
pooling_metadata: Optional["PoolingMetadata"] = None
|
||||
|
||||
|
||||
class NPUPoolingModelRunner(
|
||||
NPUModelRunnerBase[ModelInputForNPUWithPoolingMetadata]):
|
||||
|
||||
_model_input_cls: Type[ModelInputForNPUWithPoolingMetadata] = (
|
||||
ModelInputForNPUWithPoolingMetadata)
|
||||
_builder_cls: Type[ModelInputForNPUBuilder] = ModelInputForNPUBuilder
|
||||
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self,
|
||||
tensor_dict: Dict[str,
|
||||
Any]) -> ModelInputForNPUWithPoolingMetadata:
|
||||
return ModelInputForNPUWithPoolingMetadata.from_broadcasted_tensor_dict(
|
||||
tensor_dict,
|
||||
attn_backend=self.attn_backend,
|
||||
)
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> ModelInputForNPUWithPoolingMetadata:
|
||||
assert seq_group_metadata_list is not None
|
||||
model_input = self._prepare_model_input_tensors(
|
||||
seq_group_metadata_list, finished_requests_ids)
|
||||
# Prepare PoolingMetadata.
|
||||
assert model_input.seq_lens is not None
|
||||
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
|
||||
model_input.seq_lens)
|
||||
|
||||
return dataclasses.replace(model_input,
|
||||
pooling_metadata=pooling_metadata)
|
||||
|
||||
def _prepare_pooling(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
prompt_lens: List[int],
|
||||
) -> PoolingMetadata:
|
||||
"""Prepare PoolingMetadata for the sequence group metadata list."""
|
||||
seq_groups: List[Tuple[List[int], PoolingParams]] = []
|
||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
pooling_params = seq_group_metadata.pooling_params
|
||||
seq_groups.append((seq_ids, pooling_params))
|
||||
|
||||
seq_data: Dict[int, SequenceData] = {}
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
seq_data.update(seq_group_metadata.seq_data)
|
||||
|
||||
pooling_metadata = PoolingMetadata(
|
||||
seq_groups=seq_groups,
|
||||
seq_data=seq_data,
|
||||
prompt_lens=prompt_lens,
|
||||
)
|
||||
|
||||
return pooling_metadata
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: ModelInputForNPUWithPoolingMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
):
|
||||
if num_steps > 1:
|
||||
raise ValueError(
|
||||
"PoolingModelRunner does not support multi-step execution.")
|
||||
|
||||
if self.lora_config:
|
||||
assert model_input.lora_requests is not None
|
||||
assert model_input.lora_mapping is not None
|
||||
self.set_active_loras(model_input.lora_requests,
|
||||
model_input.lora_mapping)
|
||||
|
||||
if self.prompt_adapter_config:
|
||||
assert model_input.prompt_adapter_requests is not None
|
||||
assert model_input.prompt_adapter_mapping is not None
|
||||
self.set_active_prompt_adapters(
|
||||
model_input.prompt_adapter_requests,
|
||||
model_input.prompt_adapter_mapping)
|
||||
|
||||
assert model_input.attn_metadata is not None
|
||||
virtual_engine = model_input.virtual_engine
|
||||
model_executable = self.model
|
||||
|
||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||
seqlen_agnostic_kwargs = {
|
||||
"finished_requests_ids": model_input.finished_requests_ids,
|
||||
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
|
||||
} if self.has_inner_state else {}
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
import torch_npu
|
||||
model_forward_start = torch_npu.npu.Event(enable_timing=True)
|
||||
model_forward_end = torch_npu.npu.Event(enable_timing=True)
|
||||
model_forward_start.record()
|
||||
|
||||
cross_enc_kwargs = {}
|
||||
if model_input.token_types is not None:
|
||||
cross_enc_kwargs["token_type_ids"] = model_input.token_types
|
||||
|
||||
with set_forward_context(model_input.attn_metadata, self.vllm_config,
|
||||
virtual_engine):
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
||||
device=self.device),
|
||||
**cross_enc_kwargs,
|
||||
**seqlen_agnostic_kwargs)
|
||||
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_end.record()
|
||||
|
||||
# Only perform pooling in the last pipeline stage.
|
||||
if not get_pp_group().is_last_rank:
|
||||
if (self.is_driver_worker
|
||||
and hidden_or_intermediate_states is not None
|
||||
and isinstance(hidden_or_intermediate_states,
|
||||
IntermediateTensors)
|
||||
and self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_end.synchronize()
|
||||
model_forward_time = model_forward_start.elapsed_time(
|
||||
model_forward_end)
|
||||
orig_model_forward_time = 0.0
|
||||
if intermediate_tensors is not None:
|
||||
orig_model_forward_time = intermediate_tensors.tensors.get(
|
||||
"model_forward_time", torch.tensor(0.0)).item()
|
||||
hidden_or_intermediate_states.tensors["model_forward_time"] = (
|
||||
torch.tensor(model_forward_time + orig_model_forward_time))
|
||||
return hidden_or_intermediate_states
|
||||
|
||||
# Only perform pooling in the driver worker.
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
||||
return [
|
||||
self.model.pooler(hidden_states=hidden_or_intermediate_states,
|
||||
pooling_metadata=model_input.pooling_metadata)
|
||||
]
|
||||
@@ -33,7 +33,6 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
|
||||
SequenceGroupMetadata, SequenceGroupMetadataDelta)
|
||||
@@ -41,12 +40,13 @@ from vllm.utils import bind_kv_cache
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
||||
from vllm.worker.model_runner_base import ModelRunnerBase
|
||||
from vllm.worker.pooling_model_runner import PoolingModelRunner
|
||||
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
|
||||
WorkerInput)
|
||||
|
||||
from vllm_ascend.model_runner import NPUModelRunner
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.utils import try_register_lib
|
||||
from vllm_ascend.worker.model_runner import NPUModelRunner
|
||||
from vllm_ascend.worker.pooling_model_runner import NPUPoolingModelRunner
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -58,15 +58,12 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
||||
distributed inference, each worker is assigned a partition of the model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False,
|
||||
model_runner_cls: Optional[Type[ModelRunnerBase]] = None,
|
||||
) -> None:
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False):
|
||||
# Register ops when worker init.
|
||||
from vllm_ascend import ops # noqa: F401
|
||||
|
||||
@@ -101,7 +98,7 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
||||
|
||||
ModelRunnerClass: Type[ModelRunnerBase] = NPUModelRunner
|
||||
if model_config.runner_type == "pooling":
|
||||
ModelRunnerClass = PoolingModelRunner
|
||||
ModelRunnerClass = NPUPoolingModelRunner
|
||||
elif self.model_config.is_encoder_decoder:
|
||||
ModelRunnerClass = EncoderDecoderModelRunner
|
||||
self.model_runner: ModelRunnerBase = ModelRunnerClass(
|
||||
@@ -110,8 +107,6 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
||||
is_driver_worker=is_driver_worker,
|
||||
**speculative_args,
|
||||
)
|
||||
if model_runner_cls is not None:
|
||||
self.model_runner = model_runner_cls(self.model_runner)
|
||||
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
# initialize_cache.
|
||||
@@ -155,6 +150,26 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
||||
else:
|
||||
self.profiler = None
|
||||
|
||||
def init_device(self) -> None:
|
||||
if self.device_config.device.type == "npu":
|
||||
self.device = torch.device(f"npu:{self.local_rank}")
|
||||
NPUPlatform.set_device(self.device)
|
||||
NPUPlatform.empty_cache()
|
||||
self.init_npu_memory = NPUPlatform.mem_get_info()[0]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Not support device type: {self.device_config.device}")
|
||||
# Initialize the distributed environment.
|
||||
self._init_worker_distributed_environment(self.parallel_config,
|
||||
self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank)
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
def load_model(self):
|
||||
self.model_runner.load_model()
|
||||
|
||||
def start_profile(self):
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
@@ -165,28 +180,6 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
self.profiler.stop()
|
||||
|
||||
def init_device(self) -> None:
|
||||
if self.device_config.device.type == "npu":
|
||||
# # This env var set by Ray causes exceptions with graph building.
|
||||
# os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
|
||||
self.device = torch.device(f"npu:{self.local_rank}")
|
||||
current_platform.set_device(self.device)
|
||||
|
||||
current_platform.empty_cache()
|
||||
self.init_npu_memory = current_platform.mem_get_info()[0]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Not support device type: {self.device_config.device}")
|
||||
# Initialize the distributed environment.
|
||||
init_worker_distributed_environment(self.parallel_config, self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank)
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
def load_model(self):
|
||||
self.model_runner.load_model()
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
path: str,
|
||||
@@ -206,7 +199,7 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
||||
self.model_runner.save_tensorized_model(
|
||||
tensorizer_config=tensorizer_config, )
|
||||
|
||||
@current_platform.inference_mode()
|
||||
@NPUPlatform.inference_mode()
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Profiles the peak memory usage of the model to determine how many
|
||||
KV blocks may be allocated without OOMs.
|
||||
@@ -219,7 +212,7 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
||||
"""
|
||||
# Profile the memory usage of the model and get the maximum number of
|
||||
# cache blocks that can be allocated with the remaining free memory.
|
||||
current_platform.empty_cache()
|
||||
NPUPlatform.empty_cache()
|
||||
|
||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||
# of the model.
|
||||
@@ -227,7 +220,7 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
||||
|
||||
# Calculate the number of blocks that can be allocated with the
|
||||
# profiled peak memory.
|
||||
free_npu_memory, total_npu_memory = current_platform.mem_get_info()
|
||||
free_npu_memory, total_npu_memory = NPUPlatform.mem_get_info()
|
||||
# NOTE(woosuk): Here we assume that the other processes using the same
|
||||
# GPU did not change their memory usage during the profiling.
|
||||
peak_memory = self.init_npu_memory - free_npu_memory
|
||||
@@ -248,7 +241,7 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
||||
gc.collect()
|
||||
# TODO: don`t need impl this func after empty_cache in
|
||||
# Worker.determine_num_available_blocks() unified`
|
||||
current_platform.empty_cache()
|
||||
NPUPlatform.empty_cache()
|
||||
return num_npu_blocks, num_cpu_blocks
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
@@ -448,21 +441,21 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
||||
self.model_config,
|
||||
self.parallel_config)
|
||||
|
||||
|
||||
def init_worker_distributed_environment(
|
||||
parallel_config: ParallelConfig,
|
||||
rank: int,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
local_rank: int = -1,
|
||||
backend: str = "hccl") -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
||||
|
||||
init_distributed_environment(parallel_config.world_size, rank,
|
||||
distributed_init_method, local_rank, backend)
|
||||
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
def _init_worker_distributed_environment(
|
||||
self,
|
||||
parallel_config: ParallelConfig,
|
||||
rank: int,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
local_rank: int = -1,
|
||||
backend: str = "hccl") -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
||||
init_distributed_environment(parallel_config.world_size, rank,
|
||||
distributed_init_method, local_rank,
|
||||
backend)
|
||||
ensure_model_parallel_initialized(
|
||||
parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
|
||||
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free,
|
||||
Reference in New Issue
Block a user