port deepseekv2 and mtp to main branch (#429)
### What this PR does / why we need it? This PR ports all the deepseek graph mode code and mtp code from v0.7.3 to the main branch --------- Signed-off-by: SidaoY <1024863041@qq.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Signed-off-by: Yizhou Liu <liuyizhou5@h-partners.com> Signed-off-by: mengwei805 <mengwei25@huawei.com> Signed-off-by: libaokui <libaokui@huawei.com> Signed-off-by: q00832892 <qiaoyang19@huawei.com> Signed-off-by: ganyi <pleaplusone.gy@gmail.com> Co-authored-by: SidaoY <1024863041@qq.com> Co-authored-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: Yizhou Liu <liuyizhou5@h-partners.com> Co-authored-by: mengwei805 <mengwei25@huawei.com> Co-authored-by: libaokui <libaokui@huawei.com>
This commit is contained in:
128
examples/disaggregated_prefill_hccl.py
Normal file
128
examples/disaggregated_prefill_hccl.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
This file demonstrates the example usage of disaggregated prefilling
|
||||
We will launch 2 vllm instances (NPU 0,1 for prefill and NPU 2,3 for decode),
|
||||
and then transfer the KV cache between them.
|
||||
"""
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import time
|
||||
from multiprocessing import Event, Process
|
||||
|
||||
|
||||
def clean_up():
|
||||
import gc
|
||||
|
||||
import torch
|
||||
from vllm.distributed.parallel_state import (
|
||||
destroy_distributed_environment, destroy_model_parallel)
|
||||
destroy_model_parallel()
|
||||
destroy_distributed_environment()
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
|
||||
|
||||
def run_prefill(prefill_done, process_close):
|
||||
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1"
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import KVTransferConfig
|
||||
|
||||
prompts = [
|
||||
"Hello, how are you today?", "Hi, what is your name?",
|
||||
"Tell me a very long story.", "what is your favourite book?"
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
|
||||
|
||||
ktc = KVTransferConfig.from_cli(
|
||||
'{"kv_connector":"AscendHcclConnector","kv_buffer_device":"npu","kv_role":"kv_producer", "kv_parallel_size":2}'
|
||||
)
|
||||
|
||||
# Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB
|
||||
# memory. You may need to adjust the value to fit your GPU.
|
||||
llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
|
||||
kv_transfer_config=ktc,
|
||||
max_model_len=2000,
|
||||
gpu_memory_utilization=0.8,
|
||||
tensor_parallel_size=2)
|
||||
|
||||
llm.generate(prompts, sampling_params)
|
||||
print("Prefill node is finished.")
|
||||
prefill_done.set()
|
||||
|
||||
# To keep the prefill node running in case the decode node is not done;
|
||||
# otherwise, the script might exit prematurely, causing incomplete decoding.
|
||||
try:
|
||||
while not process_close.is_set():
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
print("Script stopped by user.")
|
||||
finally:
|
||||
print("Cleanup prefill resources")
|
||||
del llm
|
||||
clean_up()
|
||||
|
||||
|
||||
def run_decode(prefill_done):
|
||||
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "2,3"
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import KVTransferConfig
|
||||
|
||||
prompts = [
|
||||
"Hello, how are you today?", "Hi, what is your name?",
|
||||
"Tell me a very long story.", "what is your favourite book?"
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0, top_p=0.95)
|
||||
|
||||
ktc = KVTransferConfig.from_cli(
|
||||
'{"kv_connector":"AscendHcclConnector","kv_buffer_device":"npu","kv_role":"kv_consumer","kv_parallel_size":2}'
|
||||
)
|
||||
|
||||
llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
|
||||
kv_transfer_config=ktc,
|
||||
max_model_len=2000,
|
||||
gpu_memory_utilization=0.8,
|
||||
tensor_parallel_size=2)
|
||||
|
||||
# Wait for the producer to start the comsumer
|
||||
print("Waiting for prefill node to finish...")
|
||||
prefill_done.wait()
|
||||
|
||||
# At this point when the prefill_done is set, the kv-cache should have been
|
||||
# transferred to this decode node, so we can start decoding.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
del llm
|
||||
clean_up()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mp.get_context('spawn')
|
||||
|
||||
prefill_done = Event()
|
||||
process_close = Event()
|
||||
prefill_process = Process(target=run_prefill,
|
||||
args=(
|
||||
prefill_done,
|
||||
process_close,
|
||||
))
|
||||
decode_process = Process(target=run_decode, args=(prefill_done, ))
|
||||
|
||||
# Start prefill node
|
||||
prefill_process.start()
|
||||
|
||||
# Start decode node
|
||||
decode_process.start()
|
||||
|
||||
# Terminate the prefill node when decode is finished
|
||||
decode_process.join()
|
||||
|
||||
# Terminate prefill process
|
||||
process_close.set()
|
||||
prefill_process.join()
|
||||
prefill_process.terminate()
|
||||
print("All process done!")
|
||||
86
examples/dp_offline/data_parallel.py
Normal file
86
examples/dp_offline/data_parallel.py
Normal file
@@ -0,0 +1,86 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/examples/offline_inference/data_parallel.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# usage:
|
||||
# python examples/offline_inference_data_parallel.py
|
||||
# we need to have a launcher to create multiple data parallel
|
||||
# ranks. And each rank will create a vLLM instance to process its own prompts.
|
||||
|
||||
import gc
|
||||
import os
|
||||
|
||||
VLLM_ENABLE_GRAPGH_MODE = os.environ.get("VLLM_ENABLE_GRAPH_MODE") == "1"
|
||||
|
||||
|
||||
def main():
|
||||
dp_rank = int(os.environ['RANK'])
|
||||
local_rank = int(os.environ['LOCAL_RANK'])
|
||||
dp_size = int(os.environ['WORLD_SIZE'])
|
||||
master_addr = os.environ['MASTER_ADDR']
|
||||
master_port = os.environ['MASTER_PORT']
|
||||
tp_size = 4
|
||||
etp_size = 2
|
||||
|
||||
os.environ["VLLM_DP_RANK"] = str(dp_rank)
|
||||
os.environ["VLLM_DP_SIZE"] = str(dp_size)
|
||||
os.environ["VLLM_DP_MASTER_IP"] = master_addr
|
||||
os.environ["VLLM_DP_MASTER_PORT"] = master_port
|
||||
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = ",".join(
|
||||
str(i)
|
||||
for i in range(local_rank * tp_size, (local_rank + 1) * tp_size))
|
||||
|
||||
import torch
|
||||
import torch_npu # noqa
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.distributed.parallel_state import (
|
||||
destroy_distributed_environment, destroy_model_parallel)
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
] * 4
|
||||
|
||||
promts_per_rank = len(prompts) // dp_size
|
||||
start = dp_rank * promts_per_rank
|
||||
end = start + promts_per_rank
|
||||
prompts = prompts[start:end]
|
||||
if len(prompts) == 0:
|
||||
prompts = ["Placeholder"]
|
||||
print(f"DP rank {dp_rank} needs to process {len(prompts)} prompts")
|
||||
num_seqs = len(prompts)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=4,
|
||||
min_tokens=4)
|
||||
# Create an LLM.
|
||||
llm = LLM(
|
||||
model="deepseek-ai/DeepSeek-V2-Lite-Chat",
|
||||
tensor_parallel_size=tp_size,
|
||||
trust_remote_code=True,
|
||||
expert_tensor_parallel_size=etp_size,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=num_seqs,
|
||||
compilation_config=1 if VLLM_ENABLE_GRAPGH_MODE else 0,
|
||||
)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"DP rank {dp_rank}, Prompt: {prompt!r}, "
|
||||
f"Generated text: {generated_text!r}")
|
||||
|
||||
del llm
|
||||
destroy_model_parallel()
|
||||
destroy_distributed_environment()
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
21
examples/dp_offline/run_dp.sh
Normal file
21
examples/dp_offline/run_dp.sh
Normal file
@@ -0,0 +1,21 @@
|
||||
export HCCL_IF_IP=${local_ip}
|
||||
export GLOO_SOCKET_IFNAME=${ifname}
|
||||
export TP_SOCKET_IFNAME=${ifname}
|
||||
export HCCL_SOCKET_IFNAME=${ifname}
|
||||
|
||||
# dp_size = node_size * dp_per_node
|
||||
node_size=1
|
||||
node_rank=0
|
||||
dp_per_node=2
|
||||
master_addr=127.0.0.1
|
||||
master_port=12345
|
||||
|
||||
rm -rf ./.torchair_cache/
|
||||
rm -rf ./dynamo_*
|
||||
rm -rf /root/ascend/log/debug/plog/*
|
||||
export VLLM_ENABLE_GRAPH_MODE=0
|
||||
export VLLM_ENABLE_MC2=0
|
||||
|
||||
torchrun --nproc_per_node ${dp_per_node} --nnodes ${node_size} \
|
||||
--node_rank ${node_rank} --master_addr ${master_addr} --master_port ${master_port} \
|
||||
data_parallel.py
|
||||
49
examples/offline_inference_npu_v1.py
Normal file
49
examples/offline_inference_npu_v1.py
Normal file
@@ -0,0 +1,49 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/examples/offline_inference/basic.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
os.environ["VLLM_USE_V1"] = "1"
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
if __name__ == "__main__":
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
|
||||
# Create an LLM.
|
||||
llm = LLM(model="/data/weights/deepseek-ai/deepseekv3-lite-base-latest",
|
||||
tensor_parallel_size=2,
|
||||
enforce_eager=True,
|
||||
trust_remote_code=True,
|
||||
max_model_len=1024)
|
||||
|
||||
# Generate texts from the prompts.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
@@ -26,8 +26,6 @@ import torch
|
||||
from PIL import Image
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import TaskOption
|
||||
from vllm.distributed.parallel_state import (destroy_distributed_environment,
|
||||
destroy_model_parallel)
|
||||
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
@@ -35,6 +33,15 @@ from vllm.utils import is_list_of
|
||||
|
||||
from tests.model_utils import (TokensTextLogprobs,
|
||||
TokensTextLogprobsPromptLogprobs)
|
||||
# TODO: remove this part after the patch merged into vllm, if
|
||||
# we not explicitly patch here, some of them might be effectiveless
|
||||
# in pytest scenario
|
||||
from vllm_ascend.utils import adapt_patch # noqa E402
|
||||
|
||||
adapt_patch(True)
|
||||
|
||||
from vllm.distributed.parallel_state import ( # noqa E402
|
||||
destroy_distributed_environment, destroy_model_parallel)
|
||||
|
||||
_M = TypeVar("_M")
|
||||
|
||||
|
||||
@@ -31,13 +31,14 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata, AttentionType,
|
||||
MLAAttentionImpl)
|
||||
from vllm.attention.backends.utils import (CommonAttentionState,
|
||||
from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
|
||||
CommonMetadataBuilder,
|
||||
compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_block_tables_empty)
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
|
||||
from vllm_ascend.utils import VLLM_ENABLE_GRAPH_MODE
|
||||
from vllm_ascend.worker.model_runner import (
|
||||
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
|
||||
|
||||
@@ -222,7 +223,7 @@ class AscendMLAAttentionBackend(AscendAttentionBackend):
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (1, num_blocks, block_size, num_kv_heads * head_size)
|
||||
return (num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -552,10 +553,33 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||
inter_data.block_tables,
|
||||
)
|
||||
|
||||
def _get_graph_runner_block_tables(
|
||||
self, num_seqs: int,
|
||||
block_tables: List[List[int]]) -> torch.Tensor:
|
||||
# The shape of graph_block_tables is
|
||||
# [max batch size, max context len // block size].
|
||||
|
||||
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
|
||||
assert max_batch_size >= num_seqs
|
||||
|
||||
graph_block_tables = self.runner.graph_block_tables # [:num_seqs]
|
||||
for i, block_table in enumerate(block_tables):
|
||||
if block_table:
|
||||
num_blocks = len(block_table)
|
||||
if num_blocks <= max_blocks:
|
||||
graph_block_tables[i, :num_blocks] = block_table
|
||||
else:
|
||||
graph_block_tables[
|
||||
i, :max_blocks] = block_table[:max_blocks]
|
||||
|
||||
return torch.from_numpy(graph_block_tables).to(
|
||||
device=self.runner.device, non_blocking=True)
|
||||
|
||||
def build(
|
||||
self,
|
||||
seq_lens: List[int],
|
||||
query_lens: List[int],
|
||||
graph_pad_size: int,
|
||||
):
|
||||
"""Build attention metadata with on-device tensors.
|
||||
|
||||
@@ -568,6 +592,7 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||
self.input_builder.chunked_prefill_enabled)
|
||||
|
||||
device = self.runner.device
|
||||
use_torchair_graph = graph_pad_size != -1
|
||||
|
||||
max_query_len = max(query_lens)
|
||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||
@@ -582,12 +607,36 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||
self.attn_mask = None
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
|
||||
if self.num_prefills == 0 and use_torchair_graph:
|
||||
num_seqs = len(seq_lens)
|
||||
self.slot_mapping.extend([PAD_SLOT_ID] * graph_pad_size)
|
||||
self.block_tables.extend([[]] * graph_pad_size)
|
||||
block_tables = self._get_graph_runner_block_tables(
|
||||
num_seqs, self.block_tables)
|
||||
else:
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if self.num_prefills > 0:
|
||||
self.attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
|
||||
max_prefill_seq_len,
|
||||
self.input_builder.runner.model_config.dtype,
|
||||
self.input_builder.runner.device)
|
||||
else:
|
||||
self.attn_mask = None
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
assert max_query_len > 0, "query_lens: {}".format(query_lens)
|
||||
|
||||
assert device is not None
|
||||
@@ -855,14 +904,100 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
||||
self.q_proj = extra_impl_args['q_proj']
|
||||
self.kv_b_proj = extra_impl_args['kv_b_proj']
|
||||
self.o_proj = extra_impl_args['o_proj']
|
||||
self.kv_a_proj_with_mqa = extra_impl_args.get('kv_a_proj_with_mqa',
|
||||
None)
|
||||
self.kv_a_layernorm = extra_impl_args.get('kv_a_layernorm', None)
|
||||
self.k_pe_cache = None
|
||||
self.k_nope_cache = None
|
||||
self.w_kc = None
|
||||
self.w_vc = None
|
||||
|
||||
def exec_kv(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
kv_cache: Tuple,
|
||||
slots: torch.Tensor,
|
||||
):
|
||||
B = hidden_states.shape[0]
|
||||
N = self.num_kv_heads
|
||||
S = 1
|
||||
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
||||
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
|
||||
k_pe, k_nope = torch.ops.npu_inference.npu_kv_rmsnorm_rope_cache(
|
||||
kv,
|
||||
self.kv_a_layernorm.weight,
|
||||
cos,
|
||||
sin,
|
||||
slots.to(torch.int64),
|
||||
kv_cache[1],
|
||||
kv_cache[0],
|
||||
epsilon=self.kv_a_layernorm.variance_epsilon,
|
||||
cache_mode="PA",
|
||||
)
|
||||
|
||||
return k_pe, k_nope
|
||||
|
||||
def apply_rotary_emb(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
is_neox_style: bool,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: [num_tokens, num_heads, head_size]
|
||||
cos: [num_tokens, head_size // 2]
|
||||
sin: [num_tokens, head_size // 2]
|
||||
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
|
||||
positional embeddings.
|
||||
"""
|
||||
cos = cos.unsqueeze(-2).to(x.dtype)
|
||||
sin = sin.unsqueeze(-2).to(x.dtype)
|
||||
if is_neox_style:
|
||||
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||
else:
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
o1 = x1 * cos - x2 * sin
|
||||
o2 = x2 * cos + x1 * sin
|
||||
if is_neox_style:
|
||||
return torch.cat((o1, o2), dim=-1)
|
||||
else:
|
||||
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
||||
|
||||
def rope_single(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
B, N, D = x.shape
|
||||
S = 1
|
||||
x = x.view(B, N, S, D)
|
||||
x = torch.ops.npu_inference.npu_interleave_rope(x, cos, sin)
|
||||
return x.view(B, N, D)
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
if self.w_kc is None or self.w_vc is None:
|
||||
kv_b_proj_weight = self.kv_b_proj.weight.reshape(
|
||||
self.num_heads, self.qk_nope_head_dim + self.v_head_dim,
|
||||
self.kv_lora_rank)
|
||||
self.w_kc = kv_b_proj_weight[:, :self.
|
||||
qk_nope_head_dim, :].contiguous()
|
||||
self.w_vc = kv_b_proj_weight[:,
|
||||
self.qk_nope_head_dim:, :].transpose(
|
||||
1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
hidden_states_or_q_c: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
hidden_states_or_kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
@@ -873,7 +1008,7 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
||||
Args:
|
||||
hidden_states_or_q_c: shape = [num_tokens, num_heads * head_size]
|
||||
num_tokens = batch_size * seq_len
|
||||
kv_c_normed: shape = [num_tokens, num_kv_heads * head_size]
|
||||
hidden_states_or_kv_c_normed: shape = [num_tokens, num_kv_heads * head_size]
|
||||
k_pe: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache: shape = [1, num_blocks, block_size,
|
||||
num_kv_heads * head_size]
|
||||
@@ -889,71 +1024,86 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
||||
"are not implemented for "
|
||||
"PallasAttentionBackendImpl")
|
||||
|
||||
if attn_metadata is None:
|
||||
# for profile run
|
||||
return hidden_states_or_q_c
|
||||
|
||||
num_tokens = hidden_states_or_q_c.shape[0]
|
||||
q = self.q_proj(hidden_states_or_q_c)[0].view(-1, self.num_heads,
|
||||
self.qk_head_dim)
|
||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
|
||||
dim=-1)
|
||||
if k_pe is None and attn_metadata.decode_metadata:
|
||||
seq_len = self.rotary_emb.max_position_embeddings
|
||||
|
||||
k_pe = k_pe.view(num_tokens, self.num_kv_heads, -1)
|
||||
cos = self.rotary_emb.cos_cached[:seq_len].to(dtype=q_pe.dtype)
|
||||
sin = self.rotary_emb.sin_cached[:seq_len].to(dtype=q_pe.dtype)
|
||||
cos = cos[attn_metadata.input_positions]
|
||||
sin = sin[attn_metadata.input_positions]
|
||||
cos = cos[:, None, None, :]
|
||||
sin = sin[:, None, None, :]
|
||||
|
||||
if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding':
|
||||
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
|
||||
q_pe = q_pe.reshape(num_tokens, -1)
|
||||
k_pe = k_pe.reshape(num_tokens, -1)
|
||||
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe,
|
||||
k_pe)
|
||||
q_pe = q_pe.view(ori_q_pe_shape)
|
||||
k_pe = k_pe.view(ori_k_pe_shape)
|
||||
q_pe = self.rope_single(q_pe, cos, sin)
|
||||
k_pe, k_nope = self.exec_kv(hidden_states_or_kv_c_normed, cos, sin,
|
||||
kv_cache, attn_metadata.slot_mapping)
|
||||
else:
|
||||
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe,
|
||||
k_pe)
|
||||
|
||||
if self.w_kc is None or self.w_vc is None:
|
||||
kv_b_proj_weight = self.kv_b_proj.weight.reshape(
|
||||
self.num_heads, self.qk_nope_head_dim + self.v_head_dim,
|
||||
self.kv_lora_rank)
|
||||
self.w_kc = kv_b_proj_weight[:, :self.
|
||||
qk_nope_head_dim, :].contiguous()
|
||||
self.w_vc = kv_b_proj_weight[:,
|
||||
self.qk_nope_head_dim:, :].transpose(
|
||||
1, 2).contiguous()
|
||||
if k_pe is None:
|
||||
# NOTE: k_pe is None when graph mode enabled
|
||||
kv_c, k_pe = self.kv_a_proj_with_mqa(
|
||||
hidden_states_or_kv_c_normed)[0].split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
else:
|
||||
kv_c_normed = hidden_states_or_kv_c_normed
|
||||
k_pe = k_pe.view(num_tokens, self.num_kv_heads, -1)
|
||||
if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding':
|
||||
# NOTE: When scaling not specified
|
||||
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
|
||||
q_pe = q_pe.reshape(num_tokens, -1)
|
||||
k_pe = k_pe.reshape(num_tokens, -1)
|
||||
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions,
|
||||
q_pe, k_pe)
|
||||
q_pe = q_pe.view(ori_q_pe_shape)
|
||||
k_pe = k_pe.view(ori_k_pe_shape)
|
||||
else:
|
||||
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions,
|
||||
q_pe, k_pe)
|
||||
|
||||
if attn_metadata.num_prefills > 0:
|
||||
kv_heads_num = self.num_heads
|
||||
kv = self.kv_b_proj(kv_c_normed)[0].view(num_tokens, kv_heads_num,
|
||||
-1)
|
||||
kv = self.kv_b_proj(kv_c_normed)[0].view(num_tokens,
|
||||
self.num_heads, -1)
|
||||
k_nope, value = kv.split([self.qk_nope_head_dim, self.v_head_dim],
|
||||
dim=-1)
|
||||
k_cache = torch.cat(
|
||||
[kv_c_normed.view(num_tokens, self.num_kv_heads, -1), k_pe],
|
||||
dim=2)
|
||||
k_pe = k_pe.expand(-1, self.num_heads, -1)
|
||||
key = torch.cat([k_nope.view(num_tokens, kv_heads_num, -1), k_pe],
|
||||
dim=2)
|
||||
else:
|
||||
kv_heads_num = self.num_kv_heads
|
||||
q_nope_t = torch.transpose(q_nope, 0, 1)
|
||||
q_nope_out = torch.bmm(q_nope_t, self.w_kc)
|
||||
q_nope = torch.transpose(q_nope_out, 0, 1)
|
||||
k_cache = torch.cat(
|
||||
[kv_c_normed.view(num_tokens, self.num_kv_heads, -1), k_pe],
|
||||
dim=2)
|
||||
|
||||
query = torch.cat([q_nope, q_pe], dim=-1).view(num_tokens,
|
||||
self.num_heads, -1)
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
key_cache = kv_cache[0]
|
||||
num_blocks, block_size, _ = key_cache.shape
|
||||
|
||||
key_cache = key_cache.view(
|
||||
num_blocks, block_size, self.num_kv_heads,
|
||||
self.qk_rope_head_dim + self.kv_lora_rank)
|
||||
slots = attn_metadata.slot_mapping
|
||||
torch_npu._npu_reshape_and_cache_siso(key=k_cache,
|
||||
key_cache=key_cache,
|
||||
slot_indices=slots)
|
||||
# TODO: Replace the env with more flexible expressions
|
||||
if VLLM_ENABLE_GRAPH_MODE == '1':
|
||||
if len(kv_cache) > 0 and kv_cache[0].numel(
|
||||
) > 0 and attn_metadata.num_prefills > 0:
|
||||
slots = attn_metadata.slot_mapping
|
||||
# NOTE: Seperate the kv cache in advance to avoid OOM or other issues
|
||||
torch_npu._npu_reshape_and_cache(key=kv_c_normed.view(
|
||||
num_tokens, self.num_kv_heads, -1),
|
||||
value=k_pe,
|
||||
key_cache=kv_cache[0],
|
||||
value_cache=kv_cache[1],
|
||||
slot_indices=slots)
|
||||
else:
|
||||
if kv_cache.numel() > 0:
|
||||
key = torch.cat([
|
||||
kv_c_normed.view(num_tokens, self.num_kv_heads, -1), k_pe
|
||||
],
|
||||
dim=2)
|
||||
slots = attn_metadata.slot_mapping
|
||||
torch_npu._npu_reshape_and_cache_siso(key=key,
|
||||
key_cache=kv_cache,
|
||||
slot_indices=slots)
|
||||
|
||||
if attn_metadata.num_prefills > 0:
|
||||
attn_output = torch.empty(num_tokens,
|
||||
@@ -964,12 +1114,15 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
||||
if (attn_metadata.block_tables is None
|
||||
or attn_metadata.block_tables.numel() == 0):
|
||||
assert attn_metadata.attn_mask is not None
|
||||
mask = attn_metadata.attn_mask
|
||||
assert attn_metadata.prefill_metadata is not None
|
||||
assert attn_metadata.prefill_metadata.seq_lens is not None
|
||||
mask = attn_metadata.attn_mask
|
||||
self.seq_lens_tensor_cpu = torch.from_numpy(
|
||||
np.array(attn_metadata.prefill_metadata.seq_lens).astype(
|
||||
np.int32))
|
||||
k_pe = k_pe.repeat(1, self.num_heads, 1)
|
||||
key = torch.cat(
|
||||
[k_nope.view(num_tokens, self.num_heads, -1), k_pe], dim=2)
|
||||
torch_npu._npu_flash_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
@@ -987,29 +1140,55 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
||||
)
|
||||
elif attn_metadata.decode_metadata:
|
||||
assert kv_cache is not None
|
||||
# if torch.empty is used here, the preemptive scheduling case of
|
||||
# test_mtp_correctness.py will fail to run.
|
||||
attn_output = torch.randn(
|
||||
[num_tokens, self.num_heads, self.kv_lora_rank],
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
self.seq_lens_tensor_cpu = torch.from_numpy(
|
||||
np.array(attn_metadata.decode_metadata.seq_lens).astype(
|
||||
np.int32))
|
||||
block_tables = attn_metadata.decode_metadata.block_tables
|
||||
torch_npu._npu_paged_attention_mla(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=block_tables,
|
||||
context_lens=self.seq_lens_tensor_cpu,
|
||||
mla_vheadsize=self.kv_lora_rank,
|
||||
out=attn_output)
|
||||
attn_output_t = torch.transpose(attn_output, 0, 1)
|
||||
attn_output_t = torch.bmm(attn_output_t, self.w_vc)
|
||||
attn_output = torch.transpose(attn_output_t, 0, 1)
|
||||
if VLLM_ENABLE_GRAPH_MODE == '1':
|
||||
# TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim]
|
||||
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
|
||||
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
|
||||
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
|
||||
q_nope,
|
||||
k_nope,
|
||||
k_nope,
|
||||
query_rope=q_pe,
|
||||
key_rope=k_pe,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=1,
|
||||
input_layout="BNSD",
|
||||
atten_mask=attn_metadata.attn_mask,
|
||||
scale=self.scale,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
block_table=attn_metadata.block_tables,
|
||||
block_size=kv_cache[0].shape[1],
|
||||
actual_seq_lengths_kv=attn_metadata.seq_lens,
|
||||
)
|
||||
attn_output = attn_output.view(num_tokens, -1,
|
||||
self.kv_lora_rank).transpose(
|
||||
0, 1)
|
||||
attn_output = torch.bmm(attn_output, self.w_vc).transpose(0, 1)
|
||||
else:
|
||||
# if torch.empty is used here, the preemptive scheduling case of
|
||||
# test_mtp_correctness.py will fail to run.
|
||||
attn_output = torch.randn(
|
||||
[num_tokens, self.num_heads, self.kv_lora_rank],
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
self.seq_lens_tensor_cpu = torch.from_numpy(
|
||||
np.array(attn_metadata.decode_metadata.seq_lens).astype(
|
||||
np.int32))
|
||||
block_tables = attn_metadata.decode_metadata.block_tables
|
||||
torch_npu._npu_paged_attention_mla(
|
||||
query=query,
|
||||
key_cache=kv_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=block_tables,
|
||||
context_lens=self.seq_lens_tensor_cpu,
|
||||
mla_vheadsize=self.kv_lora_rank,
|
||||
out=attn_output)
|
||||
attn_output_t = torch.transpose(attn_output, 0, 1)
|
||||
attn_output_t = torch.bmm(attn_output_t, self.w_vc)
|
||||
attn_output = torch.transpose(attn_output_t, 0, 1)
|
||||
|
||||
output, _ = self.o_proj(attn_output.reshape(num_tokens, -1))
|
||||
|
||||
|
||||
@@ -24,6 +24,10 @@ import torch_npu
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer, AttentionType)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill
|
||||
|
||||
|
||||
class AscendAttentionBackend(AttentionBackend):
|
||||
@@ -44,6 +48,10 @@ class AscendAttentionBackend(AttentionBackend):
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]:
|
||||
return AscendAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
@@ -94,11 +102,11 @@ class AscendAttentionState(Enum):
|
||||
class AscendMetadata:
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
block_tables: Optional[torch.Tensor]
|
||||
block_tables: torch.Tensor
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]] = None
|
||||
context_lens: Optional[List[int]] = None
|
||||
query_lens: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
# Maximum query length in the batch. None for decoding.
|
||||
max_query_len: Optional[int] = None
|
||||
# (num_tokens,). The indices of the token slots that input tokens will be
|
||||
@@ -117,6 +125,36 @@ class AscendMetadata:
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class AscendAttentionMetadataBuilder:
|
||||
|
||||
def __init__(self, runner):
|
||||
self.runner = runner
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
return False
|
||||
|
||||
def build(self, num_reqs, num_actual_tokens, max_query_len,
|
||||
common_prefix_len):
|
||||
block_table = (
|
||||
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
|
||||
query_lens = self.runner.query_lens
|
||||
seq_lens = self.runner.seq_lens_cpu[:num_reqs]
|
||||
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
||||
self.runner.device, non_blocking=True)
|
||||
attn_mask = self.runner.attn_mask
|
||||
attn_state = self.runner.attn_state
|
||||
|
||||
attn_metadata = AscendMetadata(block_tables=block_table,
|
||||
query_lens=query_lens,
|
||||
seq_lens=seq_lens,
|
||||
max_query_len=max_query_len,
|
||||
slot_mapping=slot_mapping,
|
||||
attn_mask=attn_mask,
|
||||
attn_state=attn_state)
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class AscendAttentionBackendImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
@@ -229,29 +267,46 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
out=output)
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
block_tables = attn_metadata.block_tables
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=block_tables,
|
||||
context_lens=attn_metadata.context_lens,
|
||||
out=output)
|
||||
torch_npu._npu_paged_attention(query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
# Normal V1 situation.
|
||||
else:
|
||||
# use paged attention
|
||||
torch_npu._npu_paged_attention_splitfuse(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
mask=attn_metadata.attn_mask,
|
||||
block_table=attn_metadata.block_tables,
|
||||
seq_len=attn_metadata.seq_lens,
|
||||
context_lens=attn_metadata.context_lens,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
out=output)
|
||||
# use chunked prefill for head size 192 scenario, like deepseek
|
||||
# paged_attention_splitfuse maybe crash at such scenario
|
||||
# TODO: vanilla path will be removed after the kernel support
|
||||
# head_size 192 scenario
|
||||
if self.head_size == 192:
|
||||
cu_seqlen_q = [0] + attn_metadata.query_lens.tolist()
|
||||
cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist()
|
||||
cu_seqlen_q = torch.tensor(cu_seqlen_q, device="npu")
|
||||
cu_seqlen_k = torch.tensor(cu_seqlen_k, device="npu")
|
||||
cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0)
|
||||
cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0)
|
||||
max_seqlen_q = torch.max(attn_metadata.query_lens)
|
||||
max_seqlen_k = torch.max(attn_metadata.seq_lens)
|
||||
vanilla_chunked_prefill(output, query, self.key_cache,
|
||||
self.value_cache,
|
||||
attn_metadata.block_tables,
|
||||
cu_seqlen_q, cu_seqlen_k, max_seqlen_q,
|
||||
max_seqlen_k, self.scale, None, True)
|
||||
else:
|
||||
torch_npu._npu_paged_attention_splitfuse(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
mask=attn_metadata.attn_mask,
|
||||
block_table=attn_metadata.block_tables,
|
||||
seq_len=attn_metadata.query_lens,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
out=output)
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
|
||||
561
vllm_ascend/attention/mla_v1.py
Normal file
561
vllm_ascend/attention/mla_v1.py
Normal file
@@ -0,0 +1,561 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
||||
AttentionMetadata,
|
||||
MLAAttentionImpl)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase, RowParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
|
||||
from vllm_ascend.ops.cache import concat_and_cache_mla
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AscendMLABackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "VLLM_ASCEND_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||
return AscendMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls():
|
||||
return AscendMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int,
|
||||
head_size: int) -> tuple[int, ...]:
|
||||
return (num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["MLAAttentionImpl"]:
|
||||
return AscendMLAImpl
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendMLAPrefillMetadata:
|
||||
""" Prefill Specific Metadata for Ascend"""
|
||||
attn_mask: torch.Tensor
|
||||
query_lens: list[int]
|
||||
context_lens: torch.Tensor
|
||||
input_positions: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
max_query_len: int
|
||||
max_context_len: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendMLADecodeMetadata:
|
||||
# Input positions for rotrary embeddings since for MLA the rotary
|
||||
# position embeddings are applied inside the attention backend
|
||||
input_positions: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendMLAMetadata:
|
||||
"""Metadata for MLACommon.
|
||||
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
"""
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
# New for MLA (compared to FlashAttention)
|
||||
# For handling prefill decode split
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_prefills: int
|
||||
|
||||
# For logging.
|
||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||
|
||||
# The dimension of the attention heads
|
||||
head_dim: Optional[int] = None
|
||||
attn_mask: torch.Tensor = None
|
||||
# chunked prefill by default if no attn_states passed
|
||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||
|
||||
decode: Optional[AscendMLADecodeMetadata] = None
|
||||
prefill: Optional[AscendMLAPrefillMetadata] = None
|
||||
|
||||
def __post_init__(self):
|
||||
pass
|
||||
# supported_head_sizes = AscendMLABackend.get_supported_head_sizes()
|
||||
# if self.head_dim is not None and self.head_dim \
|
||||
# not in supported_head_sizes:
|
||||
# raise ValueError(
|
||||
# f"Only {supported_head_sizes} are supported for head_dim,",
|
||||
# f"received {self.head_dim}.")
|
||||
|
||||
|
||||
M = TypeVar("M", bound=AscendMLAMetadata)
|
||||
|
||||
|
||||
class AscendMLAMetadataBuilder:
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
"""
|
||||
|
||||
# _attn_mask_builder = None
|
||||
def __init__(self,
|
||||
runner: "NPUModelRunner",
|
||||
metadata_cls: Optional[AscendMLAMetadata] = None):
|
||||
self.metadata_cls: Optional[AscendMLAMetadata] = metadata_cls \
|
||||
if metadata_cls is not None else AscendMLAMetadata # type: ignore
|
||||
self.runner = runner
|
||||
scheduler_config = runner.scheduler_config
|
||||
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
|
||||
# self.attn_mask = None
|
||||
# if AscendMLAMetadataBuilder._attn_mask_builder is None:
|
||||
# AscendMLAMetadataBuilder._attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
|
||||
# 128, self.runner.model_config.dtype
|
||||
# )
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
# We now want to reorder the batch so that the "decode" requests are at
|
||||
# the front and the "prefill" requests are at the using the least amount
|
||||
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
|
||||
# where attention is likely memory-bound and "prefill" to mean requests
|
||||
# where attention is likely compute-bound, TODO(lucas): figure out a
|
||||
# better naming here)
|
||||
decodes = []
|
||||
prefills = []
|
||||
num_decode_tokens = 0
|
||||
num_prefill_tokens = 0
|
||||
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
# for now treat 1 scheduled token as "decode" even if its not,
|
||||
# we should update this to something like < 8 in the future but
|
||||
# currently the TritonMLA._forward_decode only supports
|
||||
# num_tokens = 1
|
||||
if num_tokens == 1:
|
||||
decodes.append(i)
|
||||
num_decode_tokens += num_tokens
|
||||
else:
|
||||
prefills.append(i)
|
||||
num_prefill_tokens += num_tokens
|
||||
|
||||
# We hope that this is fairly minimal since decodes
|
||||
# should be around for a number of iterations so hopefully they are
|
||||
# relatively stationary (and new request are generally appended to the
|
||||
# persistent batch so already should be at the back)
|
||||
# To achieve this we loop over the decodes in descending order and
|
||||
# the prefills in ascending order. We swap decodes from the "back"
|
||||
# i.e. past where the last decode should be in the reodorered with
|
||||
# prefills from the front of the batch.
|
||||
# `decodes` and `prefills` are already in ascending order just based on
|
||||
# the above loop
|
||||
num_decodes = len(decodes)
|
||||
num_prefills = len(prefills)
|
||||
first_prefill = 0
|
||||
modified_batch = False
|
||||
|
||||
for i in range(1, min(num_decodes, num_prefills) + 1):
|
||||
# If the decode is at the "back" of the batch, i, we can swap it
|
||||
# with the prefill closest to the front of the batch
|
||||
if decodes[num_decodes - i] >= num_decodes:
|
||||
input_batch.swap_states(prefills[first_prefill],
|
||||
decodes[num_decodes - i])
|
||||
first_prefill += 1
|
||||
modified_batch = True
|
||||
else:
|
||||
break
|
||||
|
||||
# Save for next `build` call
|
||||
# TODO(lucas): this is a bit of a hack, we should probably have a
|
||||
# better way of doing this
|
||||
self._num_decodes = num_decodes
|
||||
self._num_prefills = num_prefills
|
||||
self._num_decode_tokens = num_decode_tokens
|
||||
self._num_prefill_tokens = num_prefill_tokens
|
||||
|
||||
return modified_batch
|
||||
|
||||
def build(self,
|
||||
num_reqs: int,
|
||||
num_actual_tokens: int,
|
||||
max_query_len: int,
|
||||
common_prefix_len: Optional[int] = None) -> AscendMLAMetadata:
|
||||
assert self._num_decodes + self._num_prefills == num_reqs
|
||||
|
||||
# Note(simon): be careful about the CPU <> GPU memory movement in this
|
||||
# function. We should avoid GPU -> CPU sync as much as possible because
|
||||
# it blocks on all previous kernels.
|
||||
device = self.runner.device
|
||||
block_table = (
|
||||
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
|
||||
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
||||
device, non_blocking=True).long()
|
||||
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
|
||||
device, non_blocking=True).long()
|
||||
|
||||
seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs]
|
||||
query_lens = seq_lens_cpu - self.runner.input_batch.num_computed_tokens_cpu_tensor[:
|
||||
num_reqs]
|
||||
seq_lens = seq_lens_cpu
|
||||
max_query_len = query_lens.max().item()
|
||||
max_context_len = seq_lens.max().item()
|
||||
|
||||
prefill_metadata = None
|
||||
if self._num_prefills > 0:
|
||||
reqs_start = self._num_decodes # prefill_start
|
||||
tokens_start = self._num_decode_tokens
|
||||
|
||||
prefill_metadata = AscendMLAPrefillMetadata(
|
||||
attn_mask=self.runner.attn_mask,
|
||||
query_lens=query_lens[tokens_start:],
|
||||
context_lens=seq_lens[tokens_start:],
|
||||
input_positions=input_positions[tokens_start:],
|
||||
block_table=block_table[reqs_start:, ...],
|
||||
max_query_len=max_query_len,
|
||||
max_context_len=max_context_len,
|
||||
)
|
||||
|
||||
decode_metadata = None
|
||||
if self._num_decodes > 0:
|
||||
decode_metadata = AscendMLADecodeMetadata(
|
||||
input_positions=input_positions[:self._num_decode_tokens],
|
||||
block_table=block_table[:self._num_decode_tokens, ...],
|
||||
seq_lens=seq_lens[:self._num_decode_tokens])
|
||||
|
||||
return self.metadata_cls( # type: ignore
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
head_dim=self.runner.model_config.get_head_size(),
|
||||
num_decodes=self._num_decodes,
|
||||
num_decode_tokens=self._num_decode_tokens,
|
||||
num_prefills=self._num_prefills,
|
||||
attn_mask=self.runner.attn_mask,
|
||||
attn_state=self.runner.attn_state,
|
||||
prefill=prefill_metadata,
|
||||
decode=decode_metadata,
|
||||
)
|
||||
|
||||
|
||||
class AscendMLAImpl(MLAAttentionImpl):
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
# MLA Specific Arguments
|
||||
q_lora_rank: Optional[int],
|
||||
kv_lora_rank: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
qk_head_dim: int,
|
||||
v_head_dim: int,
|
||||
rotary_emb: RotaryEmbedding,
|
||||
# q_proj should be q_b_proj if q_lora_rank is not None, but from an
|
||||
# attention backend perspective we rely on the layer to pass in the
|
||||
# correct matrix
|
||||
q_proj: ColumnParallelLinear,
|
||||
kv_b_proj: ColumnParallelLinear,
|
||||
o_proj: RowParallelLinear,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.qk_head_dim = qk_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
|
||||
# Hack for V1 for now to avoid torch library overhead (since we are
|
||||
# already inside an attention custom op), pull out the forward
|
||||
# method from the rotary embedding and call it directly
|
||||
# TODO(lucas): we should probably find a cleaner way to do this
|
||||
self.rotary_emb = rotary_emb.forward_native
|
||||
|
||||
self.q_proj = q_proj
|
||||
self.kv_b_proj = kv_b_proj
|
||||
self.o_proj = o_proj
|
||||
|
||||
# Handle the differences between the flash_attn_varlen from flash_attn
|
||||
# and the one from vllm_flash_attn. The former is used on RoCM and the
|
||||
# latter has an additional parameter to control FA2 vs FA3
|
||||
# self.flash_attn_varlen_func = flash_attn_varlen_func
|
||||
# if self.vllm_flash_attn_version is not None:
|
||||
# self.flash_attn_varlen_func = \
|
||||
# functools.partial(flash_attn_varlen_func,
|
||||
# fa_version=self.vllm_flash_attn_version)
|
||||
|
||||
def _v_up_proj_and_o_proj(self, x):
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
||||
x = torch.bmm(x, self.W_UV)
|
||||
# Convert from (N, B, V) to (B, N * V)
|
||||
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
||||
return self.o_proj(x)[0]
|
||||
|
||||
# Return `ql_nope`, `q_pe`
|
||||
def _q_proj_and_k_up_proj(self, x):
|
||||
q_nope, q_pe = self.q_proj(x)[0]\
|
||||
.view(-1, self.num_heads, self.qk_head_dim)\
|
||||
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
q_nope = q_nope.transpose(0, 1)
|
||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||
ql_nope = torch.bmm(q_nope, self.W_UK_T)
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
return ql_nope.transpose(0, 1), q_pe
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
|
||||
def get_layer_weight(layer):
|
||||
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
|
||||
for attr in WEIGHT_NAMES:
|
||||
if hasattr(layer, attr):
|
||||
return getattr(layer, attr)
|
||||
raise AttributeError(
|
||||
f"Layer '{layer}' has no recognized weight attribute:"
|
||||
f" {WEIGHT_NAMES}.")
|
||||
|
||||
def get_and_maybe_dequant_weights(layer: LinearBase):
|
||||
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
|
||||
# NOTE: This should only be used offline, since it's O(N^3)
|
||||
eye = torch.eye(layer.input_size_per_partition,
|
||||
dtype=act_dtype,
|
||||
device=get_layer_weight(layer).device)
|
||||
dequant_weights = layer.quant_method.apply(layer,
|
||||
eye,
|
||||
bias=None)
|
||||
del eye
|
||||
# standardize to (output, input)
|
||||
return dequant_weights.T
|
||||
return layer.weight
|
||||
|
||||
# we currently do not have quantized bmm's which are needed for
|
||||
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
|
||||
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
|
||||
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
||||
assert kv_b_proj_weight.shape == (
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
|
||||
f"{kv_b_proj_weight.shape=}, "
|
||||
f"{self.kv_lora_rank=}, "
|
||||
f"{self.num_heads=}, "
|
||||
f"{self.qk_nope_head_dim=}, "
|
||||
f"{self.v_head_dim=}")
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim,
|
||||
)
|
||||
|
||||
W_UK, W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
# Convert from (L, N, V) to (N, L, V)
|
||||
self.W_UV = W_UV.transpose(0, 1)
|
||||
# Convert from (L, N, P) to (N, P, L)
|
||||
self.W_UK_T = W_UK.permute(1, 2, 0)
|
||||
|
||||
def _forward_prefill(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: AscendMLAMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert attn_metadata.prefill is not None
|
||||
|
||||
# TODO: enable this compute for flash attention computation
|
||||
# kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
|
||||
# -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
# k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
# key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
||||
# v_padded = torch.nn.functional.pad(v, [0, query.shape[-1] - v.shape[-1]],
|
||||
# value=0)
|
||||
num_tokens = query.size(0)
|
||||
attn_output = torch.empty(num_tokens,
|
||||
self.num_heads,
|
||||
self.v_head_dim,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
# current requests is chunked in prefill, disable flash attention with chunked prefill
|
||||
vanilla_chunked_prefill_mla(
|
||||
output=attn_output,
|
||||
query=query,
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
block_tables=attn_metadata.prefill.block_table,
|
||||
query_lens=attn_metadata.prefill.query_lens,
|
||||
context_lens=attn_metadata.prefill.context_lens,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
max_query_len=attn_metadata.prefill.max_query_len,
|
||||
max_context_len=attn_metadata.prefill.max_context_len,
|
||||
nope_dim=self.qk_nope_head_dim,
|
||||
rope_dim=self.qk_rope_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
scale=self.scale,
|
||||
alibi_slopes=None,
|
||||
causal=True)
|
||||
attn_output = attn_output.view(
|
||||
[num_tokens, self.num_heads * self.v_head_dim])
|
||||
return self.o_proj(attn_output)[0]
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: AscendMLAMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
|
||||
decode_meta = attn_metadata.decode
|
||||
assert decode_meta is not None
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
num_tokens = q.size(0)
|
||||
attn_output = torch.randn(
|
||||
[num_tokens, self.num_heads, self.kv_lora_rank],
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
torch_npu._npu_paged_attention_mla(
|
||||
query=q,
|
||||
key_cache=kv_c_and_k_pe_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.decode.block_table, # type:ignore
|
||||
context_lens=attn_metadata.decode.seq_lens, # type:ignore
|
||||
mla_vheadsize=self.kv_lora_rank,
|
||||
out=attn_output)
|
||||
return self._v_up_proj_and_o_proj(attn_output)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
hidden_states_or_q_c: torch.Tensor, # query in unified attn
|
||||
k_c_normed: torch.Tensor, # key in unified attn
|
||||
k_pe: torch.Tensor, # value in unified attn
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: M,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
|
||||
num_actual_toks = attn_metadata.num_actual_tokens
|
||||
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
output_padded = output
|
||||
output = output[:num_actual_toks, ...]
|
||||
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
|
||||
k_c_normed = k_c_normed[:num_actual_toks, ...]
|
||||
k_pe = k_pe[:num_actual_toks, ...]
|
||||
|
||||
# Restore head dim (for rotary embedding)
|
||||
k_pe = k_pe.unsqueeze(1)
|
||||
|
||||
assert attn_metadata.num_decodes is not None and \
|
||||
attn_metadata.num_prefills is not None and \
|
||||
attn_metadata.num_decode_tokens is not None
|
||||
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
|
||||
decode_k_pe = k_pe[:num_decode_tokens]
|
||||
|
||||
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
|
||||
prefill_k_pe = k_pe[num_decode_tokens:]
|
||||
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
|
||||
|
||||
if has_decode:
|
||||
assert attn_metadata.decode is not None
|
||||
decode_ql_nope, decode_q_pe = \
|
||||
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
||||
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
||||
attn_metadata.decode.input_positions, decode_q_pe.contiguous(),
|
||||
decode_k_pe)
|
||||
|
||||
if has_prefill:
|
||||
assert attn_metadata.prefill is not None
|
||||
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
|
||||
.view(-1, self.num_heads, self.qk_head_dim)
|
||||
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
||||
|
||||
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
||||
attn_metadata.prefill.input_positions,
|
||||
prefill_q_pe.contiguous(), prefill_k_pe)
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
concat_and_cache_mla(k_c_normed, k_pe, kv_cache,
|
||||
attn_metadata.slot_mapping.flatten())
|
||||
# TODO: replaced back to ascend ops
|
||||
# key = torch.cat([k_c_normed.view([num_actual_toks, self.num_kv_heads, -1]), k_pe], dim=2)
|
||||
# torch_npu._npu_reshape_and_cache_siso(
|
||||
# key=key,
|
||||
# key_cache=kv_cache,
|
||||
# slot_indices=attn_metadata.slot_mapping.flatten())
|
||||
|
||||
if has_prefill:
|
||||
output[num_decode_tokens:] = self._forward_prefill(
|
||||
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
|
||||
attn_metadata)
|
||||
|
||||
if has_decode:
|
||||
output[:num_decode_tokens] = self._forward_decode(
|
||||
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
|
||||
|
||||
return output_padded
|
||||
@@ -462,4 +462,4 @@ class LLMDataDistConnector(KVConnectorBase):
|
||||
|
||||
def close(self, ):
|
||||
self.llm_datadist_engine.data_dist.unlink_clusters([self.cluster],
|
||||
5000)
|
||||
5000)
|
||||
|
||||
75
vllm_ascend/distributed/parallel_state.py
Normal file
75
vllm_ascend/distributed/parallel_state.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
|
||||
init_model_parallel_group)
|
||||
|
||||
# vllm-ascend will maintain its own EP GroupCoordinator and ETP GroupCoordinator for
|
||||
# customize parallel solution
|
||||
_EP: Optional[GroupCoordinator] = None
|
||||
_ETP: Optional[list[GroupCoordinator]] = None
|
||||
|
||||
|
||||
def get_ep_group() -> GroupCoordinator:
|
||||
assert _EP is not None, ("expert model parallel group is not initialized")
|
||||
return _EP
|
||||
|
||||
|
||||
def get_etp_group() -> GroupCoordinator:
|
||||
assert _ETP is not None, (
|
||||
"expert tensor parallel group is not initialized")
|
||||
return _ETP
|
||||
|
||||
|
||||
def init_ascend_model_parallel(
|
||||
tensor_model_parallel_size: int = 1,
|
||||
pipeline_model_parallel_size: int = 1,
|
||||
expert_tensor_parallel_size: int = 1,
|
||||
backend: Optional[str] = None,
|
||||
):
|
||||
assert torch.distributed.is_initialized()
|
||||
world_size: int = torch.distributed.get_world_size()
|
||||
backend = backend or torch.distributed.get_backend(
|
||||
get_world_group().device_group)
|
||||
num_expert_parallel_groups: int = expert_tensor_parallel_size
|
||||
num_expert_tensor_parallel_groups: int = (world_size //
|
||||
expert_tensor_parallel_size)
|
||||
|
||||
global _EP
|
||||
assert _EP is None, ("expert parallel group is already initialized")
|
||||
group_ranks = []
|
||||
for i in range(num_expert_parallel_groups):
|
||||
ranks = list(range(i, world_size, num_expert_parallel_groups))
|
||||
group_ranks.append(ranks)
|
||||
|
||||
_EP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="ep")
|
||||
|
||||
group_ranks = []
|
||||
global _ETP
|
||||
assert _ETP is None, (
|
||||
"expert tensor parallel group is already initialized")
|
||||
for i in range(num_expert_tensor_parallel_groups):
|
||||
ranks = list(
|
||||
range(i * expert_tensor_parallel_size,
|
||||
(i + 1) * expert_tensor_parallel_size))
|
||||
group_ranks.append(ranks)
|
||||
|
||||
_ETP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="etp")
|
||||
|
||||
|
||||
def destory_ascend_model_parallel():
|
||||
global _EP
|
||||
if _EP:
|
||||
_EP.destroy()
|
||||
_EP = None
|
||||
|
||||
global _ETP
|
||||
if _ETP:
|
||||
_ETP.destroy()
|
||||
_ETP = None
|
||||
@@ -2,8 +2,15 @@ from vllm import ModelRegistry
|
||||
|
||||
|
||||
def register_model():
|
||||
from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401
|
||||
from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401
|
||||
from .deepseek_v2 import CustomDeepseekV3ForCausalLM # noqa: F401
|
||||
from .qwen2_vl import CustomQwen2VLForConditionalGeneration # noqa: F401
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepSeekMTPModel",
|
||||
"vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen2VLForConditionalGeneration",
|
||||
"vllm_ascend.models.qwen2_vl:CustomQwen2VLForConditionalGeneration")
|
||||
|
||||
173
vllm_ascend/models/deepseek_mtp.py
Normal file
173
vllm_ascend/models/deepseek_mtp.py
Normal file
@@ -0,0 +1,173 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Adapted from vllm/model_executor/models/qwen2_vl.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import \
|
||||
VocabParallelEmbedding
|
||||
from vllm.model_executor.models.deepseek_mtp import (
|
||||
DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer,
|
||||
SharedHead)
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
|
||||
from .deepseek_v2 import CustomDeepseekV2DecoderLayer
|
||||
|
||||
|
||||
class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
prefix: str,
|
||||
model_config: ModelConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
|
||||
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
self.shared_head = SharedHead(config=config, quant_config=quant_config)
|
||||
self.mtp_block = CustomDeepseekV2DecoderLayer(config, prefix,
|
||||
model_config,
|
||||
cache_config,
|
||||
quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_index: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
assert inputs_embeds is not None
|
||||
# masking inputs at position 0, as not needed by MTP
|
||||
inputs_embeds = torch.where((positions == 0).unsqueeze(-1),
|
||||
torch.zeros_like(inputs_embeds),
|
||||
inputs_embeds)
|
||||
inputs_embeds = self.enorm(inputs_embeds)
|
||||
previous_hidden_states = self.hnorm(previous_hidden_states)
|
||||
|
||||
hidden_states = self.eh_proj(
|
||||
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
|
||||
|
||||
hidden_states, residual = self.mtp_block(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=None)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.mtp_start_layer_idx = config.num_hidden_layers
|
||||
self.num_mtp_layers = config.num_nextn_predict_layers
|
||||
# to map the exact layer index from weights
|
||||
self.layers = torch.nn.ModuleDict({
|
||||
str(idx): CustomDeepSeekMultiTokenPredictorLayer(
|
||||
config,
|
||||
f"{prefix}.layers.{idx}",
|
||||
model_config=vllm_config.model_config,
|
||||
cache_config=vllm_config.cache_config,
|
||||
quant_config=vllm_config.quant_config,
|
||||
)
|
||||
for idx in range(self.mtp_start_layer_idx,
|
||||
self.mtp_start_layer_idx + self.num_mtp_layers)
|
||||
})
|
||||
|
||||
# Note: torch._dynamo.exc.Unsupported: builtin: str
|
||||
self.layers_list = [
|
||||
self.layers[str(idx)]
|
||||
for idx in range(self.mtp_start_layer_idx,
|
||||
self.mtp_start_layer_idx + self.num_mtp_layers)
|
||||
]
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||
return self.layers_list[current_step_idx](
|
||||
input_ids,
|
||||
positions,
|
||||
kv_caches[current_step_idx],
|
||||
attn_metadata,
|
||||
previous_hidden_states,
|
||||
inputs_embeds,
|
||||
current_step_idx,
|
||||
)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||
mtp_layer = self.layers_list[current_step_idx]
|
||||
logits = self.logits_processor(mtp_layer.shared_head.head,
|
||||
mtp_layer.shared_head(hidden_states),
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
||||
class CustomDeepSeekMTP(DeepSeekMTP):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
self.model = CustomDeepSeekMultiTokenPredictor(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "model"))
|
||||
|
||||
self.sampler = get_sampler()
|
||||
@@ -19,36 +19,77 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Adapted from
|
||||
# vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_v2.py
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||
# vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py
|
||||
"""Inference-only DeepseekV2/DeepseekV3 model."""
|
||||
from typing import Optional, Union
|
||||
# <<<<<<< HEAD
|
||||
# # Adapted from
|
||||
# # vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_v2.py
|
||||
# # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||
# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py
|
||||
# """Inference-only DeepseekV2/DeepseekV3 model."""
|
||||
# from typing import Optional, Union
|
||||
|
||||
# import torch
|
||||
# from torch import nn
|
||||
# from transformers import PretrainedConfig
|
||||
# from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
# from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
# from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
# from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
# from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
# from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
# from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
# from vllm.model_executor.layers.sampler import get_sampler
|
||||
# from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
# ParallelLMHead, VocabParallelEmbedding)
|
||||
# from vllm.model_executor.models.deepseek_v2 import ( # noqa
|
||||
# DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM,
|
||||
# DeepseekV2MLAAttention, DeepseekV2MLP, DeepseekV2MoE)
|
||||
# =======
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
|
||||
get_current_vllm_config)
|
||||
from vllm.distributed import (get_dp_group, get_pp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group, tensor_model_parallel_all_reduce)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.deepseek_v2 import ( # noqa
|
||||
DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM,
|
||||
DeepseekV2MLAAttention, DeepseekV2MLP, DeepseekV2MoE)
|
||||
from vllm.model_executor.models.deepseek_v2 import \
|
||||
DeepseekV2ForCausalLM # ruff: noqa: E501
|
||||
from vllm.model_executor.models.deepseek_v2 import \
|
||||
yarn_get_mscale # ruff: noqa: E501
|
||||
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention,
|
||||
DeepseekV2DecoderLayer,
|
||||
DeepseekV2MLAAttention,
|
||||
DeepseekV2MLP)
|
||||
from vllm.model_executor.models.utils import (
|
||||
PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
# >>>>>>> dcd5c73 (Feat: Graph mode for deepseek v2/v3.)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||
from vllm_ascend.utils import VLLM_ENABLE_GRAPH_MODE
|
||||
|
||||
class CustomDeepseekV2MoE(DeepseekV2MoE):
|
||||
|
||||
class CustomDeepseekV2MoE(nn.Module):
|
||||
|
||||
top_k: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -56,10 +97,15 @@ class CustomDeepseekV2MoE(DeepseekV2MoE):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
nn.Module.__init__(self)
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.n_shared_experts = config.n_shared_experts
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
if self.tp_size > config.n_routed_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {config.n_routed_experts}.")
|
||||
|
||||
if config.hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
||||
@@ -76,7 +122,7 @@ class CustomDeepseekV2MoE(DeepseekV2MoE):
|
||||
else:
|
||||
self.gate.e_score_correction_bias = None
|
||||
|
||||
self.experts = FusedMoE(
|
||||
self.experts = AscendFusedMoE(
|
||||
num_experts=config.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
@@ -99,9 +145,248 @@ class CustomDeepseekV2MoE(DeepseekV2MoE):
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
reduce_results=True,
|
||||
prefix=f"{prefix}.shared_experts",
|
||||
)
|
||||
CustomDeepseekV2MoE.top_k = config.num_experts_per_tok
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.dp_size = get_dp_group().world_size
|
||||
batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
self.enable_mc2 = int(os.environ.get("VLLM_ENABLE_MC2", 0)) == 1
|
||||
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.final_hidden_states = torch.zeros(
|
||||
[batch_size, config.hidden_size], dtype=params_dtype, device="npu")
|
||||
self.tp_group = get_tp_group().device_group
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if attn_metadata is None:
|
||||
# for profile run
|
||||
return hidden_states
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
if self.n_shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
|
||||
if (self.tp_size > 1 and self.enable_mc2
|
||||
and attn_metadata.num_prefills == 0):
|
||||
# hidden_states = dist._functional_collectives.reduce_scatter_tensor(
|
||||
# hidden_states, "sum", scatter_dim=0, group=self.tp_group
|
||||
# )
|
||||
chunks = torch.chunk(hidden_states,
|
||||
get_tp_group().world_size,
|
||||
dim=0)
|
||||
hidden_states = chunks[get_tp_group().rank_in_group]
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
is_prefill = True if attn_metadata.num_prefills > 0 else False
|
||||
# is_prefill = attn_metadata.num_prefills > 0
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
is_prefill=is_prefill,
|
||||
top_k=CustomDeepseekV2MoE.top_k) * self.routed_scaling_factor
|
||||
|
||||
if self.tp_size > 1:
|
||||
if self.enable_mc2 and not is_prefill:
|
||||
dist.all_gather_into_tensor(self.final_hidden_states,
|
||||
final_hidden_states, self.tp_group)
|
||||
final_hidden_states = self.final_hidden_states
|
||||
else:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
|
||||
class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
q_lora_rank: Optional[int],
|
||||
kv_lora_rank: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.hidden_size = hidden_size
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
|
||||
self.num_heads = num_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert num_heads % tp_size == 0
|
||||
self.num_local_heads = num_heads // tp_size
|
||||
|
||||
self.scaling = self.qk_head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
self.q_a_proj = ReplicatedLinear(self.hidden_size,
|
||||
self.q_lora_rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_a_proj")
|
||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
|
||||
eps=config.rms_norm_eps)
|
||||
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
|
||||
self.num_heads *
|
||||
self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_b_proj")
|
||||
else:
|
||||
self.q_proj = ColumnParallelLinear(self.hidden_size,
|
||||
self.num_heads *
|
||||
self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_proj")
|
||||
|
||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.kv_a_proj_with_mqa")
|
||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
|
||||
eps=config.rms_norm_eps)
|
||||
self.kv_b_proj = ColumnParallelLinear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.kv_b_proj")
|
||||
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj")
|
||||
|
||||
if rope_scaling:
|
||||
rope_scaling["rope_type"] = 'deepseek_yarn'
|
||||
self.rotary_emb = get_rope(qk_rope_head_dim,
|
||||
rotary_dim=qk_rope_head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=False)
|
||||
if rope_scaling:
|
||||
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
self.scaling = self.scaling * mscale * mscale
|
||||
|
||||
# In the MLA backend, kv_cache includes both k_c and
|
||||
# pe (i.e. decoupled position embeddings). In particular,
|
||||
# the concat_and_cache_mla op requires
|
||||
# k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
|
||||
# i.e.
|
||||
# kv_lora_rank + qk_rope_head_dim == head_size
|
||||
self.mla_attn = Attention(
|
||||
num_heads=self.num_local_heads,
|
||||
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
scale=self.scaling,
|
||||
num_kv_heads=1,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_mla=True,
|
||||
# MLA Args
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
qk_head_dim=self.qk_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
rotary_emb=self.rotary_emb,
|
||||
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
|
||||
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
|
||||
kv_a_layernorm=self.kv_a_layernorm,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
o_proj=self.o_proj,
|
||||
)
|
||||
|
||||
self.prefix = prefix
|
||||
self.debug_layer_idx = int(self.prefix.split(".")[-2])
|
||||
if VLLM_ENABLE_GRAPH_MODE == "1":
|
||||
self.forward = self.forward_torchair
|
||||
else:
|
||||
self.forward = self.forward_eager # type: ignore
|
||||
|
||||
def forward_torchair(self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor = None,
|
||||
attn_metadata=None):
|
||||
if self.q_lora_rank is not None:
|
||||
ckq = self.q_a_proj(hidden_states)[0]
|
||||
hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
||||
else:
|
||||
hidden_states_or_q_c = hidden_states
|
||||
return self.mla_attn(hidden_states_or_q_c, hidden_states, None,
|
||||
kv_cache, attn_metadata)
|
||||
|
||||
def forward_eager(self, positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor):
|
||||
if self.q_lora_rank is not None:
|
||||
ckq = self.q_a_proj(hidden_states)[0]
|
||||
hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
||||
else:
|
||||
hidden_states_or_q_c = hidden_states
|
||||
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
return self.mla_attn(hidden_states_or_q_c,
|
||||
kv_c_normed,
|
||||
k_pe,
|
||||
output_shape=hidden_states.shape)
|
||||
|
||||
# def forward(
|
||||
# self,
|
||||
# positions: torch.Tensor,
|
||||
# hidden_states: torch.Tensor,
|
||||
# # torchair should pass below two parameters
|
||||
# kv_cache: torch.Tensor = None,
|
||||
# attn_metadata: AttentionMetadata = None,
|
||||
# ) -> torch.Tensor:
|
||||
# if self.q_lora_rank is not None:
|
||||
# ckq = self.q_a_proj(hidden_states)[0]
|
||||
# hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
||||
# else:
|
||||
# hidden_states_or_q_c = hidden_states
|
||||
# if VLLM_ENABLE_GRAPH_MODE == '1':
|
||||
# return self.mla_attn(hidden_states_or_q_c, hidden_states, None,
|
||||
# kv_cache, attn_metadata)
|
||||
# else:
|
||||
# kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
||||
# [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
# kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
# return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, output_shape=hidden_states.shape)
|
||||
# kv_cache, attn_metadata)
|
||||
|
||||
|
||||
class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
@@ -124,8 +409,9 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
# with the layer's index.
|
||||
layer_idx = int(prefix.split(sep='.')[-1])
|
||||
self.layer_idx = layer_idx
|
||||
# TODO: enable mla in vllm-ascend
|
||||
if model_config.use_mla:
|
||||
attn_cls = DeepseekV2MLAAttention
|
||||
attn_cls = CustomDeepseekV2MLAAttention
|
||||
else:
|
||||
attn_cls = DeepseekV2Attention
|
||||
self.self_attn = attn_cls(
|
||||
@@ -180,8 +466,8 @@ class CustomDeepseekV2Model(nn.Module):
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
|
||||
@@ -18,3 +18,4 @@ import vllm_ascend.ops.activation # noqa
|
||||
import vllm_ascend.ops.fused_moe # noqa
|
||||
import vllm_ascend.ops.layernorm # noqa
|
||||
import vllm_ascend.ops.rotary_embedding # noqa
|
||||
import vllm_ascend.ops.vocab_parallel_embedding # noqa
|
||||
|
||||
293
vllm_ascend/ops/attention.py
Normal file
293
vllm_ascend/ops/attention.py
Normal file
@@ -0,0 +1,293 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm/tests/kernels/test_moe.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
|
||||
|
||||
# Implementation of vanilla chunked prefill, should be removed after the kernel is ready for
|
||||
# all the corner case
|
||||
def vanilla_chunked_prefill(
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor, # (num_tokens, heads, head_size)
|
||||
key_cache: torch.Tensor, # (num_blocks, block_size, kv_heads, head_size)
|
||||
value_cache: torch.
|
||||
Tensor, # (num_blocks, block_size, kv_heads, head_size,)
|
||||
block_tables: torch.Tensor, # (num_seqs, max_num_blocks_per_seq)
|
||||
cu_seqlen_q: torch.Tensor, # (num_seqs + 1,)
|
||||
cu_seqlen_k: torch.Tensor, # (num_seqs + 1,)
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
causal: bool = True,
|
||||
) -> None:
|
||||
num_query_heads = query.shape[1]
|
||||
head_dim = value_cache.shape[3]
|
||||
num_kv_heads = value_cache.shape[2]
|
||||
block_size = value_cache.shape[1]
|
||||
num_batch = cu_seqlen_q.shape[0] - 1
|
||||
max_num_blocks_per_seq = block_tables.shape[1]
|
||||
|
||||
key = key_cache[block_tables].view(num_batch,
|
||||
max_num_blocks_per_seq * block_size,
|
||||
num_kv_heads, head_dim)
|
||||
|
||||
value = value_cache[block_tables].view(num_batch,
|
||||
max_num_blocks_per_seq * block_size,
|
||||
num_kv_heads, head_dim)
|
||||
key = key[:, :max_seqlen_k, :, :]
|
||||
value = value[:, :max_seqlen_k, :, :]
|
||||
|
||||
seqlen_k = cu_seqlen_k[1:] - cu_seqlen_k[:-1]
|
||||
seqlen_q = cu_seqlen_q[1:] - cu_seqlen_q[:-1]
|
||||
seqlen_q = seqlen_q.view(-1, 1)
|
||||
seqlen_k = seqlen_k.view(-1, 1)
|
||||
seqlen_diff = seqlen_k - seqlen_q
|
||||
q_idx_mask = (torch.arange(0, max_seqlen_q,
|
||||
device="npu").view(1, -1).repeat(num_batch, 1))
|
||||
k_idx_mask = (torch.arange(0, max_seqlen_k,
|
||||
device="npu").view(1, -1).repeat(num_batch, 1))
|
||||
q_mask = q_idx_mask < seqlen_q
|
||||
k_mask = k_idx_mask < seqlen_k
|
||||
|
||||
# calculate idx for causal mask of query [batch, max_seqlen_q]
|
||||
causal_mask_idx = (q_idx_mask + seqlen_diff)[q_mask]
|
||||
|
||||
# generate causal mask [batch, max_seqlen_q, max_seqlen_k]
|
||||
tril_mask = torch.tril(torch.ones(max_seqlen_k, max_seqlen_k,
|
||||
device="npu"))
|
||||
tril_mask[tril_mask == 0] = float("-inf")
|
||||
tril_mask[tril_mask == 1] = 0
|
||||
causal_mask = tril_mask[causal_mask_idx]
|
||||
causal_mask_padding = torch.empty([num_batch, max_seqlen_q, max_seqlen_k],
|
||||
device="npu").fill_(float("-inf"))
|
||||
causal_mask_padding[q_mask] = causal_mask
|
||||
# to [batch, num_heads, max_seqlen_q, max_seqlen_k]
|
||||
causal_mask_padding = causal_mask_padding.unsqueeze(1)
|
||||
|
||||
pad_q = torch.zeros(
|
||||
[num_batch, max_seqlen_q, num_query_heads, head_dim],
|
||||
device="npu",
|
||||
dtype=query.dtype,
|
||||
)
|
||||
pad_k = torch.zeros(
|
||||
[num_batch, max_seqlen_k, num_kv_heads, head_dim],
|
||||
device="npu",
|
||||
dtype=key.dtype,
|
||||
)
|
||||
pad_v = torch.zeros(
|
||||
[num_batch, max_seqlen_k, num_kv_heads, head_dim],
|
||||
device="npu",
|
||||
dtype=value.dtype,
|
||||
)
|
||||
pad_q[q_mask] = query
|
||||
pad_k[k_mask] = key[k_mask]
|
||||
pad_v[k_mask] = value[k_mask]
|
||||
|
||||
if num_query_heads > num_kv_heads:
|
||||
pad_k = pad_k.view(
|
||||
[num_batch, max_seqlen_k, num_kv_heads, 1, head_dim])
|
||||
pad_k = pad_k.repeat(1, 1, 1, num_query_heads // num_kv_heads, 1).view(
|
||||
[num_batch, max_seqlen_k, num_query_heads, head_dim])
|
||||
pad_v = pad_v.view(
|
||||
[num_batch, max_seqlen_k, num_kv_heads, 1, head_dim])
|
||||
pad_v = pad_v.repeat(1, 1, 1, num_query_heads // num_kv_heads, 1).view(
|
||||
[num_batch, max_seqlen_k, num_query_heads, head_dim])
|
||||
# permute to [b, h, n, k]
|
||||
pad_q = pad_q.permute(0, 2, 1, 3)
|
||||
pad_k = pad_k.permute(0, 2, 1, 3)
|
||||
pad_v = pad_v.permute(0, 2, 1, 3)
|
||||
attn_mask = torch.empty([num_batch, 1, 1, max_seqlen_k],
|
||||
device="npu").fill_(float("-inf"))
|
||||
attn_mask[:, :, :, :max_seqlen_k].masked_fill_(k_mask[:, None, None, :], 0)
|
||||
# [b, h, f, t]
|
||||
attn_weights = torch.einsum("bhqd,bhkd->bhqk", pad_q, pad_k)
|
||||
attn_weights *= scale
|
||||
attn_mask = attn_mask.float()
|
||||
attn_weights = attn_weights + attn_mask
|
||||
if causal:
|
||||
attn_weights = attn_weights + causal_mask_padding
|
||||
|
||||
attn_weights = torch.softmax(attn_weights, dim=-1)
|
||||
attn_output = torch.einsum("bhqk,bhkd->bhqd", attn_weights, pad_v.float())
|
||||
attn_output = attn_output.permute(0, 2, 1, 3)
|
||||
|
||||
attn_output = (attn_output[q_mask].view([-1, num_query_heads,
|
||||
head_dim]).to(output.dtype))
|
||||
output.copy_(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
def vanilla_chunked_prefill_mla(
|
||||
output: torch.Tensor, # (num_tokens, num_heads, v_head_dim)
|
||||
query: torch.Tensor, # (num_tokens, num_heads, nope_dim + rope_dim)
|
||||
kv_cache: torch.Tensor, # (num_blocks, block_size, latent_kv)
|
||||
block_tables: torch.Tensor, # (batch_size, max_num_blocks_per_seq)
|
||||
query_lens: torch.Tensor, # (batch_size)
|
||||
context_lens: torch.Tensor, # (batch_size)
|
||||
kv_b_proj: ColumnParallelLinear, # ()
|
||||
max_query_len: int,
|
||||
max_context_len: int,
|
||||
nope_dim: int,
|
||||
rope_dim: int,
|
||||
v_head_dim: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
causal: bool = True) -> None:
|
||||
batch_size = block_tables.size(0)
|
||||
assert query_lens.size(0) == batch_size
|
||||
num_heads = query.size(1)
|
||||
block_size = kv_cache.size(1)
|
||||
latent_kv_dim = kv_cache.size(3) - rope_dim
|
||||
max_num_blocks_per_seq = block_tables.size(1)
|
||||
batch_size = query_lens.size(0)
|
||||
kv_cache = kv_cache.squeeze()
|
||||
# select kv_c out as [batch_size, max_context_len, latent_kv + rope_dim]
|
||||
cache_kv_c_pe = kv_cache[block_tables].view(
|
||||
batch_size, max_num_blocks_per_seq * block_size,
|
||||
latent_kv_dim + rope_dim)[:, :max_context_len, :]
|
||||
# get kv_c and k_pe
|
||||
# cached_kv_c: [batch_size, max_context_len, latent_kv]
|
||||
# cached_k_pe: [batch_size, max_context_len, rope_dim]
|
||||
cache_kv_c = cache_kv_c_pe[:, :, :latent_kv_dim]
|
||||
cache_k_pe = cache_kv_c_pe[:, :, latent_kv_dim:]
|
||||
# get k_rope and v
|
||||
# k_nope: [batch_size, max_context_len, num_heads, nope_dim]
|
||||
# value: [batch_size, max_context_len, num_heads, v_head_dim]
|
||||
k_nope, value = kv_b_proj(cache_kv_c)[0].view(
|
||||
batch_size, max_context_len, num_heads,
|
||||
nope_dim + v_head_dim).split([nope_dim, v_head_dim], dim=-1)
|
||||
# key: [batch_size, max_context_len, num_hads, rope_dim + nope_dim]
|
||||
key = torch.cat(
|
||||
[k_nope, cache_k_pe.unsqueeze(2).expand(-1, -1, num_heads, -1)],
|
||||
dim=-1)
|
||||
|
||||
context_lens = context_lens.view(-1, 1).to("npu")
|
||||
query_lens = query_lens.view(-1, 1).to("npu")
|
||||
seq_diff = context_lens - query_lens
|
||||
|
||||
q_idx_mask = (torch.arange(0, max_query_len,
|
||||
device="npu").view(1, -1).repeat(batch_size, 1))
|
||||
kv_c_idx_mask = (torch.arange(0, max_context_len,
|
||||
device="npu").view(1,
|
||||
-1).repeat(batch_size, 1))
|
||||
kv_c_mask = kv_c_idx_mask < context_lens
|
||||
q_mask = q_idx_mask < query_lens
|
||||
|
||||
# calculate idx for causal mask of query [batch, max_seqlen_q]
|
||||
causal_mask_idx = (q_idx_mask + seq_diff)[q_mask]
|
||||
|
||||
# generate causal mask [batch, max_seqlen_q, max_seqlen_k]
|
||||
tril_mask = torch.tril(
|
||||
torch.ones(max_context_len, max_context_len, device="npu"))
|
||||
tril_mask[tril_mask == 0] = float("-inf")
|
||||
tril_mask[tril_mask == 1] = 0
|
||||
causal_mask = tril_mask[causal_mask_idx]
|
||||
causal_mask_padding = torch.empty(
|
||||
[batch_size, max_query_len, max_context_len],
|
||||
device="npu").fill_(float("-inf"))
|
||||
causal_mask_padding[q_mask] = causal_mask
|
||||
# to [batch, num_heads, max_seqlen_q, max_seqlen_k]
|
||||
causal_mask_padding = causal_mask_padding.unsqueeze(1)
|
||||
|
||||
pad_q = torch.zeros(
|
||||
[batch_size, max_query_len, num_heads, rope_dim + nope_dim],
|
||||
device="npu",
|
||||
dtype=query.dtype,
|
||||
)
|
||||
pad_k = torch.zeros(
|
||||
[batch_size, max_context_len, num_heads, rope_dim + nope_dim],
|
||||
device="npu",
|
||||
dtype=key.dtype,
|
||||
)
|
||||
pad_v = torch.zeros(
|
||||
[batch_size, max_context_len, num_heads, v_head_dim],
|
||||
device="npu",
|
||||
dtype=value.dtype,
|
||||
)
|
||||
pad_q[q_mask] = query
|
||||
pad_k[kv_c_mask] = key[kv_c_mask]
|
||||
pad_v[kv_c_mask] = value[kv_c_mask]
|
||||
|
||||
pad_q = pad_q.permute(0, 2, 1, 3)
|
||||
pad_k = pad_k.permute(0, 2, 1, 3)
|
||||
pad_v = pad_v.permute(0, 2, 1, 3)
|
||||
attn_mask = torch.empty([batch_size, 1, 1, max_context_len],
|
||||
device="npu").fill_(float("-inf"))
|
||||
attn_mask[:, :, :, :max_context_len].masked_fill_(
|
||||
kv_c_mask[:, None, None, :], 0)
|
||||
# [b, h, f, t]
|
||||
attn_weights = torch.einsum("bhqd,bhkd->bhqk", pad_q, pad_k)
|
||||
attn_weights *= scale
|
||||
attn_mask = attn_mask.float()
|
||||
attn_weights = attn_weights + attn_mask
|
||||
if causal:
|
||||
attn_weights = attn_weights + causal_mask_padding
|
||||
|
||||
attn_weights = torch.softmax(attn_weights, dim=-1)
|
||||
attn_output = torch.einsum("bhqk,bhkd->bhqd", attn_weights, pad_v.float())
|
||||
attn_output = attn_output.permute(0, 2, 1, 3)
|
||||
|
||||
attn_output = (attn_output[q_mask].view([-1, num_heads,
|
||||
v_head_dim]).to(output.dtype))
|
||||
output.copy_(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
def vanilla_decode_mla(
|
||||
query: torch.Tensor, # [num_tokens, num_heads, latent_dim + rope_dim]
|
||||
key_cache: torch.
|
||||
Tensor, # [num_blocks, block_size, num_kv_heads, latent_dim + rope_dim]
|
||||
num_kv_heads: int,
|
||||
num_heads: int,
|
||||
scale: float,
|
||||
block_table: torch.Tensor, # [batch_size, max_block_size]
|
||||
context_lens: List[int],
|
||||
mla_vhead_size: int,
|
||||
rope_dim: int,
|
||||
output: torch.Tensor):
|
||||
batch_size = block_table.size()[0]
|
||||
max_block_size = block_table.size()[1]
|
||||
reduce_dim = key_cache.size()[-1]
|
||||
block_size = key_cache.size()[1]
|
||||
latent_dim = reduce_dim - rope_dim
|
||||
kv_c_and_pe = key_cache[block_table].view(
|
||||
[batch_size, max_block_size * block_size, num_kv_heads, reduce_dim])
|
||||
max_context_len = max(context_lens)
|
||||
context_lens = torch.tensor(context_lens, device="npu").view(batch_size, 1)
|
||||
# [batch_size, max_context_len, num_kv_heads, latent_dim + rope_dim]
|
||||
# since the kv head is 1 in deepseek, we use expand here for perf
|
||||
kv_c_and_pe = kv_c_and_pe[:, :max_context_len, :, :].expand(
|
||||
-1, -1, num_heads, 1)
|
||||
kv_c = kv_c_and_pe[..., :latent_dim]
|
||||
kv_idx_mask = (torch.arange(0, max_context_len,
|
||||
device="npu").view(1,
|
||||
-1).repeat(batch_size, 1))
|
||||
# [batch_size, max_context_len]
|
||||
kv_idx_mask = kv_idx_mask < context_lens
|
||||
query = query.unsqueeze(1)
|
||||
attn_weights = torch.einsum("bqhd,bkhd->bhqk", query, kv_c_and_pe)
|
||||
attn_weights *= scale
|
||||
attn_weights = attn_weights + kv_idx_mask[:, -1, -1, :].float()
|
||||
attn_weights = torch.softmax(attn_weights, dim=-1)
|
||||
attn_output = torch.einsum("bhqk,bkhd->bqhd", attn_weights,
|
||||
kv_c.float()).view(-1, num_heads, latent_dim)
|
||||
output.copy_(attn_output)
|
||||
return output
|
||||
35
vllm_ascend/ops/cache.py
Normal file
35
vllm_ascend/ops/cache.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm/tests/kernels/test_moe.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def concat_and_cache_mla(
|
||||
kv_c_normed: torch.Tensor, # [num_tokens, num_kv_head, nope]
|
||||
k_pe: torch.Tensor, # [num_tokens, num_kv_head, rope]
|
||||
kv_cache: torch.
|
||||
Tensor, # [num_blocks, block_size, num_kv_head, nope + rope]
|
||||
slot_mapping, # [num_tokens]
|
||||
):
|
||||
num_blocks = kv_cache.size()[0]
|
||||
block_size = kv_cache.size()[1]
|
||||
num_kv_head = k_pe.size()[1]
|
||||
|
||||
idx_for_copy = slot_mapping // block_size * block_size + slot_mapping % block_size
|
||||
kv_cache = kv_cache.view(num_blocks * block_size, num_kv_head, -1)
|
||||
kv_cache[idx_for_copy] = torch.cat([kv_c_normed.unsqueeze(1), k_pe],
|
||||
dim=-1)
|
||||
@@ -15,12 +15,131 @@
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm/tests/kernels/test_moe.py
|
||||
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch_npu
|
||||
from vllm.model_executor.layers.fused_moe.layer import \
|
||||
UnquantizedFusedMoEMethod
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
|
||||
from vllm.model_executor.layers.quantization.base_config import \
|
||||
QuantizeMethodBase
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
|
||||
|
||||
|
||||
def fused_experts_with_mc2(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
moe_all_to_all_group_name: Optional[str] = None,
|
||||
) -> torch.Tensor:
|
||||
global_bs = 0
|
||||
moe_expert_num = len(expert_map)
|
||||
kwargs = {
|
||||
"x": hidden_states,
|
||||
"expert_ids": topk_ids,
|
||||
"expert_shard_type": 0,
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": moe_expert_num,
|
||||
"global_bs": global_bs,
|
||||
}
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
|
||||
quant_mode = 0
|
||||
ep_group = get_ep_group().device_group
|
||||
local_rank = torch.distributed.get_rank(group=ep_group)
|
||||
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
|
||||
|
||||
world_szie = torch.distributed.get_world_size()
|
||||
tp_size = world_szie // all_to_all_group_size
|
||||
tp_rank = rank % tp_size
|
||||
|
||||
stage1_kwargs = {
|
||||
"scales": None,
|
||||
"quant_mode": quant_mode,
|
||||
"group_ep": moe_all_to_all_group_name,
|
||||
"ep_world_size": all_to_all_group_size,
|
||||
"ep_rank_id": local_rank,
|
||||
# "group_tp": self.moe_rs_group_name,
|
||||
"group_tp": moe_all_to_all_group_name,
|
||||
"tp_world_size": tp_size,
|
||||
"tp_rank_id": tp_rank,
|
||||
}
|
||||
kwargs.update(stage1_kwargs)
|
||||
|
||||
output = torch_npu.npu_moe_distribute_dispatch(**kwargs)
|
||||
# comm_stream.wait_stream(torch.npu.current_stream())
|
||||
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
|
||||
0:5]
|
||||
|
||||
w1 = w1.transpose(1, 2)
|
||||
expert_token_nums = torch.cumsum(expert_token_nums,
|
||||
dim=0,
|
||||
dtype=torch.int64)
|
||||
group_list = expert_token_nums.to(torch.int64)
|
||||
gate_up_out_list = torch_npu.npu_grouped_matmul(
|
||||
x=[expand_x],
|
||||
weight=[w1],
|
||||
split_item=2,
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)
|
||||
|
||||
# TODO: Remove this in the future.
|
||||
gate_up_out = torch.cat(gate_up_out_list, dim=0)
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
||||
|
||||
w2 = w2.transpose(1, 2)
|
||||
down_out_list = torch_npu.npu_grouped_matmul(
|
||||
x=[gate_up_out],
|
||||
weight=[w2],
|
||||
split_item=2,
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)
|
||||
|
||||
down_out_list = torch.cat(down_out_list, dim=0)
|
||||
|
||||
# moeCombine
|
||||
kwargs = {
|
||||
"expand_x": down_out_list,
|
||||
"expert_ids": topk_ids,
|
||||
"expand_idx": expand_idx,
|
||||
"expert_scales": topk_weights.to(torch.float32),
|
||||
"expert_shard_type": 0,
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": moe_expert_num,
|
||||
"global_bs": 0,
|
||||
}
|
||||
tp_recv_counts = output[5]
|
||||
stage3_kwargs = {
|
||||
"ep_send_counts": ep_recv_counts,
|
||||
"group_ep": moe_all_to_all_group_name,
|
||||
"ep_world_size": all_to_all_group_size,
|
||||
"ep_rank_id": local_rank,
|
||||
"tp_send_counts": tp_recv_counts,
|
||||
# "group_tp": self.moe_rs_group_name,
|
||||
"group_tp": moe_all_to_all_group_name,
|
||||
"tp_world_size": tp_size,
|
||||
"tp_rank_id": tp_rank,
|
||||
}
|
||||
kwargs.update(stage3_kwargs)
|
||||
|
||||
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def fused_experts(
|
||||
@@ -47,22 +166,27 @@ def fused_experts(
|
||||
Returns:
|
||||
hidden_states: Hidden states after routing.
|
||||
"""
|
||||
"""
|
||||
# Check constraints.
|
||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||
"""
|
||||
# if torch.distributed.get_rank() == 0:
|
||||
# print(w1.shape)
|
||||
# print(hidden_states.shape)
|
||||
|
||||
original_shape = hidden_states.shape
|
||||
assert len(original_shape) == 2
|
||||
# assert len(original_shape) == 2
|
||||
|
||||
num_tokens = hidden_states.shape[:-1].numel()
|
||||
num_experts = w1.shape[0]
|
||||
dtype = hidden_states.dtype
|
||||
device = hidden_states.device
|
||||
assert dtype in [torch.float32, torch.float16, torch.bfloat16
|
||||
], "Only float32, float16, and bfloat16 are supported"
|
||||
# assert dtype in [torch.float32, torch.float16, torch.bfloat16
|
||||
# ], "Only float32, float16, and bfloat16 are supported"
|
||||
|
||||
if expert_map is not None:
|
||||
# Generate token indices and flatten
|
||||
@@ -152,11 +276,18 @@ def fused_experts(
|
||||
final_hidden_states = torch.zeros(*original_shape,
|
||||
device=hidden_states.device,
|
||||
dtype=dtype)
|
||||
final_hidden_states.index_add_(0, sorted_token_indices,
|
||||
weighted_down_out)
|
||||
# TODO: This should not happen! Look into it!
|
||||
# fill nan with 0.0
|
||||
final_hidden_states[torch.isnan(final_hidden_states)] = 0.0
|
||||
|
||||
# TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
|
||||
# This created multiple NaN and index_add_ will mix them up which harms accracy
|
||||
# remove this mask and filter after it being fixed
|
||||
num_valid_tokens = mask.sum()
|
||||
valid_token_mask = torch.arange(
|
||||
0, sorted_token_indices.shape[0],
|
||||
device=device).unsqueeze(1) < num_valid_tokens
|
||||
valid_output = torch.where(
|
||||
valid_token_mask, weighted_down_out,
|
||||
torch.zeros_like(weighted_down_out)).to(dtype)
|
||||
final_hidden_states.index_add_(0, sorted_token_indices, valid_output)
|
||||
else:
|
||||
# TODO: Reorder device memory 2 times here, replace the current
|
||||
# implementation here when suitable operators become available.
|
||||
@@ -199,16 +330,17 @@ def native_grouped_topk(
|
||||
|
||||
|
||||
def select_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: Optional[bool] = True
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Select top-k experts based on router logits.
|
||||
@@ -232,8 +364,23 @@ def select_experts(
|
||||
Raises:
|
||||
ValueError: If an unsupported scoring function is provided.
|
||||
"""
|
||||
assert hidden_states.shape[0] == router_logits.shape[0], (
|
||||
"Number of tokens mismatch")
|
||||
# assert hidden_states.shape[0] == router_logits.shape[0], (
|
||||
# "Number of tokens mismatch")
|
||||
# if os.environ.get("VLLM_ENABLE_GRAPH_MODE") == "1" and not is_prefill:
|
||||
# topk_weight, topk_idx, _ = torch.ops.npu_inference.npu_moe_gating_top_k(
|
||||
# router_logits,
|
||||
# k=top_k, # topk当前写8
|
||||
# bias=e_score_correction_bias,
|
||||
# k_group=topk_group, # fix: 4
|
||||
# group_count=num_expert_group, # fix 8
|
||||
# group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
|
||||
# renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
||||
# norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
||||
# # out_flag=False, # todo new api; 第三个输出是否输出
|
||||
# # y2_flag=False, # old api; 第三个输出是否输出
|
||||
# routed_scaling_factor=1,
|
||||
# eps=float(1e-20))
|
||||
# return topk_weight, topk_idx
|
||||
|
||||
if custom_routing_function is not None:
|
||||
raise NotImplementedError(
|
||||
@@ -261,14 +408,16 @@ def select_experts(
|
||||
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
|
||||
topk_weights = native_grouped_topk(topk_weights, num_expert_group,
|
||||
topk_group)
|
||||
|
||||
# TODO bfloat16 is not supported in torch.topk with ge graph.
|
||||
if e_score_correction_bias is not None:
|
||||
topk_ids = torch.topk(topk_weights, k=top_k, dim=-1,
|
||||
topk_ids = torch.topk(topk_weights.to(torch.float32),
|
||||
k=top_k,
|
||||
dim=-1,
|
||||
sorted=False)[1]
|
||||
# Use original unbiased scores for the routing weights
|
||||
topk_weights = original_weights.gather(1, topk_ids)
|
||||
else:
|
||||
topk_weights, topk_ids = torch.topk(topk_weights,
|
||||
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
|
||||
k=top_k,
|
||||
dim=-1,
|
||||
sorted=False)
|
||||
@@ -285,46 +434,245 @@ def select_experts(
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts, "Number of global experts mismatch"
|
||||
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
return fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map)
|
||||
ep_group = get_ep_group()
|
||||
self.ep_size = ep_group.world_size
|
||||
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
self.local_batch_size = self.global_batch_size // self.ep_size
|
||||
|
||||
try:
|
||||
device_group = ep_group.device_group
|
||||
# TODO: Try local_rank = ep_group.rank_in_group
|
||||
local_rank = torch.distributed.get_rank(group=device_group)
|
||||
backend = device_group._get_backend(torch.device("npu"))
|
||||
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
|
||||
local_rank)
|
||||
except AttributeError:
|
||||
self.moe_all_to_all_group_name = None
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
super(UnquantizedFusedMoEMethod,
|
||||
self).process_weights_after_loading(layer)
|
||||
layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight(
|
||||
layer.w13_weight.data),
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight(
|
||||
layer.w2_weight.data),
|
||||
requires_grad=False)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill=False,
|
||||
**kwargs,
|
||||
):
|
||||
# assert router_logits.shape[
|
||||
# 1] == global_num_experts, "Number of global experts mismatch"
|
||||
# set prefill as false always, should fix this
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
is_prefill=is_prefill)
|
||||
|
||||
if os.environ.get("VLLM_ENABLE_MC2") == "1" and not is_prefill:
|
||||
return fused_experts_with_mc2(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map,
|
||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
|
||||
else:
|
||||
return fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map)
|
||||
|
||||
|
||||
UnquantizedFusedMoEMethod.forward_oot = forward_oot
|
||||
class AscendFusedMoE(FusedMoE):
|
||||
|
||||
def __init__(self,
|
||||
num_experts,
|
||||
top_k,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
params_dtype=None,
|
||||
reduce_results=False,
|
||||
renormalize=True,
|
||||
use_grouped_topk=False,
|
||||
num_expert_group=None,
|
||||
topk_group=None,
|
||||
quant_config=None,
|
||||
tp_size=None,
|
||||
ep_size=None,
|
||||
dp_size=None,
|
||||
prefix="",
|
||||
custom_routing_function=None,
|
||||
scoring_func="softmax",
|
||||
e_score_correction_bias=None,
|
||||
activation="silu"):
|
||||
super(FusedMoE, self).__init__()
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
|
||||
self.ep_size = get_ep_group().world_size
|
||||
self.tp_size = get_etp_group().world_size
|
||||
self.dp_size = (dp_size
|
||||
if dp_size is not None else get_dp_group().world_size)
|
||||
self.dp_rank = (0
|
||||
if self.dp_size == 1 else get_dp_group().rank_in_group)
|
||||
|
||||
self.top_k = top_k
|
||||
self.num_experts = num_experts
|
||||
self.global_num_experts = num_experts
|
||||
assert intermediate_size % self.tp_size == 0
|
||||
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
||||
self.reduce_results = reduce_results
|
||||
self.renormalize = renormalize
|
||||
self.use_grouped_topk = use_grouped_topk
|
||||
if self.use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
self.num_expert_group = num_expert_group
|
||||
self.topk_group = topk_group
|
||||
self.custom_routing_function = custom_routing_function
|
||||
self.scoring_func = scoring_func
|
||||
self.e_score_correction_bias = e_score_correction_bias
|
||||
self.expert_map = None
|
||||
self.activation = activation
|
||||
|
||||
if self.ep_size > 1:
|
||||
# Create a tensor of size num_experts filled with -1
|
||||
self.local_num_experts, self.expert_map = determine_expert_map(
|
||||
self.ep_size,
|
||||
get_ep_group().rank_in_group, self.global_num_experts)
|
||||
self.tp_rank = get_etp_group().rank_in_group
|
||||
self.ep_rank = get_ep_group().rank_in_group
|
||||
else:
|
||||
# Adjust TP size for DP attention
|
||||
# haven't test its functionality yet, may remove in the future
|
||||
self.tp_rank = self.tp_size * self.dp_rank
|
||||
self.ep_rank = 0
|
||||
self.tp_size = self.tp_size * self.dp_size
|
||||
self.ep_size = 1
|
||||
self.local_num_experts = self.global_num_experts
|
||||
self.expert_map = None
|
||||
|
||||
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
||||
raise ValueError("Only softmax scoring function is supported for "
|
||||
"non-grouped topk.")
|
||||
|
||||
if quant_config is None:
|
||||
self.quant_method: Optional[QuantizeMethodBase] = (
|
||||
AscendUnquantizedFusedMoEMethod())
|
||||
else:
|
||||
self.quant_method = quant_config.get_quant_method(self, prefix)
|
||||
assert self.quant_method is not None
|
||||
|
||||
local_num_experts = torch.sum(self.expert_map != -1) \
|
||||
if self.expert_map is not None else num_experts
|
||||
|
||||
moe_quant_params = {
|
||||
"num_experts": local_num_experts,
|
||||
"hidden_size": hidden_size,
|
||||
"intermediate_size_per_partition":
|
||||
self.intermediate_size_per_partition,
|
||||
"params_dtype": params_dtype,
|
||||
"weight_loader": self.weight_loader,
|
||||
}
|
||||
# need full intermediate size pre-sharding for WNA16 act order
|
||||
if (self.quant_method.__class__.__name__
|
||||
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
|
||||
moe_quant_params["intermediate_size_full"] = intermediate_size
|
||||
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_prefill: bool,
|
||||
top_k=None):
|
||||
assert self.quant_method is not None
|
||||
|
||||
if top_k:
|
||||
real_top_k = top_k
|
||||
else:
|
||||
real_top_k = self.top_k
|
||||
|
||||
if self.dp_size > 1:
|
||||
if int(os.environ.get("VLLM_ENABLE_MC2") # type: ignore
|
||||
) == 1 and not is_prefill:
|
||||
...
|
||||
elif int(os.environ.get("USING_LCCL_COM")) == 1: # type: ignore
|
||||
hidden_states = get_dp_group().all_gather(
|
||||
hidden_states, 0, False)
|
||||
router_logits = get_dp_group().all_gather(
|
||||
router_logits, 0, False)
|
||||
else:
|
||||
hidden_states = get_dp_group().all_gather(hidden_states, 0)
|
||||
router_logits = get_dp_group().all_gather(router_logits, 0)
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=real_top_k,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
global_num_experts=self.num_experts,
|
||||
expert_map=self.expert_map,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
scoring_func=self.scoring_func,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
is_prefill=is_prefill)
|
||||
|
||||
if self.dp_size > 1:
|
||||
if int(os.environ.get("VLLM_ENABLE_MC2") # type: ignore
|
||||
) == 1 and not is_prefill:
|
||||
...
|
||||
else:
|
||||
final_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
|
||||
final_hidden_states,
|
||||
"sum",
|
||||
scatter_dim=0,
|
||||
group=get_dp_group().device_group)
|
||||
|
||||
# if self.reduce_results and self.tp_size > 1:
|
||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -38,7 +39,7 @@ def rope_forward_oot(
|
||||
if self.cos_sin_cache.dtype != query.dtype:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
|
||||
# adopt custom kernel path for rotary_embedding
|
||||
if CUSTOM_OP_ENABLED and self.is_neox_style:
|
||||
if CUSTOM_OP_ENABLED and self.is_neox_style and self.head_size % 32 == 0:
|
||||
return torch.ops._C.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
@@ -66,5 +67,169 @@ def rope_forward_oot(
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
|
||||
|
||||
def native_rope_deepseek_forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# seq_len = positions.max() + 1
|
||||
seq_len = self.max_position_embeddings
|
||||
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
# if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
|
||||
# self._set_cos_sin_cache(seq_len=seq_len, device=query.device, dtype=query.dtype)
|
||||
self._set_cos_sin_cache(seq_len=seq_len,
|
||||
device=query.device,
|
||||
dtype=query.dtype)
|
||||
|
||||
cos = self.cos_cached[:seq_len].to(dtype=query.dtype)
|
||||
sin = self.sin_cached[:seq_len].to(dtype=query.dtype)
|
||||
|
||||
q_pe, k_pe = apply_rotary_pos_emb(query, key, cos, sin, positions)
|
||||
|
||||
return q_pe, k_pe
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., :x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2:]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Inverse dim formula to find dim based on number of rotations
|
||||
def yarn_find_correction_dim(num_rotations,
|
||||
dim,
|
||||
base=10000,
|
||||
max_position_embeddings=2048):
|
||||
# Note: use torch instead of math to solve MTP compilation error.
|
||||
return (dim * torch.log(
|
||||
torch.tensor(max_position_embeddings) /
|
||||
(num_rotations * 2 * torch.pi))) / (2 * torch.log(torch.tensor(base)))
|
||||
|
||||
|
||||
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
|
||||
# Find dim range bounds based on rotations
|
||||
def yarn_find_correction_range(low_rot,
|
||||
high_rot,
|
||||
dim,
|
||||
base=10000,
|
||||
max_position_embeddings=2048):
|
||||
# Note: use torch instead of math to solve MTP compilation error.
|
||||
low = torch.floor(
|
||||
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
||||
high = torch.ceil(
|
||||
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
|
||||
# Note: use torch instead of max/min to solve MTP compilation error.
|
||||
return torch.clamp(low, min=0), torch.clamp(high, max=dim - 1)
|
||||
|
||||
|
||||
def yarn_linear_ramp_mask(min_value, max_value, dim):
|
||||
# Note: The if conditional branch is not used here
|
||||
# to solve MTP compilation error.
|
||||
max_value += (min_value == max_value).float() * 0.001
|
||||
linear_func = (torch.arange(dim, dtype=torch.float32) -
|
||||
min_value) / (max_value - min_value)
|
||||
ramp_func = torch.clamp(linear_func, 0, 1)
|
||||
return ramp_func
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`):
|
||||
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
||||
used to pass offsetted position ids when working with a KV-cache.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos[position_ids]
|
||||
sin = sin[position_ids]
|
||||
cos = cos[:, None, None, :]
|
||||
sin = sin[:, None, None, :]
|
||||
|
||||
if len(q.shape) == 3:
|
||||
q = q[:, :, None, :]
|
||||
if len(k.shape) == 2:
|
||||
k = k[:, None, None, :]
|
||||
elif len(k.shape) == 3:
|
||||
k = k[:, :, None, :]
|
||||
|
||||
b, h_q, s, d = q.shape
|
||||
q = q.view(b, h_q, s, d // 2, 2).transpose(4, 3).reshape(b, h_q, s, d)
|
||||
|
||||
b, h_k, s, d = k.shape
|
||||
k = k.view(b, h_k, s, d // 2, 2).transpose(4, 3).reshape(b, h_k, s, d)
|
||||
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
|
||||
q_embed = q_embed.view(b, h_q, d)
|
||||
k_embed = k_embed.view(b, h_k, d)
|
||||
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
seq_len = self.max_position_embeddings
|
||||
self.max_seq_len_cached = seq_len
|
||||
dim = self.rotary_dim
|
||||
|
||||
freq_extra = 1.0 / (self.base**(
|
||||
torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
||||
freq_inter = 1.0 / (self.scaling_factor * self.base**(
|
||||
torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
||||
|
||||
low, high = yarn_find_correction_range(
|
||||
self.beta_fast,
|
||||
self.beta_slow,
|
||||
dim,
|
||||
self.base,
|
||||
self.max_position_embeddings,
|
||||
)
|
||||
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
|
||||
device=device, dtype=torch.float32)
|
||||
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
t = torch.arange(seq_len, device=device, dtype=torch.float32)
|
||||
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
|
||||
# _mscale = float(
|
||||
# yarn_get_mscale(self.scaling_factor, self.mscale)
|
||||
# / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
|
||||
# )
|
||||
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", (emb.cos() * self.mscale).to(dtype),
|
||||
persistent=False)
|
||||
self.register_buffer("sin_cached", (emb.sin() * self.mscale).to(dtype),
|
||||
persistent=False)
|
||||
|
||||
|
||||
# TODO: Patch when aclnn ops avaiable
|
||||
RotaryEmbedding.forward_oot = rope_forward_oot
|
||||
DeepseekScalingRotaryEmbedding.forward = rope_forward_oot
|
||||
# DeepseekScalingRotaryEmbedding.forward = rope_deepseek_forward_oot
|
||||
DeepseekScalingRotaryEmbedding.forward = native_rope_deepseek_forward
|
||||
DeepseekScalingRotaryEmbedding._set_cos_sin_cache = _set_cos_sin_cache
|
||||
DeepseekScalingRotaryEmbedding.max_seq_len_cached = None
|
||||
|
||||
67
vllm_ascend/ops/vocab_parallel_embedding.py
Normal file
67
vllm_ascend/ops/vocab_parallel_embedding.py
Normal file
@@ -0,0 +1,67 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import \
|
||||
VocabParallelEmbedding
|
||||
|
||||
|
||||
def get_masked_input_and_mask(
|
||||
input_: torch.Tensor, org_vocab_start_index: int,
|
||||
org_vocab_end_index: int, num_org_vocab_padding: int,
|
||||
added_vocab_start_index: int,
|
||||
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# torch.compile will fuse all of the pointwise ops below
|
||||
# into a single kernel, making it very fast
|
||||
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ <
|
||||
org_vocab_end_index)
|
||||
added_vocab_mask = (input_ >= added_vocab_start_index) & (
|
||||
input_ < added_vocab_end_index)
|
||||
added_offset = added_vocab_start_index - (
|
||||
org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
|
||||
valid_offset = (org_vocab_start_index *
|
||||
org_vocab_mask) + (added_offset * added_vocab_mask)
|
||||
vocab_mask = org_vocab_mask | added_vocab_mask
|
||||
input_ = vocab_mask * (input_ - valid_offset)
|
||||
return input_, ~vocab_mask
|
||||
|
||||
|
||||
def vocab_parallel_embedding_forward(self, input_):
|
||||
if self.tp_size > 1:
|
||||
# Build the mask.
|
||||
masked_input, input_mask = get_masked_input_and_mask(
|
||||
input_, self.shard_indices.org_vocab_start_index,
|
||||
self.shard_indices.org_vocab_end_index,
|
||||
self.shard_indices.num_org_vocab_padding,
|
||||
self.shard_indices.added_vocab_start_index,
|
||||
self.shard_indices.added_vocab_end_index)
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = self.quant_method.embedding(self, masked_input.long())
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
return output
|
||||
|
||||
|
||||
VocabParallelEmbedding.forward = vocab_parallel_embedding_forward
|
||||
@@ -0,0 +1,16 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
@@ -15,4 +15,5 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import vllm_ascend.patch.platform.patch_0_8_4.patch_config # noqa
|
||||
import vllm_ascend.patch.platform.patch_0_8_4.patch_config # noqa
|
||||
import vllm_ascend.patch.platform.patch_0_8_4.patch_distributed # noqa
|
||||
138
vllm_ascend/patch/platform/patch_0_8_4/patch_distributed.py
Normal file
138
vllm_ascend/patch/platform/patch_0_8_4/patch_distributed.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import torch
|
||||
import vllm
|
||||
import vllm.distributed
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
|
||||
_get_default_timeout,
|
||||
is_nccl_available)
|
||||
from torch.distributed.rendezvous import rendezvous
|
||||
from vllm.config import ParallelConfig
|
||||
|
||||
|
||||
def ascend_destroy_model_parallel():
|
||||
"""Set the groups to none and destroy them."""
|
||||
from vllm.distributed.parallel_state import _DP, _PP, _TP
|
||||
if _TP:
|
||||
_TP.destroy()
|
||||
_TP = None
|
||||
|
||||
if _PP:
|
||||
_PP.destroy()
|
||||
_PP = None
|
||||
|
||||
if _DP:
|
||||
_DP.destroy()
|
||||
_DP = None
|
||||
from vllm.platforms import current_platform
|
||||
current_platform.destroy_platform_model_parallel()
|
||||
|
||||
|
||||
def ascend_stateless_init_torch_distributed_process_group(
|
||||
host: str, port: int, rank: int, world_size: int,
|
||||
backend: str) -> ProcessGroup:
|
||||
"""
|
||||
A replacement for `torch.distributed.init_process_group` that does not
|
||||
pollute the global state. The created ProcessGroup object can be used for
|
||||
some operations such as `allreduce`, because it does not depend on the
|
||||
global rank. However, some operations such as `broadcast` cannot be used
|
||||
because it depends on the global rank.
|
||||
|
||||
# TODO: ask for help from PyTorch team if we need the `broadcast` operation.
|
||||
|
||||
This function is useful when we are not sure about the total number of
|
||||
processes in the process group. For example, we may have process
|
||||
1, 2, ..., 8 who want to communicate, and process 9 might be the same
|
||||
process as process 1, or it might be a different process; process 10
|
||||
might be the same process as process 5, or it might be a different process.
|
||||
In this case, how can we reliably form a communication channel within
|
||||
process 9 and 10, without affecting the communication channel within
|
||||
process 1, 2, ..., 8?
|
||||
|
||||
One possible solution is to figure out if process 9 and 10 are the same
|
||||
as process 1 and 5 beforehand, and then form a communication channel
|
||||
based on the information, adjusting the ranks and world_size etc. However,
|
||||
figuring out the information is not always easy, and it will interfere
|
||||
with the main communication channel.
|
||||
|
||||
Our solution is to always form a communication channel with process 1, 2,
|
||||
..., 8, and then use this function to form another communication channel
|
||||
with process 9 and 10. This way, regardless of whether process 9 and 10
|
||||
are the same as process 1 and 5, the main communication channel is
|
||||
always formed with process 1, 2, ..., 8, and the additional communication
|
||||
channel is formed with process 9 and 10.
|
||||
"""
|
||||
init_method = f"tcp://{host}:{port}"
|
||||
backend = Backend(backend) # it is basically string
|
||||
timeout = _get_default_timeout(backend)
|
||||
|
||||
store, rank, world_size = next(
|
||||
rendezvous(init_method, rank, world_size, timeout=timeout))
|
||||
store.set_timeout(timeout)
|
||||
|
||||
group_rank = rank
|
||||
group_size = world_size
|
||||
|
||||
# Use a PrefixStore to avoid accidental overrides of keys used by
|
||||
# different systems (e.g. RPC) in case the store is multi-tenant.
|
||||
prefix_store = PrefixStore(init_method, store)
|
||||
|
||||
pg: ProcessGroup = ProcessGroup(
|
||||
prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
if backend == "gloo":
|
||||
from torch.distributed.distributed_c10d import ProcessGroupGloo
|
||||
backend_class = ProcessGroupGloo(prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
timeout=timeout)
|
||||
backend_type = ProcessGroup.BackendType.GLOO
|
||||
device = torch.device("cpu")
|
||||
elif backend == "nccl":
|
||||
assert is_nccl_available()
|
||||
from torch.distributed.distributed_c10d import ProcessGroupNCCL
|
||||
|
||||
backend_options = ProcessGroupNCCL.Options()
|
||||
backend_options._timeout = timeout
|
||||
|
||||
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
|
||||
backend_options)
|
||||
backend_type = ProcessGroup.BackendType.NCCL
|
||||
device = torch.device("cuda")
|
||||
elif current_platform.platform_has_backend_register():
|
||||
current_platform.platform_register_backend()
|
||||
return pg
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
|
||||
|
||||
pg._set_default_backend(backend_type)
|
||||
backend_class._set_sequence_number_for_group()
|
||||
|
||||
pg._register_backend(device, backend_type, backend_class)
|
||||
|
||||
return pg
|
||||
|
||||
|
||||
def parallel_config_get_dp_port(self) -> int:
|
||||
"""
|
||||
We might need to initialize process groups in multiple
|
||||
processes that is related to data parallelism,
|
||||
e.g. both in the worker and in the engine, which
|
||||
can live in different processes. To avoid port conflicts, we
|
||||
increment the port number each time we need to initialize a
|
||||
new process group related to data parallelism.
|
||||
"""
|
||||
answer = self.data_parallel_master_port
|
||||
self.data_parallel_master_port += 1
|
||||
import os
|
||||
|
||||
# NOTE: Get port from envs directly when using torchrun
|
||||
port = int(os.environ.get("MASTER_PORT", answer)) # type: ignore
|
||||
return port
|
||||
|
||||
|
||||
vllm.distributed.parallel_state.destroy_model_parallel = ascend_destroy_model_parallel
|
||||
vllm.distributed.stateless_init_torch_distributed_process_group = ascend_stateless_init_torch_distributed_process_group
|
||||
ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port
|
||||
@@ -13,4 +13,5 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
#
|
||||
import vllm_ascend.patch.platform.patch_main.patch_distributed # noqa F401
|
||||
138
vllm_ascend/patch/platform/patch_main/patch_distributed.py
Normal file
138
vllm_ascend/patch/platform/patch_main/patch_distributed.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import torch
|
||||
import vllm
|
||||
import vllm.distributed
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
|
||||
_get_default_timeout,
|
||||
is_nccl_available)
|
||||
from torch.distributed.rendezvous import rendezvous
|
||||
from vllm.config import ParallelConfig
|
||||
|
||||
|
||||
def ascend_destroy_model_parallel():
|
||||
"""Set the groups to none and destroy them."""
|
||||
from vllm.distributed.parallel_state import _DP, _PP, _TP
|
||||
if _TP:
|
||||
_TP.destroy()
|
||||
_TP = None
|
||||
|
||||
if _PP:
|
||||
_PP.destroy()
|
||||
_PP = None
|
||||
|
||||
if _DP:
|
||||
_DP.destroy()
|
||||
_DP = None
|
||||
from vllm.platforms import current_platform
|
||||
current_platform.destroy_platform_model_parallel()
|
||||
|
||||
|
||||
def ascend_stateless_init_torch_distributed_process_group(
|
||||
host: str, port: int, rank: int, world_size: int,
|
||||
backend: str) -> ProcessGroup:
|
||||
"""
|
||||
A replacement for `torch.distributed.init_process_group` that does not
|
||||
pollute the global state. The created ProcessGroup object can be used for
|
||||
some operations such as `allreduce`, because it does not depend on the
|
||||
global rank. However, some operations such as `broadcast` cannot be used
|
||||
because it depends on the global rank.
|
||||
|
||||
# TODO: ask for help from PyTorch team if we need the `broadcast` operation.
|
||||
|
||||
This function is useful when we are not sure about the total number of
|
||||
processes in the process group. For example, we may have process
|
||||
1, 2, ..., 8 who want to communicate, and process 9 might be the same
|
||||
process as process 1, or it might be a different process; process 10
|
||||
might be the same process as process 5, or it might be a different process.
|
||||
In this case, how can we reliably form a communication channel within
|
||||
process 9 and 10, without affecting the communication channel within
|
||||
process 1, 2, ..., 8?
|
||||
|
||||
One possible solution is to figure out if process 9 and 10 are the same
|
||||
as process 1 and 5 beforehand, and then form a communication channel
|
||||
based on the information, adjusting the ranks and world_size etc. However,
|
||||
figuring out the information is not always easy, and it will interfere
|
||||
with the main communication channel.
|
||||
|
||||
Our solution is to always form a communication channel with process 1, 2,
|
||||
..., 8, and then use this function to form another communication channel
|
||||
with process 9 and 10. This way, regardless of whether process 9 and 10
|
||||
are the same as process 1 and 5, the main communication channel is
|
||||
always formed with process 1, 2, ..., 8, and the additional communication
|
||||
channel is formed with process 9 and 10.
|
||||
"""
|
||||
init_method = f"tcp://{host}:{port}"
|
||||
backend = Backend(backend) # it is basically string
|
||||
timeout = _get_default_timeout(backend)
|
||||
|
||||
store, rank, world_size = next(
|
||||
rendezvous(init_method, rank, world_size, timeout=timeout))
|
||||
store.set_timeout(timeout)
|
||||
|
||||
group_rank = rank
|
||||
group_size = world_size
|
||||
|
||||
# Use a PrefixStore to avoid accidental overrides of keys used by
|
||||
# different systems (e.g. RPC) in case the store is multi-tenant.
|
||||
prefix_store = PrefixStore(init_method, store)
|
||||
|
||||
pg: ProcessGroup = ProcessGroup(
|
||||
prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
if backend == "gloo":
|
||||
from torch.distributed.distributed_c10d import ProcessGroupGloo
|
||||
backend_class = ProcessGroupGloo(prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
timeout=timeout)
|
||||
backend_type = ProcessGroup.BackendType.GLOO
|
||||
device = torch.device("cpu")
|
||||
elif backend == "nccl":
|
||||
assert is_nccl_available()
|
||||
from torch.distributed.distributed_c10d import ProcessGroupNCCL
|
||||
|
||||
backend_options = ProcessGroupNCCL.Options()
|
||||
backend_options._timeout = timeout
|
||||
|
||||
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
|
||||
backend_options)
|
||||
backend_type = ProcessGroup.BackendType.NCCL
|
||||
device = torch.device("cuda")
|
||||
elif current_platform.platform_has_backend_register():
|
||||
current_platform.platform_register_backend()
|
||||
return pg
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
|
||||
|
||||
pg._set_default_backend(backend_type)
|
||||
backend_class._set_sequence_number_for_group()
|
||||
|
||||
pg._register_backend(device, backend_type, backend_class)
|
||||
|
||||
return pg
|
||||
|
||||
|
||||
def parallel_config_get_dp_port(self) -> int:
|
||||
"""
|
||||
We might need to initialize process groups in multiple
|
||||
processes that is related to data parallelism,
|
||||
e.g. both in the worker and in the engine, which
|
||||
can live in different processes. To avoid port conflicts, we
|
||||
increment the port number each time we need to initialize a
|
||||
new process group related to data parallelism.
|
||||
"""
|
||||
answer = self.data_parallel_master_port
|
||||
self.data_parallel_master_port += 1
|
||||
import os
|
||||
|
||||
# NOTE: Get port from envs directly when using torchrun
|
||||
port = int(os.environ.get("MASTER_PORT", answer)) # type: ignore
|
||||
return port
|
||||
|
||||
|
||||
vllm.distributed.parallel_state.destroy_model_parallel = ascend_destroy_model_parallel
|
||||
vllm.distributed.stateless_init_torch_distributed_process_group = ascend_stateless_init_torch_distributed_process_group
|
||||
ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port
|
||||
@@ -160,6 +160,8 @@ class NPUPlatform(Platform):
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
||||
kv_cache_dtype, block_size, use_v1, use_mla):
|
||||
if use_v1 and use_mla:
|
||||
return "vllm_ascend.attention.mla_v1.AscendMLABackend"
|
||||
if use_v1:
|
||||
return "vllm_ascend.attention.attention_v1.AscendAttentionBackend"
|
||||
if use_mla:
|
||||
@@ -191,3 +193,30 @@ class NPUPlatform(Platform):
|
||||
model configuration.
|
||||
"""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def destroy_platform_model_parallel(cls) -> None:
|
||||
from vllm_ascend.distributed.parallel_state import \
|
||||
destory_ascend_model_parallel
|
||||
destory_ascend_model_parallel()
|
||||
|
||||
@classmethod
|
||||
def platform_has_backend_register(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def platform_register_backend(cls, pg, prefix_store, group_rank,
|
||||
group_size, backend_options,
|
||||
timeout) -> None:
|
||||
from torch.distributed import ProcessGroup, is_hccl_available
|
||||
assert is_hccl_available()
|
||||
import torch_npu # noqa
|
||||
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
|
||||
backend_options = ProcessGroupHCCL.Options()
|
||||
backend_options._timeout = timeout
|
||||
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size,
|
||||
backend_options)
|
||||
device = torch.device("npu")
|
||||
backend_class._set_sequence_number_for_group()
|
||||
backend_type = ProcessGroup.BackendType.CUSTOM
|
||||
pg._register_backend(device, backend_type, backend_class)
|
||||
|
||||
@@ -33,9 +33,7 @@ from vllm.model_executor.layers.quantization import \
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.model_executor.parameter import PerTensorScaleParameter
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
from .quantizer import AscendQuantizer
|
||||
@@ -171,12 +169,10 @@ class AscendLinearMethod(LinearMethodBase):
|
||||
output_size_per_partition,
|
||||
params_dtype)
|
||||
for weight_name, weight_param in weight_dict.items():
|
||||
layer.register_parameter(
|
||||
weight_name,
|
||||
ModelWeightParameter(data=weight_param,
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader))
|
||||
param = torch.nn.Parameter(weight_param, requires_grad=False)
|
||||
set_weight_attrs(param, {"input_dim": 1, "output_dim": 0})
|
||||
layer.register_parameter(weight_name, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
|
||||
pertensor_dict = self.quant_method.get_pertensor_param(params_dtype)
|
||||
for pertensor_name, pertensor_param in pertensor_dict.items():
|
||||
@@ -189,11 +185,10 @@ class AscendLinearMethod(LinearMethodBase):
|
||||
perchannel_dict = self.quant_method.get_perchannel_param(
|
||||
output_size_per_partition, params_dtype)
|
||||
for perchannel_name, perchannel_param in perchannel_dict.items():
|
||||
layer.register_parameter(
|
||||
perchannel_name,
|
||||
ChannelQuantScaleParameter(data=perchannel_param,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader))
|
||||
param = torch.nn.Parameter(perchannel_param, requires_grad=False)
|
||||
set_weight_attrs(param, {"output_dim": 0})
|
||||
layer.register_parameter(perchannel_name, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if hasattr(self.quant_method, "process_weights_after_loading"):
|
||||
@@ -264,48 +259,6 @@ class AscendKVCacheMethod(BaseKVCacheMethod):
|
||||
seq_lens_tensor_cpu=seq_lens_tensor_cpu)
|
||||
|
||||
|
||||
def fused_moe_perchannel_weight_loader(param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
weight_name: str, shard_id: str,
|
||||
expert_id: int) -> None:
|
||||
|
||||
if shard_id not in ("w1", "w2", "w3"):
|
||||
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
|
||||
f"got {shard_id}.")
|
||||
|
||||
# Fetch the dim to shard the parameter/loaded weight
|
||||
# based on the shard id. This will be whatever
|
||||
# dimension intermediate_size_per_partition is used.
|
||||
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
|
||||
|
||||
expert_data = param.data[expert_id]
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
# is_transposed: if the dim to shard the weight
|
||||
# should be flipped. Required by GPTQ, compressed-tensors
|
||||
# should be whatever dimension intermediate_size_per_partition is
|
||||
is_transposed = getattr(param, "is_transposed", False)
|
||||
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
||||
if is_transposed:
|
||||
shard_dim = int(not shard_dim)
|
||||
|
||||
if shard_id == "w2":
|
||||
expert_data.copy_(loaded_weight)
|
||||
elif shard_id in ("w1", "w3"):
|
||||
shard_size = expert_data.shape[shard_dim] // 2
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
|
||||
shard_size)
|
||||
# Narrow parameter and load.
|
||||
# w1, gate_proj: Load into first logical weight of w13.
|
||||
if shard_id == "w1":
|
||||
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
|
||||
# w3, up_proj: Load into second logical weight of w13.
|
||||
else:
|
||||
assert shard_id == "w3"
|
||||
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
class AscendFusedMoEMethod(FusedMoEMethodBase):
|
||||
"""FusedMoE method for Ascend quantization.
|
||||
|
||||
@@ -341,9 +294,6 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
|
||||
# load `offset` weight in `fused_moe_perchannel_weight_loader`, the original weight load in vllm 0.7.3 could only load `scale` and `zero`
|
||||
extra_weight_attrs.update(
|
||||
{"weight_loader": fused_moe_perchannel_weight_loader})
|
||||
dynamic_quant_param = self.quant_method.get_dynamic_quant_param(
|
||||
num_experts, intermediate_size_per_partition, hidden_size,
|
||||
params_dtype)
|
||||
@@ -360,15 +310,19 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
is_prefill: bool = True,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
return self.quant_method.apply(layer, x, use_grouped_topk, top_k,
|
||||
router_logits, renormalize, topk_group,
|
||||
num_expert_group,
|
||||
num_expert_group, global_num_experts,
|
||||
expert_map, is_prefill,
|
||||
custom_routing_function, scoring_func,
|
||||
e_score_correction_bias)
|
||||
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/vllm/worker/worker.py
|
||||
#
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch_npu # noqa: F401
|
||||
from packaging.version import Version
|
||||
@@ -23,6 +25,8 @@ from vllm.logger import logger
|
||||
|
||||
import vllm_ascend.envs as envs
|
||||
|
||||
VLLM_ENABLE_GRAPH_MODE = os.environ.get('VLLM_ENABLE_GRAPH_MODE', '0')
|
||||
|
||||
|
||||
def try_register_lib(lib_name: str, lib_info: str = ""):
|
||||
import importlib
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import vllm_ascend.worker.cache_engine # noqa
|
||||
69
vllm_ascend/worker/cache_engine.py
Normal file
69
vllm_ascend/worker/cache_engine.py
Normal file
@@ -0,0 +1,69 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/vllm/worker/model_runner.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from vllm.utils import is_pin_memory_available
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
|
||||
from vllm_ascend.utils import VLLM_ENABLE_GRAPH_MODE
|
||||
|
||||
|
||||
def allocate_kv_cache(
|
||||
self,
|
||||
num_blocks: int,
|
||||
device: str,
|
||||
) -> List[Tuple]:
|
||||
"""Allocates KV cache on the specified device."""
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
|
||||
pin_memory = is_pin_memory_available() if device == "cpu" else False
|
||||
kv_cache: List[Tuple] = []
|
||||
|
||||
# Align entries so they are 256 byte aligned for better performance
|
||||
# Primarily targets MLA as this typically only ends up having entries
|
||||
# be 128 byte aligned.
|
||||
alloc_shape = kv_cache_shape
|
||||
|
||||
for _ in range(self.num_attention_layers):
|
||||
# null block in CpuGpuBlockAllocator requires at least that
|
||||
# block to be zeroed-out.
|
||||
# We zero-out everything for simplicity.
|
||||
layer_kv_cache_nope = torch.zeros(
|
||||
alloc_shape[:-1] +
|
||||
(self.model_config.hf_text_config.kv_lora_rank, ),
|
||||
dtype=self.dtype,
|
||||
pin_memory=pin_memory,
|
||||
device=device)
|
||||
layer_kv_cache_pe = torch.zeros(
|
||||
alloc_shape[:-1] +
|
||||
(self.model_config.hf_text_config.qk_rope_head_dim, ),
|
||||
dtype=self.dtype,
|
||||
pin_memory=pin_memory,
|
||||
device=device)
|
||||
|
||||
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
|
||||
# when entry_shape is higher than 1D
|
||||
kv_cache.append((layer_kv_cache_nope, layer_kv_cache_pe))
|
||||
return kv_cache
|
||||
|
||||
|
||||
if VLLM_ENABLE_GRAPH_MODE == '1':
|
||||
CacheEngine._allocate_kv_cache = allocate_kv_cache
|
||||
@@ -18,18 +18,21 @@
|
||||
#
|
||||
|
||||
import dataclasses
|
||||
import itertools
|
||||
import weakref
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set,
|
||||
Type, TypeVar, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import CompilationLevel, VllmConfig
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
@@ -53,7 +56,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
||||
from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, flatten_2d_lists,
|
||||
is_pin_memory_available)
|
||||
is_pin_memory_available, supports_dynamo)
|
||||
from vllm.worker.model_runner_base import (
|
||||
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
|
||||
_add_attn_metadata_broadcastable_dict,
|
||||
@@ -72,6 +75,7 @@ if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
TModelInputForNPU = TypeVar('TModelInputForNPU', bound="ModelInputForNPU")
|
||||
ENCODER_NUM = 0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -526,6 +530,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
|
||||
seq_lens = []
|
||||
max_decode_seq_len = 0
|
||||
is_prompt = self.inter_data_list[0].is_prompt
|
||||
for inter_data in self.inter_data_list:
|
||||
seq_lens.extend(inter_data.seq_lens)
|
||||
if not inter_data.is_prompt:
|
||||
@@ -540,7 +545,26 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
for data in self.inter_data_list
|
||||
}
|
||||
|
||||
input_tokens_tensor = torch.tensor(flatten_2d_lists(input_tokens),
|
||||
# Add graph_pad_size here
|
||||
if self.runner.vllm_config.compilation_config.level ==\
|
||||
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
|
||||
graph_pad_size = self.runner.scheduler_config.max_num_seqs - len(
|
||||
seq_lens)
|
||||
else:
|
||||
graph_pad_size = -1
|
||||
|
||||
#print(f"before tensor input_tokens: {input_tokens}")
|
||||
#print(f"before tensor input_positions: {input_positions}")
|
||||
#print(f"before list seq_lens: {seq_lens}")
|
||||
input_tokens = flatten_2d_lists(input_tokens)
|
||||
input_positions = flatten_2d_lists(input_positions)
|
||||
if graph_pad_size != -1 and not is_prompt:
|
||||
input_tokens.extend(itertools.repeat(0, graph_pad_size))
|
||||
input_positions.extend( # type: ignore
|
||||
itertools.repeat(0, graph_pad_size))
|
||||
seq_lens.extend(itertools.repeat(1, graph_pad_size))
|
||||
query_lens.extend(itertools.repeat(1, graph_pad_size))
|
||||
input_tokens_tensor = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
if mrope_input_positions is not None:
|
||||
@@ -548,13 +572,16 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
else:
|
||||
input_positions_tensor = torch.tensor(
|
||||
flatten_2d_lists(input_positions),
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
input_positions_tensor = torch.tensor(input_positions,
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
#print(f"after tensor input_tokens_tensor: {input_tokens_tensor}")
|
||||
#print(f"after tensor input_positions_tensor: {input_positions_tensor}")
|
||||
#print(f"after list seq_lens: {seq_lens}")
|
||||
|
||||
# Attention metadata.
|
||||
attn_metadata = self.attn_metadata_builder.build(seq_lens, query_lens)
|
||||
attn_metadata = self.attn_metadata_builder.build(
|
||||
seq_lens, query_lens, graph_pad_size)
|
||||
|
||||
# LoRA data.
|
||||
lora_requests = set()
|
||||
@@ -582,6 +609,13 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
]
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
|
||||
|
||||
if self.runner.vllm_config.compilation_config.level ==\
|
||||
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
|
||||
torch._dynamo.mark_static(input_tokens_tensor)
|
||||
torch._dynamo.mark_static(input_positions_tensor)
|
||||
torch._dynamo.mark_static(attn_metadata.block_tables)
|
||||
torch._dynamo.mark_static(attn_metadata.slot_mapping)
|
||||
|
||||
return self.model_input_cls(
|
||||
input_tokens=input_tokens_tensor,
|
||||
input_positions=input_positions_tensor,
|
||||
@@ -841,6 +875,12 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
|
||||
|
||||
self.in_profile_run = False
|
||||
|
||||
self.graph_block_tables = np.zeros(
|
||||
(self.vllm_config.scheduler_config.max_num_seqs,
|
||||
(model_config.max_model_len + self.block_size - 1) //
|
||||
self.block_size),
|
||||
dtype=np.int32)
|
||||
|
||||
# Attention-free but stateful models like Mamba need a placeholder attn
|
||||
# backend, as the attention metadata is needed to manage internal state.
|
||||
# However we must bypass attention selection altogether for some models
|
||||
@@ -930,6 +970,26 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
|
||||
)
|
||||
self.model = self.lora_manager.create_lora_manager(self.model)
|
||||
|
||||
# adapter torch compile with npu_backend
|
||||
if self.vllm_config.compilation_config.level ==\
|
||||
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
|
||||
import torchair # type: ignore
|
||||
from torchair import patch_for_hcom # type: ignore
|
||||
|
||||
# 通信算子成图
|
||||
patch_for_hcom()
|
||||
# 设置npu的config,如果不设置config,可以使用默认的,那可以设置npu_backend="npu"
|
||||
config = torchair.CompilerConfig()
|
||||
config.experimental_config.frozen_parameter = True
|
||||
config.experimental_config.tiling_schedule_optimize = True
|
||||
torch.npu.set_compile_mode(jit_compile=False)
|
||||
self.compile_model = torchair.inference.cache_compile(
|
||||
self.model.forward,
|
||||
dynamic=True,
|
||||
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||
config=config,
|
||||
ge_cache=False)
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
path: str,
|
||||
@@ -1219,10 +1279,43 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
||||
self.attn_state.begin_forward(model_input)
|
||||
|
||||
assert model_input.attn_metadata is not None
|
||||
if self.vllm_config.compilation_config.level ==\
|
||||
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
|
||||
torch._dynamo.mark_static(model_input.input_tokens)
|
||||
torch._dynamo.mark_static(model_input.input_positions)
|
||||
torch._dynamo.mark_static(model_input.attn_metadata.block_tables)
|
||||
torch._dynamo.mark_static(model_input.attn_metadata.slot_mapping)
|
||||
torch._dynamo.mark_static(
|
||||
model_input.attn_metadata.query_start_loc)
|
||||
torch._dynamo.mark_static(model_input.attn_metadata.seq_start_loc)
|
||||
for kv in kv_caches:
|
||||
if isinstance(kv, tuple):
|
||||
torch._dynamo.mark_static(kv[0])
|
||||
torch._dynamo.mark_static(kv[1])
|
||||
|
||||
# TODO(andoorve): We can remove this once all
|
||||
# virtual engines share the same kv cache.
|
||||
virtual_engine = model_input.virtual_engine
|
||||
model_executable = self.model
|
||||
prefill_meta = model_input.attn_metadata.prefill_metadata
|
||||
previous_hidden_states = kwargs.get("previous_hidden_states")
|
||||
if prefill_meta is None and self.vllm_config.compilation_config.level > 0:
|
||||
model_executable = self.compile_model
|
||||
# Note: graph_batch_size value not same as GPU
|
||||
graph_batch_size = model_input.input_tokens.shape[ # type: ignore
|
||||
0] # type: ignore
|
||||
# Note: previous_hidden_states maybe None not same as GPU
|
||||
if previous_hidden_states is not None:
|
||||
previous_hidden_states = torch.cat([
|
||||
previous_hidden_states,
|
||||
torch.empty([
|
||||
graph_batch_size - previous_hidden_states.shape[0],
|
||||
*previous_hidden_states.shape[1:]
|
||||
],
|
||||
dtype=previous_hidden_states.dtype,
|
||||
device=previous_hidden_states.device)
|
||||
])
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
# Receive KV cache in distributed KV cache transfer setting
|
||||
# In disagg prefill setting, it will also recv hidden states and bypass
|
||||
@@ -1248,8 +1341,11 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
||||
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
|
||||
} if self.has_inner_state else {}
|
||||
|
||||
previous_hidden_states = kwargs.get("previous_hidden_states")
|
||||
model_kwargs = {}
|
||||
if self.vllm_config.compilation_config.level ==\
|
||||
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
|
||||
model_kwargs = {"inputs_embeds": None}
|
||||
else:
|
||||
model_kwargs = {}
|
||||
if previous_hidden_states is not None:
|
||||
model_kwargs["previous_hidden_states"] = previous_hidden_states
|
||||
|
||||
@@ -1273,44 +1369,30 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
||||
**seqlen_agnostic_kwargs,
|
||||
**model_kwargs)
|
||||
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_end.record()
|
||||
|
||||
# Sending KV cache in distributed KV cache transfer setting
|
||||
# NOTE: the send operation is non-blocking
|
||||
if self.need_send_kv(model_input, kv_caches):
|
||||
get_kv_transfer_group().send_kv_caches_and_hidden_states(
|
||||
# model_executable is used to know which layer the current
|
||||
# worker is working on, so that we can send KV for only those
|
||||
# layers.
|
||||
model_executable,
|
||||
model_input,
|
||||
kv_caches,
|
||||
hidden_or_intermediate_states,
|
||||
)
|
||||
|
||||
# Compute the logits in the last pipeline stage.
|
||||
if not get_pp_group().is_last_rank:
|
||||
if (self.is_driver_worker
|
||||
and hidden_or_intermediate_states is not None
|
||||
and isinstance(hidden_or_intermediate_states,
|
||||
IntermediateTensors)
|
||||
and self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_end.synchronize()
|
||||
model_forward_time = model_forward_start.elapsed_time(
|
||||
model_forward_end)
|
||||
orig_model_forward_time = 0.0
|
||||
if intermediate_tensors is not None:
|
||||
orig_model_forward_time = intermediate_tensors.tensors.get(
|
||||
"model_forward_time", torch.tensor(0.0)).item()
|
||||
hidden_or_intermediate_states.tensors["model_forward_time"] = (
|
||||
torch.tensor(model_forward_time + orig_model_forward_time))
|
||||
return hidden_or_intermediate_states
|
||||
|
||||
logits = self.model.compute_logits(hidden_or_intermediate_states,
|
||||
model_input.sampling_metadata)
|
||||
# Compute the logits in the last pipeline stage.
|
||||
if not get_pp_group().is_last_rank:
|
||||
if (self.is_driver_worker
|
||||
and hidden_or_intermediate_states is not None
|
||||
and isinstance(hidden_or_intermediate_states,
|
||||
IntermediateTensors)
|
||||
and self.observability_config is not None and
|
||||
self.observability_config.collect_model_forward_time):
|
||||
model_forward_end.synchronize()
|
||||
model_forward_time = model_forward_start.elapsed_time(
|
||||
model_forward_end)
|
||||
orig_model_forward_time = 0.0
|
||||
if intermediate_tensors is not None:
|
||||
orig_model_forward_time = intermediate_tensors.tensors.get(
|
||||
"model_forward_time", torch.tensor(0.0)).item()
|
||||
hidden_or_intermediate_states.tensors[
|
||||
"model_forward_time"] = (
|
||||
torch.tensor(model_forward_time +
|
||||
orig_model_forward_time))
|
||||
return hidden_or_intermediate_states
|
||||
# TODO: remove the synchronize here
|
||||
torch.npu.synchronize()
|
||||
logits = self.model.compute_logits(hidden_or_intermediate_states,
|
||||
model_input.sampling_metadata)
|
||||
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
@@ -1348,6 +1430,9 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
||||
hidden_states = hidden_or_intermediate_states.index_select(
|
||||
0, indices)
|
||||
output.prefill_hidden_states = hidden_or_intermediate_states
|
||||
elif self.vllm_config.compilation_config.level == \
|
||||
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
|
||||
hidden_states = hidden_or_intermediate_states[:len(indices)]
|
||||
else:
|
||||
hidden_states = hidden_or_intermediate_states
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
|
||||
import gc
|
||||
import os
|
||||
import weakref
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -47,8 +48,7 @@ from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
from vllm_ascend.attention.attention import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
|
||||
AscendMetadata)
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -104,6 +104,27 @@ class NPUModelRunner:
|
||||
raise NotImplementedError(
|
||||
"Non-Attention backend is not supported by V1 NPUModelRunner.")
|
||||
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.head_size,
|
||||
self.dtype,
|
||||
self.kv_cache_dtype,
|
||||
self.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
use_mla=self.model_config.use_mla,
|
||||
)
|
||||
if self.attn_backend is None:
|
||||
error_msg = (
|
||||
f"Error with get_att_backend: {self.head_size=}, "
|
||||
f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, "
|
||||
f"{self.model_config.is_attention_free=}, "
|
||||
f"{self.model_config.use_mla=}")
|
||||
logger.error(error_msg)
|
||||
raise NotImplementedError(
|
||||
"Non-Attention backend is not supported by V1 GPUModelRunner.")
|
||||
|
||||
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
|
||||
weakref.proxy(self))
|
||||
|
||||
# Multi-modal data support
|
||||
self.input_registry = INPUT_REGISTRY
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
@@ -191,6 +212,12 @@ class NPUModelRunner:
|
||||
pin_memory=True)
|
||||
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
||||
|
||||
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
self.query_start_loc_np = self.query_start_loc_cpu.numpy()
|
||||
|
||||
self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
@@ -200,6 +227,8 @@ class NPUModelRunner:
|
||||
self.input_positions_cpu = torch.arange(0,
|
||||
self.max_num_tokens,
|
||||
device="cpu")
|
||||
self.attn_mask = None
|
||||
self.attn_state = None
|
||||
|
||||
# NOTE: Pre-construct a mask matrix to improve the efficiency of
|
||||
# attention mask construction during inference.
|
||||
@@ -396,7 +425,11 @@ class NPUModelRunner:
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
assert num_reqs > 0
|
||||
|
||||
# Copy the blocks from CPU to NPU.
|
||||
modified_batch = self.attn_metadata_builder.reorder_batch(
|
||||
self.input_batch, scheduler_output)
|
||||
if modified_batch:
|
||||
self.input_batch.refresh_sampling_metadata()
|
||||
|
||||
# OPTIMIZATION: Start copying the block table first.
|
||||
# This way, we can overlap the copy with the following CPU operations.
|
||||
self.input_batch.block_table.commit(num_reqs)
|
||||
@@ -430,14 +463,13 @@ class NPUModelRunner:
|
||||
self.positions[:total_num_scheduled_tokens].copy_(
|
||||
self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
||||
positions = self.positions[:total_num_scheduled_tokens]
|
||||
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
||||
|
||||
self.seq_lens_np[:num_reqs] = (
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
||||
num_scheduled_tokens)
|
||||
seq_lens = self.seq_lens_cpu[:num_reqs]
|
||||
|
||||
query_lens = torch.from_numpy(num_scheduled_tokens)
|
||||
|
||||
block_table_indices = (req_indices * self.max_num_blocks_per_req +
|
||||
positions_np // self.block_size)
|
||||
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
|
||||
@@ -446,8 +478,6 @@ class NPUModelRunner:
|
||||
np.add(block_numbers * self.block_size,
|
||||
block_offsets,
|
||||
out=self.slot_mapping_np[:total_num_scheduled_tokens])
|
||||
slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
|
||||
self.device, non_blocking=True)
|
||||
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
|
||||
@@ -461,15 +491,14 @@ class NPUModelRunner:
|
||||
query_lens=num_scheduled_tokens,
|
||||
position=positions,
|
||||
attn_state=attn_state)
|
||||
self.attn_mask = attn_mask
|
||||
self.attn_state = attn_state # type: ignore
|
||||
|
||||
attn_metadata = AscendMetadata(
|
||||
seq_lens=query_lens,
|
||||
context_lens=seq_lens,
|
||||
slot_mapping=slot_mapping,
|
||||
block_tables=(
|
||||
self.input_batch.block_table.get_device_tensor()[:num_reqs]),
|
||||
attn_mask=attn_mask,
|
||||
attn_state=attn_state,
|
||||
attn_metadata = self.attn_metadata_builder.build( # type: ignore
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
common_prefix_len=None,
|
||||
)
|
||||
|
||||
# Prepare input_ids
|
||||
@@ -804,6 +833,9 @@ class NPUModelRunner:
|
||||
# different GPUs, and `kv_cache_config.num_blocks` is set to
|
||||
# the min of all `num_blocks`. Verify it here.
|
||||
assert num_blocks >= kv_cache_config.num_blocks
|
||||
# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
|
||||
# encounter OOM issue
|
||||
num_blocks = num_blocks // 4
|
||||
if isinstance(kv_cache_spec, FullAttentionSpec):
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
|
||||
@@ -44,6 +44,7 @@ from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
|
||||
WorkerInput)
|
||||
|
||||
from vllm_ascend.device_allocator.camem import CaMemAllocator
|
||||
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.utils import try_register_lib, vllm_version_is
|
||||
from vllm_ascend.worker.model_runner import NPUModelRunner
|
||||
@@ -313,8 +314,14 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
||||
for ve in range(self.parallel_config.pipeline_parallel_size):
|
||||
num_layers = len(self.cache_engine[ve].gpu_cache)
|
||||
for i in range(num_layers):
|
||||
torch_npu.npu_format_cast(self.cache_engine[ve].gpu_cache[i],
|
||||
2)
|
||||
if torch.is_tensor(self.cache_engine[ve].gpu_cache[i]):
|
||||
torch_npu.npu_format_cast(
|
||||
self.cache_engine[ve].gpu_cache[i], 2)
|
||||
else:
|
||||
torch_npu.npu_format_cast(
|
||||
self.cache_engine[ve].gpu_cache[i][0], 2)
|
||||
torch_npu.npu_format_cast(
|
||||
self.cache_engine[ve].gpu_cache[i][1], 2)
|
||||
self.gpu_cache = [
|
||||
self.cache_engine[ve].gpu_cache
|
||||
for ve in range(self.parallel_config.pipeline_parallel_size)
|
||||
@@ -495,6 +502,7 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
||||
backend: str = "hccl") -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
parallel_config = self.parallel_config
|
||||
additional_config = self.vllm_config.additional_config
|
||||
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
||||
init_distributed_environment(parallel_config.world_size, rank,
|
||||
distributed_init_method, local_rank,
|
||||
@@ -502,6 +510,14 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
||||
ensure_model_parallel_initialized(
|
||||
parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
expert_tensor_parallel_size = 1
|
||||
if additional_config is not None and hasattr(
|
||||
additional_config, "expert_tensor_parallel_size"):
|
||||
expert_tensor_parallel_size = getattr(
|
||||
additional_config, "expert_tensor_parallel_size")
|
||||
init_ascend_model_parallel(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size,
|
||||
expert_tensor_parallel_size)
|
||||
ensure_kv_transfer_initialized(vllm_config)
|
||||
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.utils import try_register_lib, vllm_version_is
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
@@ -209,6 +210,8 @@ class NPUWorker(WorkerBase):
|
||||
|
||||
def _init_worker_distributed_environment(self) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
additional_config = self.vllm_config.additional_config
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
set_custom_all_reduce(
|
||||
not self.parallel_config.disable_custom_all_reduce)
|
||||
init_distributed_environment(self.parallel_config.world_size,
|
||||
@@ -217,6 +220,13 @@ class NPUWorker(WorkerBase):
|
||||
ensure_model_parallel_initialized(
|
||||
self.parallel_config.tensor_parallel_size,
|
||||
self.parallel_config.pipeline_parallel_size)
|
||||
expert_tensor_parallel_size = 1
|
||||
if additional_config is not None and "expert_tensor_parallel_size" in additional_config:
|
||||
expert_tensor_parallel_size = int(
|
||||
additional_config["expert_tensor_parallel_size"])
|
||||
init_ascend_model_parallel(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size,
|
||||
expert_tensor_parallel_size)
|
||||
ensure_kv_transfer_initialized(self.vllm_config)
|
||||
|
||||
def _init_profiler(self):
|
||||
|
||||
Reference in New Issue
Block a user