[Core] Support pooling (#229)
This PR added pooling support for vllm-ascend
Tested with `bge-base-en-v1.5` by encode:
```
from vllm import LLM
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create an LLM.
model = LLM(model="./bge-base-en-v1.5", enforce_eager=True)
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = model.encode(prompts)
# Print the outputs.
for output in outputs:
print(output.outputs.embedding) # list of 4096 floats
```
Tested by embedding:
```
from vllm import LLM, SamplingParams
llm = LLM(model="./bge-base-en-v1.5", task="embed")
(output,) = llm.embed("Hello, my name is")
embeds = output.outputs.embedding
print(f"Embeddings: {embeds!r} (size={len(embeds)})")
```
Related: https://github.com/vllm-project/vllm-ascend/issues/200
## Known issue
The accuracy is not correct since this feature rely on `enc-dec`
support. It'll be done in the following PR by @MengqingCao
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -7,7 +7,7 @@
|
|||||||
| LoRA | ✗ | Plan in 2025 Q1 |
|
| LoRA | ✗ | Plan in 2025 Q1 |
|
||||||
| Prompt adapter | ✗ | Plan in 2025 Q1 |
|
| Prompt adapter | ✗ | Plan in 2025 Q1 |
|
||||||
| Speculative decoding | ✗ | Plan in 2025 Q1 |
|
| Speculative decoding | ✗ | Plan in 2025 Q1 |
|
||||||
| Pooling | ✗ | Plan in 2025 Q2 |
|
| Pooling | ✅ | |
|
||||||
| Enc-dec | ✗ | Plan in 2025 Q2 |
|
| Enc-dec | ✗ | Plan in 2025 Q2 |
|
||||||
| Multi Modality | ✅ (LLaVA/Qwen2-vl/Qwen2-audio/internVL)| Add more model support in 2025 Q1 |
|
| Multi Modality | ✅ (LLaVA/Qwen2-vl/Qwen2-audio/internVL)| Add more model support in 2025 Q1 |
|
||||||
| LogProbs | ✅ ||
|
| LogProbs | ✅ ||
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ from vllm.attention.backends.utils import (CommonAttentionState,
|
|||||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm_ascend.model_runner import ModelInputForNPUBuilder
|
from vllm_ascend.worker.model_runner import ModelInputForNPUBuilder
|
||||||
|
|
||||||
|
|
||||||
def generate_attn_mask(max_seq_len: int, dtype=torch.float16):
|
def generate_attn_mask(max_seq_len: int, dtype=torch.float16):
|
||||||
@@ -211,6 +211,9 @@ class AscendMetadata(AttentionMetadata):
|
|||||||
# the computed tokens + new tokens None if it is a decoding.
|
# the computed tokens + new tokens None if it is a decoding.
|
||||||
seq_lens: Optional[List[int]] = None
|
seq_lens: Optional[List[int]] = None
|
||||||
|
|
||||||
|
# seq_lens stored as a tensor.
|
||||||
|
seq_lens_tensor: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# Maximum query length in the batch. None for decoding.
|
# Maximum query length in the batch. None for decoding.
|
||||||
max_query_len: Optional[int] = None
|
max_query_len: Optional[int] = None
|
||||||
|
|
||||||
@@ -258,6 +261,9 @@ class AscendMetadata(AttentionMetadata):
|
|||||||
block_tables = (None if self.block_tables is None else
|
block_tables = (None if self.block_tables is None else
|
||||||
self.block_tables[:self.num_prefills])
|
self.block_tables[:self.num_prefills])
|
||||||
|
|
||||||
|
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||||
|
self.seq_lens_tensor[:self.num_prefills])
|
||||||
|
|
||||||
# Construct & cache prefill-phase attention metadata structure.
|
# Construct & cache prefill-phase attention metadata structure.
|
||||||
self._cached_prefill_metadata = AscendMetadata(
|
self._cached_prefill_metadata = AscendMetadata(
|
||||||
num_prefills=self.num_prefills,
|
num_prefills=self.num_prefills,
|
||||||
@@ -265,6 +271,7 @@ class AscendMetadata(AttentionMetadata):
|
|||||||
num_decode_tokens=0,
|
num_decode_tokens=0,
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
|
seq_lens_tensor=seq_lens_tensor,
|
||||||
max_query_len=self.max_query_len,
|
max_query_len=self.max_query_len,
|
||||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||||
max_decode_seq_len=0,
|
max_decode_seq_len=0,
|
||||||
@@ -297,7 +304,8 @@ class AscendMetadata(AttentionMetadata):
|
|||||||
self.seq_lens[self.num_prefills:])
|
self.seq_lens[self.num_prefills:])
|
||||||
block_tables = (None if self.block_tables is None else
|
block_tables = (None if self.block_tables is None else
|
||||||
self.block_tables[self.num_prefills:])
|
self.block_tables[self.num_prefills:])
|
||||||
|
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||||
|
self.seq_lens_tensor[self.num_prefills:])
|
||||||
# Construct & cache decode-phase attention metadata structure.
|
# Construct & cache decode-phase attention metadata structure.
|
||||||
self._cached_decode_metadata = AscendMetadata(
|
self._cached_decode_metadata = AscendMetadata(
|
||||||
num_prefills=0,
|
num_prefills=0,
|
||||||
@@ -305,6 +313,7 @@ class AscendMetadata(AttentionMetadata):
|
|||||||
num_decode_tokens=self.num_decode_tokens,
|
num_decode_tokens=self.num_decode_tokens,
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
|
seq_lens_tensor=seq_lens_tensor,
|
||||||
max_prefill_seq_len=0,
|
max_prefill_seq_len=0,
|
||||||
max_decode_seq_len=self.max_decode_seq_len,
|
max_decode_seq_len=self.max_decode_seq_len,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
@@ -322,7 +331,6 @@ class AscendMetadata(AttentionMetadata):
|
|||||||
|
|
||||||
class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||||
|
|
||||||
_metadata_cls = AscendMetadata
|
|
||||||
_attn_mask_builder = None # noqa
|
_attn_mask_builder = None # noqa
|
||||||
|
|
||||||
def __init__(self, input_builder: "ModelInputForNPUBuilder"):
|
def __init__(self, input_builder: "ModelInputForNPUBuilder"):
|
||||||
@@ -451,7 +459,11 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
|||||||
self.multimodal_placeholder_maps.items()
|
self.multimodal_placeholder_maps.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
return self._metadata_cls( # type: ignore
|
seq_lens_tensor = torch.tensor(seq_lens,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
return AscendMetadata(
|
||||||
num_prefills=self.num_prefills,
|
num_prefills=self.num_prefills,
|
||||||
slot_mapping=slot_mapping_tensor,
|
slot_mapping=slot_mapping_tensor,
|
||||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||||
@@ -459,6 +471,7 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
|||||||
num_prefill_tokens=self.num_prefill_tokens,
|
num_prefill_tokens=self.num_prefill_tokens,
|
||||||
num_decode_tokens=self.num_decode_tokens,
|
num_decode_tokens=self.num_decode_tokens,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
|
seq_lens_tensor=seq_lens_tensor,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
max_prefill_seq_len=max_prefill_seq_len,
|
max_prefill_seq_len=max_prefill_seq_len,
|
||||||
max_decode_seq_len=max_decode_seq_len,
|
max_decode_seq_len=max_decode_seq_len,
|
||||||
@@ -528,12 +541,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
shape = [batch_size, seq_len * num_heads * head_size]
|
shape = [batch_size, seq_len * num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||||
attn_type = self.attn_type
|
|
||||||
if attn_type != AttentionType.DECODER:
|
|
||||||
raise NotImplementedError("Encoder self-attention and "
|
|
||||||
"encoder/decoder cross-attention "
|
|
||||||
"are not implemented for "
|
|
||||||
"PallasAttentionBackendImpl")
|
|
||||||
# View q k v to BSH.
|
# View q k v to BSH.
|
||||||
num_tokens = query.shape[0]
|
num_tokens = query.shape[0]
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
|||||||
@@ -105,7 +105,7 @@ class NPUPlatform(Platform):
|
|||||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
if parallel_config.worker_cls == "auto":
|
if parallel_config.worker_cls == "auto":
|
||||||
parallel_config.worker_cls = "vllm_ascend.worker.NPUWorker"
|
parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker"
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
if cache_config and cache_config.block_size is None:
|
if cache_config and cache_config.block_size is None:
|
||||||
# TODO: Set block_size to 128 will lead unexpected accuracy issue in mla case. Please set block_size to 128 back once the problem is fixed.
|
# TODO: Set block_size to 128 will lead unexpected accuracy issue in mla case. Please set block_size to 128 back once the problem is fixed.
|
||||||
|
|||||||
0
vllm_ascend/worker/__init__.py
Normal file
0
vllm_ascend/worker/__init__.py
Normal file
@@ -808,6 +808,9 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
|
|||||||
SamplingMetadataCache() \
|
SamplingMetadataCache() \
|
||||||
if self.parallel_config.pipeline_parallel_size == 1 else None
|
if self.parallel_config.pipeline_parallel_size == 1 else None
|
||||||
|
|
||||||
|
def get_model(self) -> nn.Module:
|
||||||
|
return self.model
|
||||||
|
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
logger.info("Starting to load model %s...", self.model_config.model)
|
logger.info("Starting to load model %s...", self.model_config.model)
|
||||||
with DeviceMemoryProfiler() as m:
|
with DeviceMemoryProfiler() as m:
|
||||||
@@ -1341,6 +1344,3 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
|||||||
self.execute_model(model_input, kv_caches, intermediate_tensors)
|
self.execute_model(model_input, kv_caches, intermediate_tensors)
|
||||||
current_platform.synchronize()
|
current_platform.synchronize()
|
||||||
return
|
return
|
||||||
|
|
||||||
def get_model(self) -> nn.Module:
|
|
||||||
return self.model
|
|
||||||
187
vllm_ascend/worker/pooling_model_runner.py
Normal file
187
vllm_ascend/worker/pooling_model_runner.py
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
# This file is a part of the vllm-ascend project.
|
||||||
|
# Adapted from vllm-project/vllm/vllm/worker/worker.py
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import dataclasses
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from vllm.distributed import get_pp_group
|
||||||
|
from vllm.forward_context import set_forward_context
|
||||||
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||||
|
from vllm.multimodal import MultiModalKwargs
|
||||||
|
from vllm.pooling_params import PoolingParams
|
||||||
|
from vllm.sequence import (IntermediateTensors, SequenceData,
|
||||||
|
SequenceGroupMetadata)
|
||||||
|
|
||||||
|
from vllm_ascend.worker.model_runner import (ModelInputForNPU,
|
||||||
|
ModelInputForNPUBuilder,
|
||||||
|
NPUModelRunnerBase)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
class ModelInputForNPUWithPoolingMetadata(ModelInputForNPU):
|
||||||
|
"""
|
||||||
|
Used by the PoolingModelRunner.
|
||||||
|
"""
|
||||||
|
pooling_metadata: Optional["PoolingMetadata"] = None
|
||||||
|
|
||||||
|
|
||||||
|
class NPUPoolingModelRunner(
|
||||||
|
NPUModelRunnerBase[ModelInputForNPUWithPoolingMetadata]):
|
||||||
|
|
||||||
|
_model_input_cls: Type[ModelInputForNPUWithPoolingMetadata] = (
|
||||||
|
ModelInputForNPUWithPoolingMetadata)
|
||||||
|
_builder_cls: Type[ModelInputForNPUBuilder] = ModelInputForNPUBuilder
|
||||||
|
|
||||||
|
def make_model_input_from_broadcasted_tensor_dict(
|
||||||
|
self,
|
||||||
|
tensor_dict: Dict[str,
|
||||||
|
Any]) -> ModelInputForNPUWithPoolingMetadata:
|
||||||
|
return ModelInputForNPUWithPoolingMetadata.from_broadcasted_tensor_dict(
|
||||||
|
tensor_dict,
|
||||||
|
attn_backend=self.attn_backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_model_input(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||||
|
virtual_engine: int = 0,
|
||||||
|
finished_requests_ids: Optional[List[str]] = None
|
||||||
|
) -> ModelInputForNPUWithPoolingMetadata:
|
||||||
|
assert seq_group_metadata_list is not None
|
||||||
|
model_input = self._prepare_model_input_tensors(
|
||||||
|
seq_group_metadata_list, finished_requests_ids)
|
||||||
|
# Prepare PoolingMetadata.
|
||||||
|
assert model_input.seq_lens is not None
|
||||||
|
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
|
||||||
|
model_input.seq_lens)
|
||||||
|
|
||||||
|
return dataclasses.replace(model_input,
|
||||||
|
pooling_metadata=pooling_metadata)
|
||||||
|
|
||||||
|
def _prepare_pooling(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
prompt_lens: List[int],
|
||||||
|
) -> PoolingMetadata:
|
||||||
|
"""Prepare PoolingMetadata for the sequence group metadata list."""
|
||||||
|
seq_groups: List[Tuple[List[int], PoolingParams]] = []
|
||||||
|
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||||
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||||
|
pooling_params = seq_group_metadata.pooling_params
|
||||||
|
seq_groups.append((seq_ids, pooling_params))
|
||||||
|
|
||||||
|
seq_data: Dict[int, SequenceData] = {}
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
seq_data.update(seq_group_metadata.seq_data)
|
||||||
|
|
||||||
|
pooling_metadata = PoolingMetadata(
|
||||||
|
seq_groups=seq_groups,
|
||||||
|
seq_data=seq_data,
|
||||||
|
prompt_lens=prompt_lens,
|
||||||
|
)
|
||||||
|
|
||||||
|
return pooling_metadata
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
model_input: ModelInputForNPUWithPoolingMetadata,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
num_steps: int = 1,
|
||||||
|
):
|
||||||
|
if num_steps > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"PoolingModelRunner does not support multi-step execution.")
|
||||||
|
|
||||||
|
if self.lora_config:
|
||||||
|
assert model_input.lora_requests is not None
|
||||||
|
assert model_input.lora_mapping is not None
|
||||||
|
self.set_active_loras(model_input.lora_requests,
|
||||||
|
model_input.lora_mapping)
|
||||||
|
|
||||||
|
if self.prompt_adapter_config:
|
||||||
|
assert model_input.prompt_adapter_requests is not None
|
||||||
|
assert model_input.prompt_adapter_mapping is not None
|
||||||
|
self.set_active_prompt_adapters(
|
||||||
|
model_input.prompt_adapter_requests,
|
||||||
|
model_input.prompt_adapter_mapping)
|
||||||
|
|
||||||
|
assert model_input.attn_metadata is not None
|
||||||
|
virtual_engine = model_input.virtual_engine
|
||||||
|
model_executable = self.model
|
||||||
|
|
||||||
|
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||||
|
seqlen_agnostic_kwargs = {
|
||||||
|
"finished_requests_ids": model_input.finished_requests_ids,
|
||||||
|
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
|
||||||
|
} if self.has_inner_state else {}
|
||||||
|
if (self.observability_config is not None
|
||||||
|
and self.observability_config.collect_model_forward_time):
|
||||||
|
import torch_npu
|
||||||
|
model_forward_start = torch_npu.npu.Event(enable_timing=True)
|
||||||
|
model_forward_end = torch_npu.npu.Event(enable_timing=True)
|
||||||
|
model_forward_start.record()
|
||||||
|
|
||||||
|
cross_enc_kwargs = {}
|
||||||
|
if model_input.token_types is not None:
|
||||||
|
cross_enc_kwargs["token_type_ids"] = model_input.token_types
|
||||||
|
|
||||||
|
with set_forward_context(model_input.attn_metadata, self.vllm_config,
|
||||||
|
virtual_engine):
|
||||||
|
hidden_or_intermediate_states = model_executable(
|
||||||
|
input_ids=model_input.input_tokens,
|
||||||
|
positions=model_input.input_positions,
|
||||||
|
intermediate_tensors=intermediate_tensors,
|
||||||
|
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
||||||
|
device=self.device),
|
||||||
|
**cross_enc_kwargs,
|
||||||
|
**seqlen_agnostic_kwargs)
|
||||||
|
|
||||||
|
if (self.observability_config is not None
|
||||||
|
and self.observability_config.collect_model_forward_time):
|
||||||
|
model_forward_end.record()
|
||||||
|
|
||||||
|
# Only perform pooling in the last pipeline stage.
|
||||||
|
if not get_pp_group().is_last_rank:
|
||||||
|
if (self.is_driver_worker
|
||||||
|
and hidden_or_intermediate_states is not None
|
||||||
|
and isinstance(hidden_or_intermediate_states,
|
||||||
|
IntermediateTensors)
|
||||||
|
and self.observability_config is not None
|
||||||
|
and self.observability_config.collect_model_forward_time):
|
||||||
|
model_forward_end.synchronize()
|
||||||
|
model_forward_time = model_forward_start.elapsed_time(
|
||||||
|
model_forward_end)
|
||||||
|
orig_model_forward_time = 0.0
|
||||||
|
if intermediate_tensors is not None:
|
||||||
|
orig_model_forward_time = intermediate_tensors.tensors.get(
|
||||||
|
"model_forward_time", torch.tensor(0.0)).item()
|
||||||
|
hidden_or_intermediate_states.tensors["model_forward_time"] = (
|
||||||
|
torch.tensor(model_forward_time + orig_model_forward_time))
|
||||||
|
return hidden_or_intermediate_states
|
||||||
|
|
||||||
|
# Only perform pooling in the driver worker.
|
||||||
|
if not self.is_driver_worker:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return [
|
||||||
|
self.model.pooler(hidden_states=hidden_or_intermediate_states,
|
||||||
|
pooling_metadata=model_input.pooling_metadata)
|
||||||
|
]
|
||||||
@@ -33,7 +33,6 @@ from vllm.lora.request import LoRARequest
|
|||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
|
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
|
||||||
SequenceGroupMetadata, SequenceGroupMetadataDelta)
|
SequenceGroupMetadata, SequenceGroupMetadataDelta)
|
||||||
@@ -41,12 +40,13 @@ from vllm.utils import bind_kv_cache
|
|||||||
from vllm.worker.cache_engine import CacheEngine
|
from vllm.worker.cache_engine import CacheEngine
|
||||||
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
||||||
from vllm.worker.model_runner_base import ModelRunnerBase
|
from vllm.worker.model_runner_base import ModelRunnerBase
|
||||||
from vllm.worker.pooling_model_runner import PoolingModelRunner
|
|
||||||
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
|
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
|
||||||
WorkerInput)
|
WorkerInput)
|
||||||
|
|
||||||
from vllm_ascend.model_runner import NPUModelRunner
|
from vllm_ascend.platform import NPUPlatform
|
||||||
from vllm_ascend.utils import try_register_lib
|
from vllm_ascend.utils import try_register_lib
|
||||||
|
from vllm_ascend.worker.model_runner import NPUModelRunner
|
||||||
|
from vllm_ascend.worker.pooling_model_runner import NPUPoolingModelRunner
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@@ -58,15 +58,12 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
|||||||
distributed inference, each worker is assigned a partition of the model.
|
distributed inference, each worker is assigned a partition of the model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
local_rank: int,
|
local_rank: int,
|
||||||
rank: int,
|
rank: int,
|
||||||
distributed_init_method: str,
|
distributed_init_method: str,
|
||||||
is_driver_worker: bool = False,
|
is_driver_worker: bool = False):
|
||||||
model_runner_cls: Optional[Type[ModelRunnerBase]] = None,
|
|
||||||
) -> None:
|
|
||||||
# Register ops when worker init.
|
# Register ops when worker init.
|
||||||
from vllm_ascend import ops # noqa: F401
|
from vllm_ascend import ops # noqa: F401
|
||||||
|
|
||||||
@@ -101,7 +98,7 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
|||||||
|
|
||||||
ModelRunnerClass: Type[ModelRunnerBase] = NPUModelRunner
|
ModelRunnerClass: Type[ModelRunnerBase] = NPUModelRunner
|
||||||
if model_config.runner_type == "pooling":
|
if model_config.runner_type == "pooling":
|
||||||
ModelRunnerClass = PoolingModelRunner
|
ModelRunnerClass = NPUPoolingModelRunner
|
||||||
elif self.model_config.is_encoder_decoder:
|
elif self.model_config.is_encoder_decoder:
|
||||||
ModelRunnerClass = EncoderDecoderModelRunner
|
ModelRunnerClass = EncoderDecoderModelRunner
|
||||||
self.model_runner: ModelRunnerBase = ModelRunnerClass(
|
self.model_runner: ModelRunnerBase = ModelRunnerClass(
|
||||||
@@ -110,8 +107,6 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
|||||||
is_driver_worker=is_driver_worker,
|
is_driver_worker=is_driver_worker,
|
||||||
**speculative_args,
|
**speculative_args,
|
||||||
)
|
)
|
||||||
if model_runner_cls is not None:
|
|
||||||
self.model_runner = model_runner_cls(self.model_runner)
|
|
||||||
|
|
||||||
# Uninitialized cache engine. Will be initialized by
|
# Uninitialized cache engine. Will be initialized by
|
||||||
# initialize_cache.
|
# initialize_cache.
|
||||||
@@ -155,6 +150,26 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
|||||||
else:
|
else:
|
||||||
self.profiler = None
|
self.profiler = None
|
||||||
|
|
||||||
|
def init_device(self) -> None:
|
||||||
|
if self.device_config.device.type == "npu":
|
||||||
|
self.device = torch.device(f"npu:{self.local_rank}")
|
||||||
|
NPUPlatform.set_device(self.device)
|
||||||
|
NPUPlatform.empty_cache()
|
||||||
|
self.init_npu_memory = NPUPlatform.mem_get_info()[0]
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Not support device type: {self.device_config.device}")
|
||||||
|
# Initialize the distributed environment.
|
||||||
|
self._init_worker_distributed_environment(self.parallel_config,
|
||||||
|
self.rank,
|
||||||
|
self.distributed_init_method,
|
||||||
|
self.local_rank)
|
||||||
|
# Set random seed.
|
||||||
|
set_random_seed(self.model_config.seed)
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
self.model_runner.load_model()
|
||||||
|
|
||||||
def start_profile(self):
|
def start_profile(self):
|
||||||
if self.profiler is None:
|
if self.profiler is None:
|
||||||
raise RuntimeError("Profiler is not enabled.")
|
raise RuntimeError("Profiler is not enabled.")
|
||||||
@@ -165,28 +180,6 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
|||||||
raise RuntimeError("Profiler is not enabled.")
|
raise RuntimeError("Profiler is not enabled.")
|
||||||
self.profiler.stop()
|
self.profiler.stop()
|
||||||
|
|
||||||
def init_device(self) -> None:
|
|
||||||
if self.device_config.device.type == "npu":
|
|
||||||
# # This env var set by Ray causes exceptions with graph building.
|
|
||||||
# os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
|
|
||||||
self.device = torch.device(f"npu:{self.local_rank}")
|
|
||||||
current_platform.set_device(self.device)
|
|
||||||
|
|
||||||
current_platform.empty_cache()
|
|
||||||
self.init_npu_memory = current_platform.mem_get_info()[0]
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Not support device type: {self.device_config.device}")
|
|
||||||
# Initialize the distributed environment.
|
|
||||||
init_worker_distributed_environment(self.parallel_config, self.rank,
|
|
||||||
self.distributed_init_method,
|
|
||||||
self.local_rank)
|
|
||||||
# Set random seed.
|
|
||||||
set_random_seed(self.model_config.seed)
|
|
||||||
|
|
||||||
def load_model(self):
|
|
||||||
self.model_runner.load_model()
|
|
||||||
|
|
||||||
def save_sharded_state(
|
def save_sharded_state(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
@@ -206,7 +199,7 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
|||||||
self.model_runner.save_tensorized_model(
|
self.model_runner.save_tensorized_model(
|
||||||
tensorizer_config=tensorizer_config, )
|
tensorizer_config=tensorizer_config, )
|
||||||
|
|
||||||
@current_platform.inference_mode()
|
@NPUPlatform.inference_mode()
|
||||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||||
"""Profiles the peak memory usage of the model to determine how many
|
"""Profiles the peak memory usage of the model to determine how many
|
||||||
KV blocks may be allocated without OOMs.
|
KV blocks may be allocated without OOMs.
|
||||||
@@ -219,7 +212,7 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
|||||||
"""
|
"""
|
||||||
# Profile the memory usage of the model and get the maximum number of
|
# Profile the memory usage of the model and get the maximum number of
|
||||||
# cache blocks that can be allocated with the remaining free memory.
|
# cache blocks that can be allocated with the remaining free memory.
|
||||||
current_platform.empty_cache()
|
NPUPlatform.empty_cache()
|
||||||
|
|
||||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||||
# of the model.
|
# of the model.
|
||||||
@@ -227,7 +220,7 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
|||||||
|
|
||||||
# Calculate the number of blocks that can be allocated with the
|
# Calculate the number of blocks that can be allocated with the
|
||||||
# profiled peak memory.
|
# profiled peak memory.
|
||||||
free_npu_memory, total_npu_memory = current_platform.mem_get_info()
|
free_npu_memory, total_npu_memory = NPUPlatform.mem_get_info()
|
||||||
# NOTE(woosuk): Here we assume that the other processes using the same
|
# NOTE(woosuk): Here we assume that the other processes using the same
|
||||||
# GPU did not change their memory usage during the profiling.
|
# GPU did not change their memory usage during the profiling.
|
||||||
peak_memory = self.init_npu_memory - free_npu_memory
|
peak_memory = self.init_npu_memory - free_npu_memory
|
||||||
@@ -248,7 +241,7 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
# TODO: don`t need impl this func after empty_cache in
|
# TODO: don`t need impl this func after empty_cache in
|
||||||
# Worker.determine_num_available_blocks() unified`
|
# Worker.determine_num_available_blocks() unified`
|
||||||
current_platform.empty_cache()
|
NPUPlatform.empty_cache()
|
||||||
return num_npu_blocks, num_cpu_blocks
|
return num_npu_blocks, num_cpu_blocks
|
||||||
|
|
||||||
def initialize_cache(self, num_gpu_blocks: int,
|
def initialize_cache(self, num_gpu_blocks: int,
|
||||||
@@ -448,8 +441,8 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
|||||||
self.model_config,
|
self.model_config,
|
||||||
self.parallel_config)
|
self.parallel_config)
|
||||||
|
|
||||||
|
def _init_worker_distributed_environment(
|
||||||
def init_worker_distributed_environment(
|
self,
|
||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
rank: int,
|
rank: int,
|
||||||
distributed_init_method: Optional[str] = None,
|
distributed_init_method: Optional[str] = None,
|
||||||
@@ -457,11 +450,11 @@ def init_worker_distributed_environment(
|
|||||||
backend: str = "hccl") -> None:
|
backend: str = "hccl") -> None:
|
||||||
"""Initialize the distributed environment."""
|
"""Initialize the distributed environment."""
|
||||||
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
||||||
|
|
||||||
init_distributed_environment(parallel_config.world_size, rank,
|
init_distributed_environment(parallel_config.world_size, rank,
|
||||||
distributed_init_method, local_rank, backend)
|
distributed_init_method, local_rank,
|
||||||
|
backend)
|
||||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
ensure_model_parallel_initialized(
|
||||||
|
parallel_config.tensor_parallel_size,
|
||||||
parallel_config.pipeline_parallel_size)
|
parallel_config.pipeline_parallel_size)
|
||||||
|
|
||||||
|
|
||||||
Reference in New Issue
Block a user