port deepseekv2 and mtp to main branch (#429)
### What this PR does / why we need it? This PR ports all the deepseek graph mode code and mtp code from v0.7.3 to the main branch --------- Signed-off-by: SidaoY <1024863041@qq.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Signed-off-by: Yizhou Liu <liuyizhou5@h-partners.com> Signed-off-by: mengwei805 <mengwei25@huawei.com> Signed-off-by: libaokui <libaokui@huawei.com> Signed-off-by: q00832892 <qiaoyang19@huawei.com> Signed-off-by: ganyi <pleaplusone.gy@gmail.com> Co-authored-by: SidaoY <1024863041@qq.com> Co-authored-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: Yizhou Liu <liuyizhou5@h-partners.com> Co-authored-by: mengwei805 <mengwei25@huawei.com> Co-authored-by: libaokui <libaokui@huawei.com>
This commit is contained in:
@@ -0,0 +1,17 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import vllm_ascend.worker.cache_engine # noqa
|
||||
69
vllm_ascend/worker/cache_engine.py
Normal file
69
vllm_ascend/worker/cache_engine.py
Normal file
@@ -0,0 +1,69 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/vllm/worker/model_runner.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from vllm.utils import is_pin_memory_available
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
|
||||
from vllm_ascend.utils import VLLM_ENABLE_GRAPH_MODE
|
||||
|
||||
|
||||
def allocate_kv_cache(
|
||||
self,
|
||||
num_blocks: int,
|
||||
device: str,
|
||||
) -> List[Tuple]:
|
||||
"""Allocates KV cache on the specified device."""
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
|
||||
pin_memory = is_pin_memory_available() if device == "cpu" else False
|
||||
kv_cache: List[Tuple] = []
|
||||
|
||||
# Align entries so they are 256 byte aligned for better performance
|
||||
# Primarily targets MLA as this typically only ends up having entries
|
||||
# be 128 byte aligned.
|
||||
alloc_shape = kv_cache_shape
|
||||
|
||||
for _ in range(self.num_attention_layers):
|
||||
# null block in CpuGpuBlockAllocator requires at least that
|
||||
# block to be zeroed-out.
|
||||
# We zero-out everything for simplicity.
|
||||
layer_kv_cache_nope = torch.zeros(
|
||||
alloc_shape[:-1] +
|
||||
(self.model_config.hf_text_config.kv_lora_rank, ),
|
||||
dtype=self.dtype,
|
||||
pin_memory=pin_memory,
|
||||
device=device)
|
||||
layer_kv_cache_pe = torch.zeros(
|
||||
alloc_shape[:-1] +
|
||||
(self.model_config.hf_text_config.qk_rope_head_dim, ),
|
||||
dtype=self.dtype,
|
||||
pin_memory=pin_memory,
|
||||
device=device)
|
||||
|
||||
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
|
||||
# when entry_shape is higher than 1D
|
||||
kv_cache.append((layer_kv_cache_nope, layer_kv_cache_pe))
|
||||
return kv_cache
|
||||
|
||||
|
||||
if VLLM_ENABLE_GRAPH_MODE == '1':
|
||||
CacheEngine._allocate_kv_cache = allocate_kv_cache
|
||||
@@ -18,18 +18,21 @@
|
||||
#
|
||||
|
||||
import dataclasses
|
||||
import itertools
|
||||
import weakref
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set,
|
||||
Type, TypeVar, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import CompilationLevel, VllmConfig
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
@@ -53,7 +56,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
||||
from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, flatten_2d_lists,
|
||||
is_pin_memory_available)
|
||||
is_pin_memory_available, supports_dynamo)
|
||||
from vllm.worker.model_runner_base import (
|
||||
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
|
||||
_add_attn_metadata_broadcastable_dict,
|
||||
@@ -72,6 +75,7 @@ if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
TModelInputForNPU = TypeVar('TModelInputForNPU', bound="ModelInputForNPU")
|
||||
ENCODER_NUM = 0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -526,6 +530,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
|
||||
seq_lens = []
|
||||
max_decode_seq_len = 0
|
||||
is_prompt = self.inter_data_list[0].is_prompt
|
||||
for inter_data in self.inter_data_list:
|
||||
seq_lens.extend(inter_data.seq_lens)
|
||||
if not inter_data.is_prompt:
|
||||
@@ -540,7 +545,26 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
for data in self.inter_data_list
|
||||
}
|
||||
|
||||
input_tokens_tensor = torch.tensor(flatten_2d_lists(input_tokens),
|
||||
# Add graph_pad_size here
|
||||
if self.runner.vllm_config.compilation_config.level ==\
|
||||
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
|
||||
graph_pad_size = self.runner.scheduler_config.max_num_seqs - len(
|
||||
seq_lens)
|
||||
else:
|
||||
graph_pad_size = -1
|
||||
|
||||
#print(f"before tensor input_tokens: {input_tokens}")
|
||||
#print(f"before tensor input_positions: {input_positions}")
|
||||
#print(f"before list seq_lens: {seq_lens}")
|
||||
input_tokens = flatten_2d_lists(input_tokens)
|
||||
input_positions = flatten_2d_lists(input_positions)
|
||||
if graph_pad_size != -1 and not is_prompt:
|
||||
input_tokens.extend(itertools.repeat(0, graph_pad_size))
|
||||
input_positions.extend( # type: ignore
|
||||
itertools.repeat(0, graph_pad_size))
|
||||
seq_lens.extend(itertools.repeat(1, graph_pad_size))
|
||||
query_lens.extend(itertools.repeat(1, graph_pad_size))
|
||||
input_tokens_tensor = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
if mrope_input_positions is not None:
|
||||
@@ -548,13 +572,16 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
else:
|
||||
input_positions_tensor = torch.tensor(
|
||||
flatten_2d_lists(input_positions),
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
input_positions_tensor = torch.tensor(input_positions,
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
#print(f"after tensor input_tokens_tensor: {input_tokens_tensor}")
|
||||
#print(f"after tensor input_positions_tensor: {input_positions_tensor}")
|
||||
#print(f"after list seq_lens: {seq_lens}")
|
||||
|
||||
# Attention metadata.
|
||||
attn_metadata = self.attn_metadata_builder.build(seq_lens, query_lens)
|
||||
attn_metadata = self.attn_metadata_builder.build(
|
||||
seq_lens, query_lens, graph_pad_size)
|
||||
|
||||
# LoRA data.
|
||||
lora_requests = set()
|
||||
@@ -582,6 +609,13 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
]
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
|
||||
|
||||
if self.runner.vllm_config.compilation_config.level ==\
|
||||
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
|
||||
torch._dynamo.mark_static(input_tokens_tensor)
|
||||
torch._dynamo.mark_static(input_positions_tensor)
|
||||
torch._dynamo.mark_static(attn_metadata.block_tables)
|
||||
torch._dynamo.mark_static(attn_metadata.slot_mapping)
|
||||
|
||||
return self.model_input_cls(
|
||||
input_tokens=input_tokens_tensor,
|
||||
input_positions=input_positions_tensor,
|
||||
@@ -841,6 +875,12 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
|
||||
|
||||
self.in_profile_run = False
|
||||
|
||||
self.graph_block_tables = np.zeros(
|
||||
(self.vllm_config.scheduler_config.max_num_seqs,
|
||||
(model_config.max_model_len + self.block_size - 1) //
|
||||
self.block_size),
|
||||
dtype=np.int32)
|
||||
|
||||
# Attention-free but stateful models like Mamba need a placeholder attn
|
||||
# backend, as the attention metadata is needed to manage internal state.
|
||||
# However we must bypass attention selection altogether for some models
|
||||
@@ -930,6 +970,26 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
|
||||
)
|
||||
self.model = self.lora_manager.create_lora_manager(self.model)
|
||||
|
||||
# adapter torch compile with npu_backend
|
||||
if self.vllm_config.compilation_config.level ==\
|
||||
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
|
||||
import torchair # type: ignore
|
||||
from torchair import patch_for_hcom # type: ignore
|
||||
|
||||
# 通信算子成图
|
||||
patch_for_hcom()
|
||||
# 设置npu的config,如果不设置config,可以使用默认的,那可以设置npu_backend="npu"
|
||||
config = torchair.CompilerConfig()
|
||||
config.experimental_config.frozen_parameter = True
|
||||
config.experimental_config.tiling_schedule_optimize = True
|
||||
torch.npu.set_compile_mode(jit_compile=False)
|
||||
self.compile_model = torchair.inference.cache_compile(
|
||||
self.model.forward,
|
||||
dynamic=True,
|
||||
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||
config=config,
|
||||
ge_cache=False)
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
path: str,
|
||||
@@ -1219,10 +1279,43 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
||||
self.attn_state.begin_forward(model_input)
|
||||
|
||||
assert model_input.attn_metadata is not None
|
||||
if self.vllm_config.compilation_config.level ==\
|
||||
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
|
||||
torch._dynamo.mark_static(model_input.input_tokens)
|
||||
torch._dynamo.mark_static(model_input.input_positions)
|
||||
torch._dynamo.mark_static(model_input.attn_metadata.block_tables)
|
||||
torch._dynamo.mark_static(model_input.attn_metadata.slot_mapping)
|
||||
torch._dynamo.mark_static(
|
||||
model_input.attn_metadata.query_start_loc)
|
||||
torch._dynamo.mark_static(model_input.attn_metadata.seq_start_loc)
|
||||
for kv in kv_caches:
|
||||
if isinstance(kv, tuple):
|
||||
torch._dynamo.mark_static(kv[0])
|
||||
torch._dynamo.mark_static(kv[1])
|
||||
|
||||
# TODO(andoorve): We can remove this once all
|
||||
# virtual engines share the same kv cache.
|
||||
virtual_engine = model_input.virtual_engine
|
||||
model_executable = self.model
|
||||
prefill_meta = model_input.attn_metadata.prefill_metadata
|
||||
previous_hidden_states = kwargs.get("previous_hidden_states")
|
||||
if prefill_meta is None and self.vllm_config.compilation_config.level > 0:
|
||||
model_executable = self.compile_model
|
||||
# Note: graph_batch_size value not same as GPU
|
||||
graph_batch_size = model_input.input_tokens.shape[ # type: ignore
|
||||
0] # type: ignore
|
||||
# Note: previous_hidden_states maybe None not same as GPU
|
||||
if previous_hidden_states is not None:
|
||||
previous_hidden_states = torch.cat([
|
||||
previous_hidden_states,
|
||||
torch.empty([
|
||||
graph_batch_size - previous_hidden_states.shape[0],
|
||||
*previous_hidden_states.shape[1:]
|
||||
],
|
||||
dtype=previous_hidden_states.dtype,
|
||||
device=previous_hidden_states.device)
|
||||
])
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
# Receive KV cache in distributed KV cache transfer setting
|
||||
# In disagg prefill setting, it will also recv hidden states and bypass
|
||||
@@ -1248,8 +1341,11 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
||||
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
|
||||
} if self.has_inner_state else {}
|
||||
|
||||
previous_hidden_states = kwargs.get("previous_hidden_states")
|
||||
model_kwargs = {}
|
||||
if self.vllm_config.compilation_config.level ==\
|
||||
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
|
||||
model_kwargs = {"inputs_embeds": None}
|
||||
else:
|
||||
model_kwargs = {}
|
||||
if previous_hidden_states is not None:
|
||||
model_kwargs["previous_hidden_states"] = previous_hidden_states
|
||||
|
||||
@@ -1273,44 +1369,30 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
||||
**seqlen_agnostic_kwargs,
|
||||
**model_kwargs)
|
||||
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_end.record()
|
||||
|
||||
# Sending KV cache in distributed KV cache transfer setting
|
||||
# NOTE: the send operation is non-blocking
|
||||
if self.need_send_kv(model_input, kv_caches):
|
||||
get_kv_transfer_group().send_kv_caches_and_hidden_states(
|
||||
# model_executable is used to know which layer the current
|
||||
# worker is working on, so that we can send KV for only those
|
||||
# layers.
|
||||
model_executable,
|
||||
model_input,
|
||||
kv_caches,
|
||||
hidden_or_intermediate_states,
|
||||
)
|
||||
|
||||
# Compute the logits in the last pipeline stage.
|
||||
if not get_pp_group().is_last_rank:
|
||||
if (self.is_driver_worker
|
||||
and hidden_or_intermediate_states is not None
|
||||
and isinstance(hidden_or_intermediate_states,
|
||||
IntermediateTensors)
|
||||
and self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_end.synchronize()
|
||||
model_forward_time = model_forward_start.elapsed_time(
|
||||
model_forward_end)
|
||||
orig_model_forward_time = 0.0
|
||||
if intermediate_tensors is not None:
|
||||
orig_model_forward_time = intermediate_tensors.tensors.get(
|
||||
"model_forward_time", torch.tensor(0.0)).item()
|
||||
hidden_or_intermediate_states.tensors["model_forward_time"] = (
|
||||
torch.tensor(model_forward_time + orig_model_forward_time))
|
||||
return hidden_or_intermediate_states
|
||||
|
||||
logits = self.model.compute_logits(hidden_or_intermediate_states,
|
||||
model_input.sampling_metadata)
|
||||
# Compute the logits in the last pipeline stage.
|
||||
if not get_pp_group().is_last_rank:
|
||||
if (self.is_driver_worker
|
||||
and hidden_or_intermediate_states is not None
|
||||
and isinstance(hidden_or_intermediate_states,
|
||||
IntermediateTensors)
|
||||
and self.observability_config is not None and
|
||||
self.observability_config.collect_model_forward_time):
|
||||
model_forward_end.synchronize()
|
||||
model_forward_time = model_forward_start.elapsed_time(
|
||||
model_forward_end)
|
||||
orig_model_forward_time = 0.0
|
||||
if intermediate_tensors is not None:
|
||||
orig_model_forward_time = intermediate_tensors.tensors.get(
|
||||
"model_forward_time", torch.tensor(0.0)).item()
|
||||
hidden_or_intermediate_states.tensors[
|
||||
"model_forward_time"] = (
|
||||
torch.tensor(model_forward_time +
|
||||
orig_model_forward_time))
|
||||
return hidden_or_intermediate_states
|
||||
# TODO: remove the synchronize here
|
||||
torch.npu.synchronize()
|
||||
logits = self.model.compute_logits(hidden_or_intermediate_states,
|
||||
model_input.sampling_metadata)
|
||||
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
@@ -1348,6 +1430,9 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
||||
hidden_states = hidden_or_intermediate_states.index_select(
|
||||
0, indices)
|
||||
output.prefill_hidden_states = hidden_or_intermediate_states
|
||||
elif self.vllm_config.compilation_config.level == \
|
||||
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
|
||||
hidden_states = hidden_or_intermediate_states[:len(indices)]
|
||||
else:
|
||||
hidden_states = hidden_or_intermediate_states
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
|
||||
import gc
|
||||
import os
|
||||
import weakref
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -47,8 +48,7 @@ from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
from vllm_ascend.attention.attention import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
|
||||
AscendMetadata)
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -104,6 +104,27 @@ class NPUModelRunner:
|
||||
raise NotImplementedError(
|
||||
"Non-Attention backend is not supported by V1 NPUModelRunner.")
|
||||
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.head_size,
|
||||
self.dtype,
|
||||
self.kv_cache_dtype,
|
||||
self.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
use_mla=self.model_config.use_mla,
|
||||
)
|
||||
if self.attn_backend is None:
|
||||
error_msg = (
|
||||
f"Error with get_att_backend: {self.head_size=}, "
|
||||
f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, "
|
||||
f"{self.model_config.is_attention_free=}, "
|
||||
f"{self.model_config.use_mla=}")
|
||||
logger.error(error_msg)
|
||||
raise NotImplementedError(
|
||||
"Non-Attention backend is not supported by V1 GPUModelRunner.")
|
||||
|
||||
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
|
||||
weakref.proxy(self))
|
||||
|
||||
# Multi-modal data support
|
||||
self.input_registry = INPUT_REGISTRY
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
@@ -191,6 +212,12 @@ class NPUModelRunner:
|
||||
pin_memory=True)
|
||||
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
||||
|
||||
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
self.query_start_loc_np = self.query_start_loc_cpu.numpy()
|
||||
|
||||
self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
@@ -200,6 +227,8 @@ class NPUModelRunner:
|
||||
self.input_positions_cpu = torch.arange(0,
|
||||
self.max_num_tokens,
|
||||
device="cpu")
|
||||
self.attn_mask = None
|
||||
self.attn_state = None
|
||||
|
||||
# NOTE: Pre-construct a mask matrix to improve the efficiency of
|
||||
# attention mask construction during inference.
|
||||
@@ -396,7 +425,11 @@ class NPUModelRunner:
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
assert num_reqs > 0
|
||||
|
||||
# Copy the blocks from CPU to NPU.
|
||||
modified_batch = self.attn_metadata_builder.reorder_batch(
|
||||
self.input_batch, scheduler_output)
|
||||
if modified_batch:
|
||||
self.input_batch.refresh_sampling_metadata()
|
||||
|
||||
# OPTIMIZATION: Start copying the block table first.
|
||||
# This way, we can overlap the copy with the following CPU operations.
|
||||
self.input_batch.block_table.commit(num_reqs)
|
||||
@@ -430,14 +463,13 @@ class NPUModelRunner:
|
||||
self.positions[:total_num_scheduled_tokens].copy_(
|
||||
self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
||||
positions = self.positions[:total_num_scheduled_tokens]
|
||||
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
||||
|
||||
self.seq_lens_np[:num_reqs] = (
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
||||
num_scheduled_tokens)
|
||||
seq_lens = self.seq_lens_cpu[:num_reqs]
|
||||
|
||||
query_lens = torch.from_numpy(num_scheduled_tokens)
|
||||
|
||||
block_table_indices = (req_indices * self.max_num_blocks_per_req +
|
||||
positions_np // self.block_size)
|
||||
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
|
||||
@@ -446,8 +478,6 @@ class NPUModelRunner:
|
||||
np.add(block_numbers * self.block_size,
|
||||
block_offsets,
|
||||
out=self.slot_mapping_np[:total_num_scheduled_tokens])
|
||||
slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
|
||||
self.device, non_blocking=True)
|
||||
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
|
||||
@@ -461,15 +491,14 @@ class NPUModelRunner:
|
||||
query_lens=num_scheduled_tokens,
|
||||
position=positions,
|
||||
attn_state=attn_state)
|
||||
self.attn_mask = attn_mask
|
||||
self.attn_state = attn_state # type: ignore
|
||||
|
||||
attn_metadata = AscendMetadata(
|
||||
seq_lens=query_lens,
|
||||
context_lens=seq_lens,
|
||||
slot_mapping=slot_mapping,
|
||||
block_tables=(
|
||||
self.input_batch.block_table.get_device_tensor()[:num_reqs]),
|
||||
attn_mask=attn_mask,
|
||||
attn_state=attn_state,
|
||||
attn_metadata = self.attn_metadata_builder.build( # type: ignore
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
common_prefix_len=None,
|
||||
)
|
||||
|
||||
# Prepare input_ids
|
||||
@@ -804,6 +833,9 @@ class NPUModelRunner:
|
||||
# different GPUs, and `kv_cache_config.num_blocks` is set to
|
||||
# the min of all `num_blocks`. Verify it here.
|
||||
assert num_blocks >= kv_cache_config.num_blocks
|
||||
# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
|
||||
# encounter OOM issue
|
||||
num_blocks = num_blocks // 4
|
||||
if isinstance(kv_cache_spec, FullAttentionSpec):
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
|
||||
@@ -44,6 +44,7 @@ from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
|
||||
WorkerInput)
|
||||
|
||||
from vllm_ascend.device_allocator.camem import CaMemAllocator
|
||||
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.utils import try_register_lib, vllm_version_is
|
||||
from vllm_ascend.worker.model_runner import NPUModelRunner
|
||||
@@ -313,8 +314,14 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
||||
for ve in range(self.parallel_config.pipeline_parallel_size):
|
||||
num_layers = len(self.cache_engine[ve].gpu_cache)
|
||||
for i in range(num_layers):
|
||||
torch_npu.npu_format_cast(self.cache_engine[ve].gpu_cache[i],
|
||||
2)
|
||||
if torch.is_tensor(self.cache_engine[ve].gpu_cache[i]):
|
||||
torch_npu.npu_format_cast(
|
||||
self.cache_engine[ve].gpu_cache[i], 2)
|
||||
else:
|
||||
torch_npu.npu_format_cast(
|
||||
self.cache_engine[ve].gpu_cache[i][0], 2)
|
||||
torch_npu.npu_format_cast(
|
||||
self.cache_engine[ve].gpu_cache[i][1], 2)
|
||||
self.gpu_cache = [
|
||||
self.cache_engine[ve].gpu_cache
|
||||
for ve in range(self.parallel_config.pipeline_parallel_size)
|
||||
@@ -495,6 +502,7 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
||||
backend: str = "hccl") -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
parallel_config = self.parallel_config
|
||||
additional_config = self.vllm_config.additional_config
|
||||
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
||||
init_distributed_environment(parallel_config.world_size, rank,
|
||||
distributed_init_method, local_rank,
|
||||
@@ -502,6 +510,14 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
||||
ensure_model_parallel_initialized(
|
||||
parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
expert_tensor_parallel_size = 1
|
||||
if additional_config is not None and hasattr(
|
||||
additional_config, "expert_tensor_parallel_size"):
|
||||
expert_tensor_parallel_size = getattr(
|
||||
additional_config, "expert_tensor_parallel_size")
|
||||
init_ascend_model_parallel(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size,
|
||||
expert_tensor_parallel_size)
|
||||
ensure_kv_transfer_initialized(vllm_config)
|
||||
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.utils import try_register_lib, vllm_version_is
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
@@ -209,6 +210,8 @@ class NPUWorker(WorkerBase):
|
||||
|
||||
def _init_worker_distributed_environment(self) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
additional_config = self.vllm_config.additional_config
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
set_custom_all_reduce(
|
||||
not self.parallel_config.disable_custom_all_reduce)
|
||||
init_distributed_environment(self.parallel_config.world_size,
|
||||
@@ -217,6 +220,13 @@ class NPUWorker(WorkerBase):
|
||||
ensure_model_parallel_initialized(
|
||||
self.parallel_config.tensor_parallel_size,
|
||||
self.parallel_config.pipeline_parallel_size)
|
||||
expert_tensor_parallel_size = 1
|
||||
if additional_config is not None and "expert_tensor_parallel_size" in additional_config:
|
||||
expert_tensor_parallel_size = int(
|
||||
additional_config["expert_tensor_parallel_size"])
|
||||
init_ascend_model_parallel(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size,
|
||||
expert_tensor_parallel_size)
|
||||
ensure_kv_transfer_initialized(self.vllm_config)
|
||||
|
||||
def _init_profiler(self):
|
||||
|
||||
Reference in New Issue
Block a user