port deepseekv2 and mtp to main branch (#429)

### What this PR does / why we need it?
This PR ports all the deepseek graph mode code and mtp code from v0.7.3
to the main branch
---------

Signed-off-by: SidaoY <1024863041@qq.com>
Signed-off-by: linfeng-yuan <1102311262@qq.com>
Signed-off-by: Yizhou Liu <liuyizhou5@h-partners.com>
Signed-off-by: mengwei805 <mengwei25@huawei.com>
Signed-off-by: libaokui <libaokui@huawei.com>
Signed-off-by: q00832892 <qiaoyang19@huawei.com>
Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
Co-authored-by: SidaoY <1024863041@qq.com>
Co-authored-by: linfeng-yuan <1102311262@qq.com>
Co-authored-by: Yizhou Liu <liuyizhou5@h-partners.com>
Co-authored-by: mengwei805 <mengwei25@huawei.com>
Co-authored-by: libaokui <libaokui@huawei.com>
This commit is contained in:
Pleaplusone
2025-04-19 17:38:18 +08:00
committed by GitHub
parent 086423dc35
commit 1a1f9a6d89
33 changed files with 3361 additions and 315 deletions

View File

@@ -0,0 +1,128 @@
"""
This file demonstrates the example usage of disaggregated prefilling
We will launch 2 vllm instances (NPU 0,1 for prefill and NPU 2,3 for decode),
and then transfer the KV cache between them.
"""
import multiprocessing as mp
import os
import time
from multiprocessing import Event, Process
def clean_up():
import gc
import torch
from vllm.distributed.parallel_state import (
destroy_distributed_environment, destroy_model_parallel)
destroy_model_parallel()
destroy_distributed_environment()
gc.collect()
torch.npu.empty_cache()
def run_prefill(prefill_done, process_close):
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1"
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
prompts = [
"Hello, how are you today?", "Hi, what is your name?",
"Tell me a very long story.", "what is your favourite book?"
]
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
ktc = KVTransferConfig.from_cli(
'{"kv_connector":"AscendHcclConnector","kv_buffer_device":"npu","kv_role":"kv_producer", "kv_parallel_size":2}'
)
# Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB
# memory. You may need to adjust the value to fit your GPU.
llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
kv_transfer_config=ktc,
max_model_len=2000,
gpu_memory_utilization=0.8,
tensor_parallel_size=2)
llm.generate(prompts, sampling_params)
print("Prefill node is finished.")
prefill_done.set()
# To keep the prefill node running in case the decode node is not done;
# otherwise, the script might exit prematurely, causing incomplete decoding.
try:
while not process_close.is_set():
time.sleep(1)
except KeyboardInterrupt:
print("Script stopped by user.")
finally:
print("Cleanup prefill resources")
del llm
clean_up()
def run_decode(prefill_done):
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "2,3"
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
prompts = [
"Hello, how are you today?", "Hi, what is your name?",
"Tell me a very long story.", "what is your favourite book?"
]
sampling_params = SamplingParams(temperature=0, top_p=0.95)
ktc = KVTransferConfig.from_cli(
'{"kv_connector":"AscendHcclConnector","kv_buffer_device":"npu","kv_role":"kv_consumer","kv_parallel_size":2}'
)
llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
kv_transfer_config=ktc,
max_model_len=2000,
gpu_memory_utilization=0.8,
tensor_parallel_size=2)
# Wait for the producer to start the comsumer
print("Waiting for prefill node to finish...")
prefill_done.wait()
# At this point when the prefill_done is set, the kv-cache should have been
# transferred to this decode node, so we can start decoding.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
del llm
clean_up()
if __name__ == "__main__":
mp.get_context('spawn')
prefill_done = Event()
process_close = Event()
prefill_process = Process(target=run_prefill,
args=(
prefill_done,
process_close,
))
decode_process = Process(target=run_decode, args=(prefill_done, ))
# Start prefill node
prefill_process.start()
# Start decode node
decode_process.start()
# Terminate the prefill node when decode is finished
decode_process.join()
# Terminate prefill process
process_close.set()
prefill_process.join()
prefill_process.terminate()
print("All process done!")

View File

@@ -0,0 +1,86 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/examples/offline_inference/data_parallel.py
# SPDX-License-Identifier: Apache-2.0
# usage:
# python examples/offline_inference_data_parallel.py
# we need to have a launcher to create multiple data parallel
# ranks. And each rank will create a vLLM instance to process its own prompts.
import gc
import os
VLLM_ENABLE_GRAPGH_MODE = os.environ.get("VLLM_ENABLE_GRAPH_MODE") == "1"
def main():
dp_rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
dp_size = int(os.environ['WORLD_SIZE'])
master_addr = os.environ['MASTER_ADDR']
master_port = os.environ['MASTER_PORT']
tp_size = 4
etp_size = 2
os.environ["VLLM_DP_RANK"] = str(dp_rank)
os.environ["VLLM_DP_SIZE"] = str(dp_size)
os.environ["VLLM_DP_MASTER_IP"] = master_addr
os.environ["VLLM_DP_MASTER_PORT"] = master_port
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = ",".join(
str(i)
for i in range(local_rank * tp_size, (local_rank + 1) * tp_size))
import torch
import torch_npu # noqa
from vllm import LLM, SamplingParams
from vllm.distributed.parallel_state import (
destroy_distributed_environment, destroy_model_parallel)
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
] * 4
promts_per_rank = len(prompts) // dp_size
start = dp_rank * promts_per_rank
end = start + promts_per_rank
prompts = prompts[start:end]
if len(prompts) == 0:
prompts = ["Placeholder"]
print(f"DP rank {dp_rank} needs to process {len(prompts)} prompts")
num_seqs = len(prompts)
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=4,
min_tokens=4)
# Create an LLM.
llm = LLM(
model="deepseek-ai/DeepSeek-V2-Lite-Chat",
tensor_parallel_size=tp_size,
trust_remote_code=True,
expert_tensor_parallel_size=etp_size,
max_model_len=4096,
max_num_seqs=num_seqs,
compilation_config=1 if VLLM_ENABLE_GRAPGH_MODE else 0,
)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"DP rank {dp_rank}, Prompt: {prompt!r}, "
f"Generated text: {generated_text!r}")
del llm
destroy_model_parallel()
destroy_distributed_environment()
gc.collect()
torch.npu.empty_cache()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,21 @@
export HCCL_IF_IP=${local_ip}
export GLOO_SOCKET_IFNAME=${ifname}
export TP_SOCKET_IFNAME=${ifname}
export HCCL_SOCKET_IFNAME=${ifname}
# dp_size = node_size * dp_per_node
node_size=1
node_rank=0
dp_per_node=2
master_addr=127.0.0.1
master_port=12345
rm -rf ./.torchair_cache/
rm -rf ./dynamo_*
rm -rf /root/ascend/log/debug/plog/*
export VLLM_ENABLE_GRAPH_MODE=0
export VLLM_ENABLE_MC2=0
torchrun --nproc_per_node ${dp_per_node} --nnodes ${node_size} \
--node_rank ${node_rank} --master_addr ${master_addr} --master_port ${master_port} \
data_parallel.py

View File

@@ -0,0 +1,49 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/examples/offline_inference/basic.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from vllm import LLM, SamplingParams
os.environ["VLLM_USE_V1"] = "1"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
if __name__ == "__main__":
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
# Create an LLM.
llm = LLM(model="/data/weights/deepseek-ai/deepseekv3-lite-base-latest",
tensor_parallel_size=2,
enforce_eager=True,
trust_remote_code=True,
max_model_len=1024)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

View File

@@ -26,8 +26,6 @@ import torch
from PIL import Image
from vllm import LLM, SamplingParams
from vllm.config import TaskOption
from vllm.distributed.parallel_state import (destroy_distributed_environment,
destroy_model_parallel)
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
@@ -35,6 +33,15 @@ from vllm.utils import is_list_of
from tests.model_utils import (TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs)
# TODO: remove this part after the patch merged into vllm, if
# we not explicitly patch here, some of them might be effectiveless
# in pytest scenario
from vllm_ascend.utils import adapt_patch # noqa E402
adapt_patch(True)
from vllm.distributed.parallel_state import ( # noqa E402
destroy_distributed_environment, destroy_model_parallel)
_M = TypeVar("_M")

View File

@@ -31,13 +31,14 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType,
MLAAttentionImpl)
from vllm.attention.backends.utils import (CommonAttentionState,
from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
CommonMetadataBuilder,
compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
from vllm_ascend.utils import VLLM_ENABLE_GRAPH_MODE
from vllm_ascend.worker.model_runner import (
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
@@ -222,7 +223,7 @@ class AscendMLAAttentionBackend(AscendAttentionBackend):
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (1, num_blocks, block_size, num_kv_heads * head_size)
return (num_blocks, block_size, num_kv_heads, head_size)
@dataclass
@@ -552,10 +553,33 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
inter_data.block_tables,
)
def _get_graph_runner_block_tables(
self, num_seqs: int,
block_tables: List[List[int]]) -> torch.Tensor:
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
assert max_batch_size >= num_seqs
graph_block_tables = self.runner.graph_block_tables # [:num_seqs]
for i, block_table in enumerate(block_tables):
if block_table:
num_blocks = len(block_table)
if num_blocks <= max_blocks:
graph_block_tables[i, :num_blocks] = block_table
else:
graph_block_tables[
i, :max_blocks] = block_table[:max_blocks]
return torch.from_numpy(graph_block_tables).to(
device=self.runner.device, non_blocking=True)
def build(
self,
seq_lens: List[int],
query_lens: List[int],
graph_pad_size: int,
):
"""Build attention metadata with on-device tensors.
@@ -568,6 +592,7 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
self.input_builder.chunked_prefill_enabled)
device = self.runner.device
use_torchair_graph = graph_pad_size != -1
max_query_len = max(query_lens)
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
@@ -582,12 +607,36 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
self.attn_mask = None
num_decode_tokens = self.num_decode_tokens
if self.num_prefills == 0 and use_torchair_graph:
num_seqs = len(seq_lens)
self.slot_mapping.extend([PAD_SLOT_ID] * graph_pad_size)
self.block_tables.extend([[]] * graph_pad_size)
block_tables = self._get_graph_runner_block_tables(
num_seqs, self.block_tables)
else:
block_tables = make_tensor_with_pad(
self.block_tables,
pad=0,
dtype=torch.int32,
device=device,
)
if self.num_prefills > 0:
self.attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
max_prefill_seq_len,
self.input_builder.runner.model_config.dtype,
self.input_builder.runner.device)
else:
self.attn_mask = None
num_decode_tokens = self.num_decode_tokens
block_tables = make_tensor_with_pad(
self.block_tables,
pad=0,
dtype=torch.int32,
device=device,
)
assert max_query_len > 0, "query_lens: {}".format(query_lens)
assert device is not None
@@ -855,14 +904,100 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
self.q_proj = extra_impl_args['q_proj']
self.kv_b_proj = extra_impl_args['kv_b_proj']
self.o_proj = extra_impl_args['o_proj']
self.kv_a_proj_with_mqa = extra_impl_args.get('kv_a_proj_with_mqa',
None)
self.kv_a_layernorm = extra_impl_args.get('kv_a_layernorm', None)
self.k_pe_cache = None
self.k_nope_cache = None
self.w_kc = None
self.w_vc = None
def exec_kv(
self,
hidden_states: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
kv_cache: Tuple,
slots: torch.Tensor,
):
B = hidden_states.shape[0]
N = self.num_kv_heads
S = 1
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
k_pe, k_nope = torch.ops.npu_inference.npu_kv_rmsnorm_rope_cache(
kv,
self.kv_a_layernorm.weight,
cos,
sin,
slots.to(torch.int64),
kv_cache[1],
kv_cache[0],
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode="PA",
)
return k_pe, k_nope
def apply_rotary_emb(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool,
) -> torch.Tensor:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if is_neox_style:
return torch.cat((o1, o2), dim=-1)
else:
return torch.stack((o1, o2), dim=-1).flatten(-2)
def rope_single(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
B, N, D = x.shape
S = 1
x = x.view(B, N, S, D)
x = torch.ops.npu_inference.npu_interleave_rope(x, cos, sin)
return x.view(B, N, D)
def process_weights_after_loading(self, act_dtype: torch.dtype):
if self.w_kc is None or self.w_vc is None:
kv_b_proj_weight = self.kv_b_proj.weight.reshape(
self.num_heads, self.qk_nope_head_dim + self.v_head_dim,
self.kv_lora_rank)
self.w_kc = kv_b_proj_weight[:, :self.
qk_nope_head_dim, :].contiguous()
self.w_vc = kv_b_proj_weight[:,
self.qk_nope_head_dim:, :].transpose(
1, 2).contiguous()
def forward(
self,
layer: AttentionLayer,
hidden_states_or_q_c: torch.Tensor,
kv_c_normed: torch.Tensor,
hidden_states_or_kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AscendMetadata,
@@ -873,7 +1008,7 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
Args:
hidden_states_or_q_c: shape = [num_tokens, num_heads * head_size]
num_tokens = batch_size * seq_len
kv_c_normed: shape = [num_tokens, num_kv_heads * head_size]
hidden_states_or_kv_c_normed: shape = [num_tokens, num_kv_heads * head_size]
k_pe: shape = [num_tokens, num_kv_heads * head_size]
kv_cache: shape = [1, num_blocks, block_size,
num_kv_heads * head_size]
@@ -889,71 +1024,86 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
"are not implemented for "
"PallasAttentionBackendImpl")
if attn_metadata is None:
# for profile run
return hidden_states_or_q_c
num_tokens = hidden_states_or_q_c.shape[0]
q = self.q_proj(hidden_states_or_q_c)[0].view(-1, self.num_heads,
self.qk_head_dim)
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
dim=-1)
if k_pe is None and attn_metadata.decode_metadata:
seq_len = self.rotary_emb.max_position_embeddings
k_pe = k_pe.view(num_tokens, self.num_kv_heads, -1)
cos = self.rotary_emb.cos_cached[:seq_len].to(dtype=q_pe.dtype)
sin = self.rotary_emb.sin_cached[:seq_len].to(dtype=q_pe.dtype)
cos = cos[attn_metadata.input_positions]
sin = sin[attn_metadata.input_positions]
cos = cos[:, None, None, :]
sin = sin[:, None, None, :]
if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding':
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
q_pe = q_pe.reshape(num_tokens, -1)
k_pe = k_pe.reshape(num_tokens, -1)
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe,
k_pe)
q_pe = q_pe.view(ori_q_pe_shape)
k_pe = k_pe.view(ori_k_pe_shape)
q_pe = self.rope_single(q_pe, cos, sin)
k_pe, k_nope = self.exec_kv(hidden_states_or_kv_c_normed, cos, sin,
kv_cache, attn_metadata.slot_mapping)
else:
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe,
k_pe)
if self.w_kc is None or self.w_vc is None:
kv_b_proj_weight = self.kv_b_proj.weight.reshape(
self.num_heads, self.qk_nope_head_dim + self.v_head_dim,
self.kv_lora_rank)
self.w_kc = kv_b_proj_weight[:, :self.
qk_nope_head_dim, :].contiguous()
self.w_vc = kv_b_proj_weight[:,
self.qk_nope_head_dim:, :].transpose(
1, 2).contiguous()
if k_pe is None:
# NOTE: k_pe is None when graph mode enabled
kv_c, k_pe = self.kv_a_proj_with_mqa(
hidden_states_or_kv_c_normed)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
else:
kv_c_normed = hidden_states_or_kv_c_normed
k_pe = k_pe.view(num_tokens, self.num_kv_heads, -1)
if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding':
# NOTE: When scaling not specified
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
q_pe = q_pe.reshape(num_tokens, -1)
k_pe = k_pe.reshape(num_tokens, -1)
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions,
q_pe, k_pe)
q_pe = q_pe.view(ori_q_pe_shape)
k_pe = k_pe.view(ori_k_pe_shape)
else:
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions,
q_pe, k_pe)
if attn_metadata.num_prefills > 0:
kv_heads_num = self.num_heads
kv = self.kv_b_proj(kv_c_normed)[0].view(num_tokens, kv_heads_num,
-1)
kv = self.kv_b_proj(kv_c_normed)[0].view(num_tokens,
self.num_heads, -1)
k_nope, value = kv.split([self.qk_nope_head_dim, self.v_head_dim],
dim=-1)
k_cache = torch.cat(
[kv_c_normed.view(num_tokens, self.num_kv_heads, -1), k_pe],
dim=2)
k_pe = k_pe.expand(-1, self.num_heads, -1)
key = torch.cat([k_nope.view(num_tokens, kv_heads_num, -1), k_pe],
dim=2)
else:
kv_heads_num = self.num_kv_heads
q_nope_t = torch.transpose(q_nope, 0, 1)
q_nope_out = torch.bmm(q_nope_t, self.w_kc)
q_nope = torch.transpose(q_nope_out, 0, 1)
k_cache = torch.cat(
[kv_c_normed.view(num_tokens, self.num_kv_heads, -1), k_pe],
dim=2)
query = torch.cat([q_nope, q_pe], dim=-1).view(num_tokens,
self.num_heads, -1)
if kv_cache.numel() > 0:
key_cache = kv_cache[0]
num_blocks, block_size, _ = key_cache.shape
key_cache = key_cache.view(
num_blocks, block_size, self.num_kv_heads,
self.qk_rope_head_dim + self.kv_lora_rank)
slots = attn_metadata.slot_mapping
torch_npu._npu_reshape_and_cache_siso(key=k_cache,
key_cache=key_cache,
slot_indices=slots)
# TODO: Replace the env with more flexible expressions
if VLLM_ENABLE_GRAPH_MODE == '1':
if len(kv_cache) > 0 and kv_cache[0].numel(
) > 0 and attn_metadata.num_prefills > 0:
slots = attn_metadata.slot_mapping
# NOTE: Seperate the kv cache in advance to avoid OOM or other issues
torch_npu._npu_reshape_and_cache(key=kv_c_normed.view(
num_tokens, self.num_kv_heads, -1),
value=k_pe,
key_cache=kv_cache[0],
value_cache=kv_cache[1],
slot_indices=slots)
else:
if kv_cache.numel() > 0:
key = torch.cat([
kv_c_normed.view(num_tokens, self.num_kv_heads, -1), k_pe
],
dim=2)
slots = attn_metadata.slot_mapping
torch_npu._npu_reshape_and_cache_siso(key=key,
key_cache=kv_cache,
slot_indices=slots)
if attn_metadata.num_prefills > 0:
attn_output = torch.empty(num_tokens,
@@ -964,12 +1114,15 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
if (attn_metadata.block_tables is None
or attn_metadata.block_tables.numel() == 0):
assert attn_metadata.attn_mask is not None
mask = attn_metadata.attn_mask
assert attn_metadata.prefill_metadata is not None
assert attn_metadata.prefill_metadata.seq_lens is not None
mask = attn_metadata.attn_mask
self.seq_lens_tensor_cpu = torch.from_numpy(
np.array(attn_metadata.prefill_metadata.seq_lens).astype(
np.int32))
k_pe = k_pe.repeat(1, self.num_heads, 1)
key = torch.cat(
[k_nope.view(num_tokens, self.num_heads, -1), k_pe], dim=2)
torch_npu._npu_flash_attention(
query=query,
key=key,
@@ -987,29 +1140,55 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
)
elif attn_metadata.decode_metadata:
assert kv_cache is not None
# if torch.empty is used here, the preemptive scheduling case of
# test_mtp_correctness.py will fail to run.
attn_output = torch.randn(
[num_tokens, self.num_heads, self.kv_lora_rank],
dtype=query.dtype,
device=query.device)
self.seq_lens_tensor_cpu = torch.from_numpy(
np.array(attn_metadata.decode_metadata.seq_lens).astype(
np.int32))
block_tables = attn_metadata.decode_metadata.block_tables
torch_npu._npu_paged_attention_mla(
query=query,
key_cache=key_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=block_tables,
context_lens=self.seq_lens_tensor_cpu,
mla_vheadsize=self.kv_lora_rank,
out=attn_output)
attn_output_t = torch.transpose(attn_output, 0, 1)
attn_output_t = torch.bmm(attn_output_t, self.w_vc)
attn_output = torch.transpose(attn_output_t, 0, 1)
if VLLM_ENABLE_GRAPH_MODE == '1':
# TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim]
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
q_nope,
k_nope,
k_nope,
query_rope=q_pe,
key_rope=k_pe,
num_heads=self.num_heads,
num_key_value_heads=1,
input_layout="BNSD",
atten_mask=attn_metadata.attn_mask,
scale=self.scale,
antiquant_mode=0,
antiquant_scale=None,
block_table=attn_metadata.block_tables,
block_size=kv_cache[0].shape[1],
actual_seq_lengths_kv=attn_metadata.seq_lens,
)
attn_output = attn_output.view(num_tokens, -1,
self.kv_lora_rank).transpose(
0, 1)
attn_output = torch.bmm(attn_output, self.w_vc).transpose(0, 1)
else:
# if torch.empty is used here, the preemptive scheduling case of
# test_mtp_correctness.py will fail to run.
attn_output = torch.randn(
[num_tokens, self.num_heads, self.kv_lora_rank],
dtype=query.dtype,
device=query.device)
self.seq_lens_tensor_cpu = torch.from_numpy(
np.array(attn_metadata.decode_metadata.seq_lens).astype(
np.int32))
block_tables = attn_metadata.decode_metadata.block_tables
torch_npu._npu_paged_attention_mla(
query=query,
key_cache=kv_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=block_tables,
context_lens=self.seq_lens_tensor_cpu,
mla_vheadsize=self.kv_lora_rank,
out=attn_output)
attn_output_t = torch.transpose(attn_output, 0, 1)
attn_output_t = torch.bmm(attn_output_t, self.w_vc)
attn_output = torch.transpose(attn_output_t, 0, 1)
output, _ = self.o_proj(attn_output.reshape(num_tokens, -1))

View File

@@ -24,6 +24,10 @@ import torch_npu
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm_ascend.ops.attention import vanilla_chunked_prefill
class AscendAttentionBackend(AttentionBackend):
@@ -44,6 +48,10 @@ class AscendAttentionBackend(AttentionBackend):
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]:
return AscendAttentionMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
@@ -94,11 +102,11 @@ class AscendAttentionState(Enum):
class AscendMetadata:
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
block_tables: Optional[torch.Tensor]
block_tables: torch.Tensor
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]] = None
context_lens: Optional[List[int]] = None
query_lens: torch.Tensor
seq_lens: torch.Tensor
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int] = None
# (num_tokens,). The indices of the token slots that input tokens will be
@@ -117,6 +125,36 @@ class AscendMetadata:
attn_mask: Optional[torch.Tensor] = None
class AscendAttentionMetadataBuilder:
def __init__(self, runner):
self.runner = runner
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return False
def build(self, num_reqs, num_actual_tokens, max_query_len,
common_prefix_len):
block_table = (
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
query_lens = self.runner.query_lens
seq_lens = self.runner.seq_lens_cpu[:num_reqs]
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
self.runner.device, non_blocking=True)
attn_mask = self.runner.attn_mask
attn_state = self.runner.attn_state
attn_metadata = AscendMetadata(block_tables=block_table,
query_lens=query_lens,
seq_lens=seq_lens,
max_query_len=max_query_len,
slot_mapping=slot_mapping,
attn_mask=attn_mask,
attn_state=attn_state)
return attn_metadata
class AscendAttentionBackendImpl(AttentionImpl):
def __init__(
@@ -229,29 +267,46 @@ class AscendAttentionBackendImpl(AttentionImpl):
out=output)
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
block_tables = attn_metadata.block_tables
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=block_tables,
context_lens=attn_metadata.context_lens,
out=output)
torch_npu._npu_paged_attention(query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
# Normal V1 situation.
else:
# use paged attention
torch_npu._npu_paged_attention_splitfuse(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
mask=attn_metadata.attn_mask,
block_table=attn_metadata.block_tables,
seq_len=attn_metadata.seq_lens,
context_lens=attn_metadata.context_lens,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
out=output)
# use chunked prefill for head size 192 scenario, like deepseek
# paged_attention_splitfuse maybe crash at such scenario
# TODO: vanilla path will be removed after the kernel support
# head_size 192 scenario
if self.head_size == 192:
cu_seqlen_q = [0] + attn_metadata.query_lens.tolist()
cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist()
cu_seqlen_q = torch.tensor(cu_seqlen_q, device="npu")
cu_seqlen_k = torch.tensor(cu_seqlen_k, device="npu")
cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0)
cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0)
max_seqlen_q = torch.max(attn_metadata.query_lens)
max_seqlen_k = torch.max(attn_metadata.seq_lens)
vanilla_chunked_prefill(output, query, self.key_cache,
self.value_cache,
attn_metadata.block_tables,
cu_seqlen_q, cu_seqlen_k, max_seqlen_q,
max_seqlen_k, self.scale, None, True)
else:
torch_npu._npu_paged_attention_splitfuse(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
mask=attn_metadata.attn_mask,
block_table=attn_metadata.block_tables,
seq_len=attn_metadata.query_lens,
context_lens=attn_metadata.seq_lens,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
out=output)
return output.view(num_tokens, self.hidden_size)

View File

@@ -0,0 +1,561 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar
import torch
import torch_npu
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata,
MLAAttentionImpl)
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear,
UnquantizedLinearMethod)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
from vllm_ascend.ops.cache import concat_and_cache_mla
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
logger = init_logger(__name__)
class AscendMLABackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_name() -> str:
return "VLLM_ASCEND_MLA"
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return AscendMLAMetadata
@staticmethod
def get_builder_cls():
return AscendMLAMetadataBuilder
@staticmethod
def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int,
head_size: int) -> tuple[int, ...]:
return (num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_impl_cls() -> Type["MLAAttentionImpl"]:
return AscendMLAImpl
@dataclass
class AscendMLAPrefillMetadata:
""" Prefill Specific Metadata for Ascend"""
attn_mask: torch.Tensor
query_lens: list[int]
context_lens: torch.Tensor
input_positions: torch.Tensor
block_table: torch.Tensor
max_query_len: int
max_context_len: int
@dataclass
class AscendMLADecodeMetadata:
# Input positions for rotrary embeddings since for MLA the rotary
# position embeddings are applied inside the attention backend
input_positions: torch.Tensor
block_table: torch.Tensor
seq_lens: torch.Tensor
@dataclass
class AscendMLAMetadata:
"""Metadata for MLACommon.
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens: int # Number of tokens excluding padding.
slot_mapping: torch.Tensor
# New for MLA (compared to FlashAttention)
# For handling prefill decode split
num_decodes: int
num_decode_tokens: int
num_prefills: int
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.
# The dimension of the attention heads
head_dim: Optional[int] = None
attn_mask: torch.Tensor = None
# chunked prefill by default if no attn_states passed
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
decode: Optional[AscendMLADecodeMetadata] = None
prefill: Optional[AscendMLAPrefillMetadata] = None
def __post_init__(self):
pass
# supported_head_sizes = AscendMLABackend.get_supported_head_sizes()
# if self.head_dim is not None and self.head_dim \
# not in supported_head_sizes:
# raise ValueError(
# f"Only {supported_head_sizes} are supported for head_dim,",
# f"received {self.head_dim}.")
M = TypeVar("M", bound=AscendMLAMetadata)
class AscendMLAMetadataBuilder:
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
# _attn_mask_builder = None
def __init__(self,
runner: "NPUModelRunner",
metadata_cls: Optional[AscendMLAMetadata] = None):
self.metadata_cls: Optional[AscendMLAMetadata] = metadata_cls \
if metadata_cls is not None else AscendMLAMetadata # type: ignore
self.runner = runner
scheduler_config = runner.scheduler_config
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
# self.attn_mask = None
# if AscendMLAMetadataBuilder._attn_mask_builder is None:
# AscendMLAMetadataBuilder._attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
# 128, self.runner.model_config.dtype
# )
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
# We now want to reorder the batch so that the "decode" requests are at
# the front and the "prefill" requests are at the using the least amount
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
# where attention is likely memory-bound and "prefill" to mean requests
# where attention is likely compute-bound, TODO(lucas): figure out a
# better naming here)
decodes = []
prefills = []
num_decode_tokens = 0
num_prefill_tokens = 0
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
# for now treat 1 scheduled token as "decode" even if its not,
# we should update this to something like < 8 in the future but
# currently the TritonMLA._forward_decode only supports
# num_tokens = 1
if num_tokens == 1:
decodes.append(i)
num_decode_tokens += num_tokens
else:
prefills.append(i)
num_prefill_tokens += num_tokens
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
first_prefill = 0
modified_batch = False
for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
if decodes[num_decodes - i] >= num_decodes:
input_batch.swap_states(prefills[first_prefill],
decodes[num_decodes - i])
first_prefill += 1
modified_batch = True
else:
break
# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
self._num_decodes = num_decodes
self._num_prefills = num_prefills
self._num_decode_tokens = num_decode_tokens
self._num_prefill_tokens = num_prefill_tokens
return modified_batch
def build(self,
num_reqs: int,
num_actual_tokens: int,
max_query_len: int,
common_prefix_len: Optional[int] = None) -> AscendMLAMetadata:
assert self._num_decodes + self._num_prefills == num_reqs
# Note(simon): be careful about the CPU <> GPU memory movement in this
# function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels.
device = self.runner.device
block_table = (
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
device, non_blocking=True).long()
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
device, non_blocking=True).long()
seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs]
query_lens = seq_lens_cpu - self.runner.input_batch.num_computed_tokens_cpu_tensor[:
num_reqs]
seq_lens = seq_lens_cpu
max_query_len = query_lens.max().item()
max_context_len = seq_lens.max().item()
prefill_metadata = None
if self._num_prefills > 0:
reqs_start = self._num_decodes # prefill_start
tokens_start = self._num_decode_tokens
prefill_metadata = AscendMLAPrefillMetadata(
attn_mask=self.runner.attn_mask,
query_lens=query_lens[tokens_start:],
context_lens=seq_lens[tokens_start:],
input_positions=input_positions[tokens_start:],
block_table=block_table[reqs_start:, ...],
max_query_len=max_query_len,
max_context_len=max_context_len,
)
decode_metadata = None
if self._num_decodes > 0:
decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions[:self._num_decode_tokens],
block_table=block_table[:self._num_decode_tokens, ...],
seq_lens=seq_lens[:self._num_decode_tokens])
return self.metadata_cls( # type: ignore
num_actual_tokens=num_actual_tokens,
slot_mapping=slot_mapping,
head_dim=self.runner.model_config.get_head_size(),
num_decodes=self._num_decodes,
num_decode_tokens=self._num_decode_tokens,
num_prefills=self._num_prefills,
attn_mask=self.runner.attn_mask,
attn_state=self.runner.attn_state,
prefill=prefill_metadata,
decode=decode_metadata,
)
class AscendMLAImpl(MLAAttentionImpl):
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
# MLA Specific Arguments
q_lora_rank: Optional[int],
kv_lora_rank: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
qk_head_dim: int,
v_head_dim: int,
rotary_emb: RotaryEmbedding,
# q_proj should be q_b_proj if q_lora_rank is not None, but from an
# attention backend perspective we rely on the layer to pass in the
# correct matrix
q_proj: ColumnParallelLinear,
kv_b_proj: ColumnParallelLinear,
o_proj: RowParallelLinear,
**kwargs,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_head_dim
self.v_head_dim = v_head_dim
# Hack for V1 for now to avoid torch library overhead (since we are
# already inside an attention custom op), pull out the forward
# method from the rotary embedding and call it directly
# TODO(lucas): we should probably find a cleaner way to do this
self.rotary_emb = rotary_emb.forward_native
self.q_proj = q_proj
self.kv_b_proj = kv_b_proj
self.o_proj = o_proj
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3
# self.flash_attn_varlen_func = flash_attn_varlen_func
# if self.vllm_flash_attn_version is not None:
# self.flash_attn_varlen_func = \
# functools.partial(flash_attn_varlen_func,
# fa_version=self.vllm_flash_attn_version)
def _v_up_proj_and_o_proj(self, x):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
return self.o_proj(x)[0]
# Return `ql_nope`, `q_pe`
def _q_proj_and_k_up_proj(self, x):
q_nope, q_pe = self.q_proj(x)[0]\
.view(-1, self.num_heads, self.qk_head_dim)\
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
return ql_nope.transpose(0, 1), q_pe
def process_weights_after_loading(self, act_dtype: torch.dtype):
def get_layer_weight(layer):
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
for attr in WEIGHT_NAMES:
if hasattr(layer, attr):
return getattr(layer, attr)
raise AttributeError(
f"Layer '{layer}' has no recognized weight attribute:"
f" {WEIGHT_NAMES}.")
def get_and_maybe_dequant_weights(layer: LinearBase):
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
# NOTE: This should only be used offline, since it's O(N^3)
eye = torch.eye(layer.input_size_per_partition,
dtype=act_dtype,
device=get_layer_weight(layer).device)
dequant_weights = layer.quant_method.apply(layer,
eye,
bias=None)
del eye
# standardize to (output, input)
return dequant_weights.T
return layer.weight
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
f"{kv_b_proj_weight.shape=}, "
f"{self.kv_lora_rank=}, "
f"{self.num_heads=}, "
f"{self.qk_nope_head_dim=}, "
f"{self.v_head_dim=}")
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.num_heads,
self.qk_nope_head_dim + self.v_head_dim,
)
W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
# Convert from (L, N, V) to (N, L, V)
self.W_UV = W_UV.transpose(0, 1)
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0)
def _forward_prefill(
self,
query: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: AscendMLAMetadata,
) -> torch.Tensor:
assert attn_metadata.prefill is not None
# TODO: enable this compute for flash attention computation
# kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
# -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
# k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
# key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
# v_padded = torch.nn.functional.pad(v, [0, query.shape[-1] - v.shape[-1]],
# value=0)
num_tokens = query.size(0)
attn_output = torch.empty(num_tokens,
self.num_heads,
self.v_head_dim,
dtype=query.dtype,
device=query.device)
# current requests is chunked in prefill, disable flash attention with chunked prefill
vanilla_chunked_prefill_mla(
output=attn_output,
query=query,
kv_cache=kv_c_and_k_pe_cache,
block_tables=attn_metadata.prefill.block_table,
query_lens=attn_metadata.prefill.query_lens,
context_lens=attn_metadata.prefill.context_lens,
kv_b_proj=self.kv_b_proj,
max_query_len=attn_metadata.prefill.max_query_len,
max_context_len=attn_metadata.prefill.max_context_len,
nope_dim=self.qk_nope_head_dim,
rope_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
scale=self.scale,
alibi_slopes=None,
causal=True)
attn_output = attn_output.view(
[num_tokens, self.num_heads * self.v_head_dim])
return self.o_proj(attn_output)[0]
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: AscendMLAMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
decode_meta = attn_metadata.decode
assert decode_meta is not None
q = torch.cat([q_nope, q_pe], dim=-1)
num_tokens = q.size(0)
attn_output = torch.randn(
[num_tokens, self.num_heads, self.kv_lora_rank],
dtype=q.dtype,
device=q.device)
torch_npu._npu_paged_attention_mla(
query=q,
key_cache=kv_c_and_k_pe_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.decode.block_table, # type:ignore
context_lens=attn_metadata.decode.seq_lens, # type:ignore
mla_vheadsize=self.kv_lora_rank,
out=attn_output)
return self._v_up_proj_and_o_proj(attn_output)
def forward(
self,
layer: AttentionLayer,
hidden_states_or_q_c: torch.Tensor, # query in unified attn
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
attn_metadata: M,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert output is not None, "Output tensor must be provided."
if attn_metadata is None:
# Profiling run.
return output
num_actual_toks = attn_metadata.num_actual_tokens
# Inputs and outputs may be padded for CUDA graphs
output_padded = output
output = output[:num_actual_toks, ...]
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
k_c_normed = k_c_normed[:num_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...]
# Restore head dim (for rotary embedding)
k_pe = k_pe.unsqueeze(1)
assert attn_metadata.num_decodes is not None and \
attn_metadata.num_prefills is not None and \
attn_metadata.num_decode_tokens is not None
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
decode_k_pe = k_pe[:num_decode_tokens]
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
prefill_k_pe = k_pe[num_decode_tokens:]
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
if has_decode:
assert attn_metadata.decode is not None
decode_ql_nope, decode_q_pe = \
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
attn_metadata.decode.input_positions, decode_q_pe.contiguous(),
decode_k_pe)
if has_prefill:
assert attn_metadata.prefill is not None
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
.view(-1, self.num_heads, self.qk_head_dim)
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
attn_metadata.prefill.input_positions,
prefill_q_pe.contiguous(), prefill_k_pe)
if kv_cache.numel() > 0:
concat_and_cache_mla(k_c_normed, k_pe, kv_cache,
attn_metadata.slot_mapping.flatten())
# TODO: replaced back to ascend ops
# key = torch.cat([k_c_normed.view([num_actual_toks, self.num_kv_heads, -1]), k_pe], dim=2)
# torch_npu._npu_reshape_and_cache_siso(
# key=key,
# key_cache=kv_cache,
# slot_indices=attn_metadata.slot_mapping.flatten())
if has_prefill:
output[num_decode_tokens:] = self._forward_prefill(
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
attn_metadata)
if has_decode:
output[:num_decode_tokens] = self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
return output_padded

View File

@@ -462,4 +462,4 @@ class LLMDataDistConnector(KVConnectorBase):
def close(self, ):
self.llm_datadist_engine.data_dist.unlink_clusters([self.cluster],
5000)
5000)

View File

@@ -0,0 +1,75 @@
from typing import Optional
import torch
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
init_model_parallel_group)
# vllm-ascend will maintain its own EP GroupCoordinator and ETP GroupCoordinator for
# customize parallel solution
_EP: Optional[GroupCoordinator] = None
_ETP: Optional[list[GroupCoordinator]] = None
def get_ep_group() -> GroupCoordinator:
assert _EP is not None, ("expert model parallel group is not initialized")
return _EP
def get_etp_group() -> GroupCoordinator:
assert _ETP is not None, (
"expert tensor parallel group is not initialized")
return _ETP
def init_ascend_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
expert_tensor_parallel_size: int = 1,
backend: Optional[str] = None,
):
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
backend = backend or torch.distributed.get_backend(
get_world_group().device_group)
num_expert_parallel_groups: int = expert_tensor_parallel_size
num_expert_tensor_parallel_groups: int = (world_size //
expert_tensor_parallel_size)
global _EP
assert _EP is None, ("expert parallel group is already initialized")
group_ranks = []
for i in range(num_expert_parallel_groups):
ranks = list(range(i, world_size, num_expert_parallel_groups))
group_ranks.append(ranks)
_EP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="ep")
group_ranks = []
global _ETP
assert _ETP is None, (
"expert tensor parallel group is already initialized")
for i in range(num_expert_tensor_parallel_groups):
ranks = list(
range(i * expert_tensor_parallel_size,
(i + 1) * expert_tensor_parallel_size))
group_ranks.append(ranks)
_ETP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="etp")
def destory_ascend_model_parallel():
global _EP
if _EP:
_EP.destroy()
_EP = None
global _ETP
if _ETP:
_ETP.destroy()
_ETP = None

View File

@@ -2,8 +2,15 @@ from vllm import ModelRegistry
def register_model():
from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401
from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401
from .deepseek_v2 import CustomDeepseekV3ForCausalLM # noqa: F401
from .qwen2_vl import CustomQwen2VLForConditionalGeneration # noqa: F401
ModelRegistry.register_model(
"DeepSeekMTPModel",
"vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP")
ModelRegistry.register_model(
"Qwen2VLForConditionalGeneration",
"vllm_ascend.models.qwen2_vl:CustomQwen2VLForConditionalGeneration")

View File

@@ -0,0 +1,173 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Adapted from vllm/model_executor/models/qwen2_vl.py
# Copyright 2023 The vLLM team.
#
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import \
VocabParallelEmbedding
from vllm.model_executor.models.deepseek_mtp import (
DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer,
SharedHead)
from vllm.model_executor.models.utils import maybe_prefix
from vllm.model_executor.sampling_metadata import SamplingMetadata
from .deepseek_v2 import CustomDeepseekV2DecoderLayer
class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
def __init__(
self,
config: PretrainedConfig,
prefix: str,
model_config: ModelConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
nn.Module.__init__(self)
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size,
bias=False)
self.shared_head = SharedHead(config=config, quant_config=quant_config)
self.mtp_block = CustomDeepseekV2DecoderLayer(config, prefix,
model_config,
cache_config,
quant_config)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_index: int = 0,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
assert inputs_embeds is not None
# masking inputs at position 0, as not needed by MTP
inputs_embeds = torch.where((positions == 0).unsqueeze(-1),
torch.zeros_like(inputs_embeds),
inputs_embeds)
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)
hidden_states = self.eh_proj(
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
hidden_states, residual = self.mtp_block(positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
residual=None)
hidden_states = residual + hidden_states
return hidden_states
class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config
self.mtp_start_layer_idx = config.num_hidden_layers
self.num_mtp_layers = config.num_nextn_predict_layers
# to map the exact layer index from weights
self.layers = torch.nn.ModuleDict({
str(idx): CustomDeepSeekMultiTokenPredictorLayer(
config,
f"{prefix}.layers.{idx}",
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
quant_config=vllm_config.quant_config,
)
for idx in range(self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers)
})
# Note: torch._dynamo.exc.Unsupported: builtin: str
self.layers_list = [
self.layers[str(idx)]
for idx in range(self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers)
]
self.logits_processor = LogitsProcessor(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
current_step_idx = (spec_step_idx % self.num_mtp_layers)
return self.layers_list[current_step_idx](
input_ids,
positions,
kv_caches[current_step_idx],
attn_metadata,
previous_hidden_states,
inputs_embeds,
current_step_idx,
)
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
spec_step_idx: int = 0,
) -> torch.Tensor:
current_step_idx = (spec_step_idx % self.num_mtp_layers)
mtp_layer = self.layers_list[current_step_idx]
logits = self.logits_processor(mtp_layer.shared_head.head,
mtp_layer.shared_head(hidden_states),
sampling_metadata)
return logits
class CustomDeepSeekMTP(DeepSeekMTP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
self.config = vllm_config.model_config.hf_config
self.model = CustomDeepSeekMultiTokenPredictor(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "model"))
self.sampler = get_sampler()

View File

@@ -19,36 +19,77 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from
# vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_v2.py
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py
"""Inference-only DeepseekV2/DeepseekV3 model."""
from typing import Optional, Union
# <<<<<<< HEAD
# # Adapted from
# # vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_v2.py
# # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py
# """Inference-only DeepseekV2/DeepseekV3 model."""
# from typing import Optional, Union
# import torch
# from torch import nn
# from transformers import PretrainedConfig
# from vllm.config import CacheConfig, ModelConfig, VllmConfig
# from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
# from vllm.model_executor.layers.fused_moe import FusedMoE
# from vllm.model_executor.layers.layernorm import RMSNorm
# from vllm.model_executor.layers.linear import ReplicatedLinear
# from vllm.model_executor.layers.logits_processor import LogitsProcessor
# from vllm.model_executor.layers.quantization import QuantizationConfig
# from vllm.model_executor.layers.sampler import get_sampler
# from vllm.model_executor.layers.vocab_parallel_embedding import (
# ParallelLMHead, VocabParallelEmbedding)
# from vllm.model_executor.models.deepseek_v2 import ( # noqa
# DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM,
# DeepseekV2MLAAttention, DeepseekV2MLP, DeepseekV2MoE)
# =======
import os
from typing import Any, Dict, Optional, Union
import torch
import torch.distributed as dist
from torch import nn
from transformers import PretrainedConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.attention import Attention
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
get_current_vllm_config)
from vllm.distributed import (get_dp_group, get_pp_group,
get_tensor_model_parallel_world_size,
get_tp_group, tensor_model_parallel_all_reduce)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.models.deepseek_v2 import ( # noqa
DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM,
DeepseekV2MLAAttention, DeepseekV2MLP, DeepseekV2MoE)
from vllm.model_executor.models.deepseek_v2 import \
DeepseekV2ForCausalLM # ruff: noqa: E501
from vllm.model_executor.models.deepseek_v2 import \
yarn_get_mscale # ruff: noqa: E501
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention,
DeepseekV2DecoderLayer,
DeepseekV2MLAAttention,
DeepseekV2MLP)
from vllm.model_executor.models.utils import (
PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
# >>>>>>> dcd5c73 (Feat: Graph mode for deepseek v2/v3.)
from vllm.sequence import IntermediateTensors
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.utils import VLLM_ENABLE_GRAPH_MODE
class CustomDeepseekV2MoE(DeepseekV2MoE):
class CustomDeepseekV2MoE(nn.Module):
top_k: int
def __init__(
self,
@@ -56,10 +97,15 @@ class CustomDeepseekV2MoE(DeepseekV2MoE):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
nn.Module.__init__(self)
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor
self.n_shared_experts = config.n_shared_experts
self.routed_scaling_factor = config.routed_scaling_factor
if self.tp_size > config.n_routed_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.n_routed_experts}.")
if config.hidden_act != "silu":
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
@@ -76,7 +122,7 @@ class CustomDeepseekV2MoE(DeepseekV2MoE):
else:
self.gate.e_score_correction_bias = None
self.experts = FusedMoE(
self.experts = AscendFusedMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
@@ -99,9 +145,248 @@ class CustomDeepseekV2MoE(DeepseekV2MoE):
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
reduce_results=True,
prefix=f"{prefix}.shared_experts",
)
CustomDeepseekV2MoE.top_k = config.num_experts_per_tok
vllm_config = get_current_vllm_config()
self.dp_size = get_dp_group().world_size
batch_size = vllm_config.scheduler_config.max_num_seqs
self.enable_mc2 = int(os.environ.get("VLLM_ENABLE_MC2", 0)) == 1
params_dtype = torch.get_default_dtype()
self.final_hidden_states = torch.zeros(
[batch_size, config.hidden_size], dtype=params_dtype, device="npu")
self.tp_group = get_tp_group().device_group
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
attn_metadata = get_forward_context().attn_metadata
if attn_metadata is None:
# for profile run
return hidden_states
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
if (self.tp_size > 1 and self.enable_mc2
and attn_metadata.num_prefills == 0):
# hidden_states = dist._functional_collectives.reduce_scatter_tensor(
# hidden_states, "sum", scatter_dim=0, group=self.tp_group
# )
chunks = torch.chunk(hidden_states,
get_tp_group().world_size,
dim=0)
hidden_states = chunks[get_tp_group().rank_in_group]
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
is_prefill = True if attn_metadata.num_prefills > 0 else False
# is_prefill = attn_metadata.num_prefills > 0
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
is_prefill=is_prefill,
top_k=CustomDeepseekV2MoE.top_k) * self.routed_scaling_factor
if self.tp_size > 1:
if self.enable_mc2 and not is_prefill:
dist.all_gather_into_tensor(self.final_hidden_states,
final_hidden_states, self.tp_group)
final_hidden_states = self.final_hidden_states
else:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
return final_hidden_states.view(num_tokens, hidden_dim)
class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
def __init__(
self,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: Optional[int],
kv_lora_rank: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size()
assert num_heads % tp_size == 0
self.num_local_heads = num_heads // tp_size
self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_a_proj")
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
self.num_heads *
self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_b_proj")
else:
self.q_proj = ColumnParallelLinear(self.hidden_size,
self.num_heads *
self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj")
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_a_proj_with_mqa")
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj")
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn'
self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=False)
if rope_scaling:
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
scaling_factor = rope_scaling["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
# In the MLA backend, kv_cache includes both k_c and
# pe (i.e. decoupled position embeddings). In particular,
# the concat_and_cache_mla op requires
# k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
# i.e.
# kv_lora_rank + qk_rope_head_dim == head_size
self.mla_attn = Attention(
num_heads=self.num_local_heads,
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
scale=self.scaling,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
# MLA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
qk_head_dim=self.qk_head_dim,
v_head_dim=self.v_head_dim,
rotary_emb=self.rotary_emb,
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
kv_a_layernorm=self.kv_a_layernorm,
kv_b_proj=self.kv_b_proj,
o_proj=self.o_proj,
)
self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2])
if VLLM_ENABLE_GRAPH_MODE == "1":
self.forward = self.forward_torchair
else:
self.forward = self.forward_eager # type: ignore
def forward_torchair(self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor = None,
attn_metadata=None):
if self.q_lora_rank is not None:
ckq = self.q_a_proj(hidden_states)[0]
hidden_states_or_q_c = self.q_a_layernorm(ckq)
else:
hidden_states_or_q_c = hidden_states
return self.mla_attn(hidden_states_or_q_c, hidden_states, None,
kv_cache, attn_metadata)
def forward_eager(self, positions: torch.Tensor,
hidden_states: torch.Tensor):
if self.q_lora_rank is not None:
ckq = self.q_a_proj(hidden_states)[0]
hidden_states_or_q_c = self.q_a_layernorm(ckq)
else:
hidden_states_or_q_c = hidden_states
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
return self.mla_attn(hidden_states_or_q_c,
kv_c_normed,
k_pe,
output_shape=hidden_states.shape)
# def forward(
# self,
# positions: torch.Tensor,
# hidden_states: torch.Tensor,
# # torchair should pass below two parameters
# kv_cache: torch.Tensor = None,
# attn_metadata: AttentionMetadata = None,
# ) -> torch.Tensor:
# if self.q_lora_rank is not None:
# ckq = self.q_a_proj(hidden_states)[0]
# hidden_states_or_q_c = self.q_a_layernorm(ckq)
# else:
# hidden_states_or_q_c = hidden_states
# if VLLM_ENABLE_GRAPH_MODE == '1':
# return self.mla_attn(hidden_states_or_q_c, hidden_states, None,
# kv_cache, attn_metadata)
# else:
# kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
# [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
# kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
# return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, output_shape=hidden_states.shape)
# kv_cache, attn_metadata)
class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
@@ -124,8 +409,9 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
# with the layer's index.
layer_idx = int(prefix.split(sep='.')[-1])
self.layer_idx = layer_idx
# TODO: enable mla in vllm-ascend
if model_config.use_mla:
attn_cls = DeepseekV2MLAAttention
attn_cls = CustomDeepseekV2MLAAttention
else:
attn_cls = DeepseekV2Attention
self.self_attn = attn_cls(
@@ -180,8 +466,8 @@ class CustomDeepseekV2Model(nn.Module):
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
if get_pp_group().is_first_rank:

View File

@@ -18,3 +18,4 @@ import vllm_ascend.ops.activation # noqa
import vllm_ascend.ops.fused_moe # noqa
import vllm_ascend.ops.layernorm # noqa
import vllm_ascend.ops.rotary_embedding # noqa
import vllm_ascend.ops.vocab_parallel_embedding # noqa

View File

@@ -0,0 +1,293 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/kernels/test_moe.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
import torch
from vllm.model_executor.layers.linear import ColumnParallelLinear
# Implementation of vanilla chunked prefill, should be removed after the kernel is ready for
# all the corner case
def vanilla_chunked_prefill(
output: torch.Tensor,
query: torch.Tensor, # (num_tokens, heads, head_size)
key_cache: torch.Tensor, # (num_blocks, block_size, kv_heads, head_size)
value_cache: torch.
Tensor, # (num_blocks, block_size, kv_heads, head_size,)
block_tables: torch.Tensor, # (num_seqs, max_num_blocks_per_seq)
cu_seqlen_q: torch.Tensor, # (num_seqs + 1,)
cu_seqlen_k: torch.Tensor, # (num_seqs + 1,)
max_seqlen_q: int,
max_seqlen_k: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
causal: bool = True,
) -> None:
num_query_heads = query.shape[1]
head_dim = value_cache.shape[3]
num_kv_heads = value_cache.shape[2]
block_size = value_cache.shape[1]
num_batch = cu_seqlen_q.shape[0] - 1
max_num_blocks_per_seq = block_tables.shape[1]
key = key_cache[block_tables].view(num_batch,
max_num_blocks_per_seq * block_size,
num_kv_heads, head_dim)
value = value_cache[block_tables].view(num_batch,
max_num_blocks_per_seq * block_size,
num_kv_heads, head_dim)
key = key[:, :max_seqlen_k, :, :]
value = value[:, :max_seqlen_k, :, :]
seqlen_k = cu_seqlen_k[1:] - cu_seqlen_k[:-1]
seqlen_q = cu_seqlen_q[1:] - cu_seqlen_q[:-1]
seqlen_q = seqlen_q.view(-1, 1)
seqlen_k = seqlen_k.view(-1, 1)
seqlen_diff = seqlen_k - seqlen_q
q_idx_mask = (torch.arange(0, max_seqlen_q,
device="npu").view(1, -1).repeat(num_batch, 1))
k_idx_mask = (torch.arange(0, max_seqlen_k,
device="npu").view(1, -1).repeat(num_batch, 1))
q_mask = q_idx_mask < seqlen_q
k_mask = k_idx_mask < seqlen_k
# calculate idx for causal mask of query [batch, max_seqlen_q]
causal_mask_idx = (q_idx_mask + seqlen_diff)[q_mask]
# generate causal mask [batch, max_seqlen_q, max_seqlen_k]
tril_mask = torch.tril(torch.ones(max_seqlen_k, max_seqlen_k,
device="npu"))
tril_mask[tril_mask == 0] = float("-inf")
tril_mask[tril_mask == 1] = 0
causal_mask = tril_mask[causal_mask_idx]
causal_mask_padding = torch.empty([num_batch, max_seqlen_q, max_seqlen_k],
device="npu").fill_(float("-inf"))
causal_mask_padding[q_mask] = causal_mask
# to [batch, num_heads, max_seqlen_q, max_seqlen_k]
causal_mask_padding = causal_mask_padding.unsqueeze(1)
pad_q = torch.zeros(
[num_batch, max_seqlen_q, num_query_heads, head_dim],
device="npu",
dtype=query.dtype,
)
pad_k = torch.zeros(
[num_batch, max_seqlen_k, num_kv_heads, head_dim],
device="npu",
dtype=key.dtype,
)
pad_v = torch.zeros(
[num_batch, max_seqlen_k, num_kv_heads, head_dim],
device="npu",
dtype=value.dtype,
)
pad_q[q_mask] = query
pad_k[k_mask] = key[k_mask]
pad_v[k_mask] = value[k_mask]
if num_query_heads > num_kv_heads:
pad_k = pad_k.view(
[num_batch, max_seqlen_k, num_kv_heads, 1, head_dim])
pad_k = pad_k.repeat(1, 1, 1, num_query_heads // num_kv_heads, 1).view(
[num_batch, max_seqlen_k, num_query_heads, head_dim])
pad_v = pad_v.view(
[num_batch, max_seqlen_k, num_kv_heads, 1, head_dim])
pad_v = pad_v.repeat(1, 1, 1, num_query_heads // num_kv_heads, 1).view(
[num_batch, max_seqlen_k, num_query_heads, head_dim])
# permute to [b, h, n, k]
pad_q = pad_q.permute(0, 2, 1, 3)
pad_k = pad_k.permute(0, 2, 1, 3)
pad_v = pad_v.permute(0, 2, 1, 3)
attn_mask = torch.empty([num_batch, 1, 1, max_seqlen_k],
device="npu").fill_(float("-inf"))
attn_mask[:, :, :, :max_seqlen_k].masked_fill_(k_mask[:, None, None, :], 0)
# [b, h, f, t]
attn_weights = torch.einsum("bhqd,bhkd->bhqk", pad_q, pad_k)
attn_weights *= scale
attn_mask = attn_mask.float()
attn_weights = attn_weights + attn_mask
if causal:
attn_weights = attn_weights + causal_mask_padding
attn_weights = torch.softmax(attn_weights, dim=-1)
attn_output = torch.einsum("bhqk,bhkd->bhqd", attn_weights, pad_v.float())
attn_output = attn_output.permute(0, 2, 1, 3)
attn_output = (attn_output[q_mask].view([-1, num_query_heads,
head_dim]).to(output.dtype))
output.copy_(attn_output)
return attn_output
def vanilla_chunked_prefill_mla(
output: torch.Tensor, # (num_tokens, num_heads, v_head_dim)
query: torch.Tensor, # (num_tokens, num_heads, nope_dim + rope_dim)
kv_cache: torch.Tensor, # (num_blocks, block_size, latent_kv)
block_tables: torch.Tensor, # (batch_size, max_num_blocks_per_seq)
query_lens: torch.Tensor, # (batch_size)
context_lens: torch.Tensor, # (batch_size)
kv_b_proj: ColumnParallelLinear, # ()
max_query_len: int,
max_context_len: int,
nope_dim: int,
rope_dim: int,
v_head_dim: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
causal: bool = True) -> None:
batch_size = block_tables.size(0)
assert query_lens.size(0) == batch_size
num_heads = query.size(1)
block_size = kv_cache.size(1)
latent_kv_dim = kv_cache.size(3) - rope_dim
max_num_blocks_per_seq = block_tables.size(1)
batch_size = query_lens.size(0)
kv_cache = kv_cache.squeeze()
# select kv_c out as [batch_size, max_context_len, latent_kv + rope_dim]
cache_kv_c_pe = kv_cache[block_tables].view(
batch_size, max_num_blocks_per_seq * block_size,
latent_kv_dim + rope_dim)[:, :max_context_len, :]
# get kv_c and k_pe
# cached_kv_c: [batch_size, max_context_len, latent_kv]
# cached_k_pe: [batch_size, max_context_len, rope_dim]
cache_kv_c = cache_kv_c_pe[:, :, :latent_kv_dim]
cache_k_pe = cache_kv_c_pe[:, :, latent_kv_dim:]
# get k_rope and v
# k_nope: [batch_size, max_context_len, num_heads, nope_dim]
# value: [batch_size, max_context_len, num_heads, v_head_dim]
k_nope, value = kv_b_proj(cache_kv_c)[0].view(
batch_size, max_context_len, num_heads,
nope_dim + v_head_dim).split([nope_dim, v_head_dim], dim=-1)
# key: [batch_size, max_context_len, num_hads, rope_dim + nope_dim]
key = torch.cat(
[k_nope, cache_k_pe.unsqueeze(2).expand(-1, -1, num_heads, -1)],
dim=-1)
context_lens = context_lens.view(-1, 1).to("npu")
query_lens = query_lens.view(-1, 1).to("npu")
seq_diff = context_lens - query_lens
q_idx_mask = (torch.arange(0, max_query_len,
device="npu").view(1, -1).repeat(batch_size, 1))
kv_c_idx_mask = (torch.arange(0, max_context_len,
device="npu").view(1,
-1).repeat(batch_size, 1))
kv_c_mask = kv_c_idx_mask < context_lens
q_mask = q_idx_mask < query_lens
# calculate idx for causal mask of query [batch, max_seqlen_q]
causal_mask_idx = (q_idx_mask + seq_diff)[q_mask]
# generate causal mask [batch, max_seqlen_q, max_seqlen_k]
tril_mask = torch.tril(
torch.ones(max_context_len, max_context_len, device="npu"))
tril_mask[tril_mask == 0] = float("-inf")
tril_mask[tril_mask == 1] = 0
causal_mask = tril_mask[causal_mask_idx]
causal_mask_padding = torch.empty(
[batch_size, max_query_len, max_context_len],
device="npu").fill_(float("-inf"))
causal_mask_padding[q_mask] = causal_mask
# to [batch, num_heads, max_seqlen_q, max_seqlen_k]
causal_mask_padding = causal_mask_padding.unsqueeze(1)
pad_q = torch.zeros(
[batch_size, max_query_len, num_heads, rope_dim + nope_dim],
device="npu",
dtype=query.dtype,
)
pad_k = torch.zeros(
[batch_size, max_context_len, num_heads, rope_dim + nope_dim],
device="npu",
dtype=key.dtype,
)
pad_v = torch.zeros(
[batch_size, max_context_len, num_heads, v_head_dim],
device="npu",
dtype=value.dtype,
)
pad_q[q_mask] = query
pad_k[kv_c_mask] = key[kv_c_mask]
pad_v[kv_c_mask] = value[kv_c_mask]
pad_q = pad_q.permute(0, 2, 1, 3)
pad_k = pad_k.permute(0, 2, 1, 3)
pad_v = pad_v.permute(0, 2, 1, 3)
attn_mask = torch.empty([batch_size, 1, 1, max_context_len],
device="npu").fill_(float("-inf"))
attn_mask[:, :, :, :max_context_len].masked_fill_(
kv_c_mask[:, None, None, :], 0)
# [b, h, f, t]
attn_weights = torch.einsum("bhqd,bhkd->bhqk", pad_q, pad_k)
attn_weights *= scale
attn_mask = attn_mask.float()
attn_weights = attn_weights + attn_mask
if causal:
attn_weights = attn_weights + causal_mask_padding
attn_weights = torch.softmax(attn_weights, dim=-1)
attn_output = torch.einsum("bhqk,bhkd->bhqd", attn_weights, pad_v.float())
attn_output = attn_output.permute(0, 2, 1, 3)
attn_output = (attn_output[q_mask].view([-1, num_heads,
v_head_dim]).to(output.dtype))
output.copy_(attn_output)
return attn_output
def vanilla_decode_mla(
query: torch.Tensor, # [num_tokens, num_heads, latent_dim + rope_dim]
key_cache: torch.
Tensor, # [num_blocks, block_size, num_kv_heads, latent_dim + rope_dim]
num_kv_heads: int,
num_heads: int,
scale: float,
block_table: torch.Tensor, # [batch_size, max_block_size]
context_lens: List[int],
mla_vhead_size: int,
rope_dim: int,
output: torch.Tensor):
batch_size = block_table.size()[0]
max_block_size = block_table.size()[1]
reduce_dim = key_cache.size()[-1]
block_size = key_cache.size()[1]
latent_dim = reduce_dim - rope_dim
kv_c_and_pe = key_cache[block_table].view(
[batch_size, max_block_size * block_size, num_kv_heads, reduce_dim])
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, device="npu").view(batch_size, 1)
# [batch_size, max_context_len, num_kv_heads, latent_dim + rope_dim]
# since the kv head is 1 in deepseek, we use expand here for perf
kv_c_and_pe = kv_c_and_pe[:, :max_context_len, :, :].expand(
-1, -1, num_heads, 1)
kv_c = kv_c_and_pe[..., :latent_dim]
kv_idx_mask = (torch.arange(0, max_context_len,
device="npu").view(1,
-1).repeat(batch_size, 1))
# [batch_size, max_context_len]
kv_idx_mask = kv_idx_mask < context_lens
query = query.unsqueeze(1)
attn_weights = torch.einsum("bqhd,bkhd->bhqk", query, kv_c_and_pe)
attn_weights *= scale
attn_weights = attn_weights + kv_idx_mask[:, -1, -1, :].float()
attn_weights = torch.softmax(attn_weights, dim=-1)
attn_output = torch.einsum("bhqk,bkhd->bqhd", attn_weights,
kv_c.float()).view(-1, num_heads, latent_dim)
output.copy_(attn_output)
return output

35
vllm_ascend/ops/cache.py Normal file
View File

@@ -0,0 +1,35 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/kernels/test_moe.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
def concat_and_cache_mla(
kv_c_normed: torch.Tensor, # [num_tokens, num_kv_head, nope]
k_pe: torch.Tensor, # [num_tokens, num_kv_head, rope]
kv_cache: torch.
Tensor, # [num_blocks, block_size, num_kv_head, nope + rope]
slot_mapping, # [num_tokens]
):
num_blocks = kv_cache.size()[0]
block_size = kv_cache.size()[1]
num_kv_head = k_pe.size()[1]
idx_for_copy = slot_mapping // block_size * block_size + slot_mapping % block_size
kv_cache = kv_cache.view(num_blocks * block_size, num_kv_head, -1)
kv_cache[idx_for_copy] = torch.cat([kv_c_normed.unsqueeze(1), k_pe],
dim=-1)

View File

@@ -15,12 +15,131 @@
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/kernels/test_moe.py
import os
from typing import Callable, Optional
import torch
import torch.distributed as dist
import torch_npu
from vllm.model_executor.layers.fused_moe.layer import \
UnquantizedFusedMoEMethod
from vllm.config import get_current_vllm_config
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_dp_group
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
from vllm.model_executor.layers.quantization.base_config import \
QuantizeMethodBase
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
def fused_experts_with_mc2(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None,
moe_all_to_all_group_name: Optional[str] = None,
) -> torch.Tensor:
global_bs = 0
moe_expert_num = len(expert_map)
kwargs = {
"x": hidden_states,
"expert_ids": topk_ids,
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": global_bs,
}
rank = torch.distributed.get_rank()
quant_mode = 0
ep_group = get_ep_group().device_group
local_rank = torch.distributed.get_rank(group=ep_group)
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
world_szie = torch.distributed.get_world_size()
tp_size = world_szie // all_to_all_group_size
tp_rank = rank % tp_size
stage1_kwargs = {
"scales": None,
"quant_mode": quant_mode,
"group_ep": moe_all_to_all_group_name,
"ep_world_size": all_to_all_group_size,
"ep_rank_id": local_rank,
# "group_tp": self.moe_rs_group_name,
"group_tp": moe_all_to_all_group_name,
"tp_world_size": tp_size,
"tp_rank_id": tp_rank,
}
kwargs.update(stage1_kwargs)
output = torch_npu.npu_moe_distribute_dispatch(**kwargs)
# comm_stream.wait_stream(torch.npu.current_stream())
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
0:5]
w1 = w1.transpose(1, 2)
expert_token_nums = torch.cumsum(expert_token_nums,
dim=0,
dtype=torch.int64)
group_list = expert_token_nums.to(torch.int64)
gate_up_out_list = torch_npu.npu_grouped_matmul(
x=[expand_x],
weight=[w1],
split_item=2,
group_list_type=0,
group_type=0,
group_list=group_list,
)
# TODO: Remove this in the future.
gate_up_out = torch.cat(gate_up_out_list, dim=0)
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
w2 = w2.transpose(1, 2)
down_out_list = torch_npu.npu_grouped_matmul(
x=[gate_up_out],
weight=[w2],
split_item=2,
group_list_type=0,
group_type=0,
group_list=group_list,
)
down_out_list = torch.cat(down_out_list, dim=0)
# moeCombine
kwargs = {
"expand_x": down_out_list,
"expert_ids": topk_ids,
"expand_idx": expand_idx,
"expert_scales": topk_weights.to(torch.float32),
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": 0,
}
tp_recv_counts = output[5]
stage3_kwargs = {
"ep_send_counts": ep_recv_counts,
"group_ep": moe_all_to_all_group_name,
"ep_world_size": all_to_all_group_size,
"ep_rank_id": local_rank,
"tp_send_counts": tp_recv_counts,
# "group_tp": self.moe_rs_group_name,
"group_tp": moe_all_to_all_group_name,
"tp_world_size": tp_size,
"tp_rank_id": tp_rank,
}
kwargs.update(stage3_kwargs)
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs)
return hidden_states
def fused_experts(
@@ -47,22 +166,27 @@ def fused_experts(
Returns:
hidden_states: Hidden states after routing.
"""
"""
# Check constraints.
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
"""
# if torch.distributed.get_rank() == 0:
# print(w1.shape)
# print(hidden_states.shape)
original_shape = hidden_states.shape
assert len(original_shape) == 2
# assert len(original_shape) == 2
num_tokens = hidden_states.shape[:-1].numel()
num_experts = w1.shape[0]
dtype = hidden_states.dtype
device = hidden_states.device
assert dtype in [torch.float32, torch.float16, torch.bfloat16
], "Only float32, float16, and bfloat16 are supported"
# assert dtype in [torch.float32, torch.float16, torch.bfloat16
# ], "Only float32, float16, and bfloat16 are supported"
if expert_map is not None:
# Generate token indices and flatten
@@ -152,11 +276,18 @@ def fused_experts(
final_hidden_states = torch.zeros(*original_shape,
device=hidden_states.device,
dtype=dtype)
final_hidden_states.index_add_(0, sorted_token_indices,
weighted_down_out)
# TODO: This should not happen! Look into it!
# fill nan with 0.0
final_hidden_states[torch.isnan(final_hidden_states)] = 0.0
# TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
# This created multiple NaN and index_add_ will mix them up which harms accracy
# remove this mask and filter after it being fixed
num_valid_tokens = mask.sum()
valid_token_mask = torch.arange(
0, sorted_token_indices.shape[0],
device=device).unsqueeze(1) < num_valid_tokens
valid_output = torch.where(
valid_token_mask, weighted_down_out,
torch.zeros_like(weighted_down_out)).to(dtype)
final_hidden_states.index_add_(0, sorted_token_indices, valid_output)
else:
# TODO: Reorder device memory 2 times here, replace the current
# implementation here when suitable operators become available.
@@ -199,16 +330,17 @@ def native_grouped_topk(
def select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: Optional[bool] = True
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Select top-k experts based on router logits.
@@ -232,8 +364,23 @@ def select_experts(
Raises:
ValueError: If an unsupported scoring function is provided.
"""
assert hidden_states.shape[0] == router_logits.shape[0], (
"Number of tokens mismatch")
# assert hidden_states.shape[0] == router_logits.shape[0], (
# "Number of tokens mismatch")
# if os.environ.get("VLLM_ENABLE_GRAPH_MODE") == "1" and not is_prefill:
# topk_weight, topk_idx, _ = torch.ops.npu_inference.npu_moe_gating_top_k(
# router_logits,
# k=top_k, # topk当前写8
# bias=e_score_correction_bias,
# k_group=topk_group, # fix: 4
# group_count=num_expert_group, # fix 8
# group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
# renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
# norm_type=1, # 0: softmax; 1: sigmoid(fix)
# # out_flag=False, # todo new api; 第三个输出是否输出
# # y2_flag=False, # old api; 第三个输出是否输出
# routed_scaling_factor=1,
# eps=float(1e-20))
# return topk_weight, topk_idx
if custom_routing_function is not None:
raise NotImplementedError(
@@ -261,14 +408,16 @@ def select_experts(
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
topk_weights = native_grouped_topk(topk_weights, num_expert_group,
topk_group)
# TODO bfloat16 is not supported in torch.topk with ge graph.
if e_score_correction_bias is not None:
topk_ids = torch.topk(topk_weights, k=top_k, dim=-1,
topk_ids = torch.topk(topk_weights.to(torch.float32),
k=top_k,
dim=-1,
sorted=False)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_weights.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(topk_weights,
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
k=top_k,
dim=-1,
sorted=False)
@@ -285,46 +434,245 @@ def select_experts(
return topk_weights, topk_ids
def forward_oot(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
**kwargs,
):
assert router_logits.shape[
1] == global_num_experts, "Number of global experts mismatch"
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
def __init__(self):
super().__init__()
vllm_config = get_current_vllm_config()
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map)
ep_group = get_ep_group()
self.ep_size = ep_group.world_size
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
self.local_batch_size = self.global_batch_size // self.ep_size
try:
device_group = ep_group.device_group
# TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch.distributed.get_rank(group=device_group)
backend = device_group._get_backend(torch.device("npu"))
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
local_rank)
except AttributeError:
self.moe_all_to_all_group_name = None
def process_weights_after_loading(self, layer):
super(UnquantizedFusedMoEMethod,
self).process_weights_after_loading(layer)
layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight(
layer.w13_weight.data),
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight(
layer.w2_weight.data),
requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill=False,
**kwargs,
):
# assert router_logits.shape[
# 1] == global_num_experts, "Number of global experts mismatch"
# set prefill as false always, should fix this
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
is_prefill=is_prefill)
if os.environ.get("VLLM_ENABLE_MC2") == "1" and not is_prefill:
return fused_experts_with_mc2(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
else:
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map)
UnquantizedFusedMoEMethod.forward_oot = forward_oot
class AscendFusedMoE(FusedMoE):
def __init__(self,
num_experts,
top_k,
hidden_size,
intermediate_size,
params_dtype=None,
reduce_results=False,
renormalize=True,
use_grouped_topk=False,
num_expert_group=None,
topk_group=None,
quant_config=None,
tp_size=None,
ep_size=None,
dp_size=None,
prefix="",
custom_routing_function=None,
scoring_func="softmax",
e_score_correction_bias=None,
activation="silu"):
super(FusedMoE, self).__init__()
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.ep_size = get_ep_group().world_size
self.tp_size = get_etp_group().world_size
self.dp_size = (dp_size
if dp_size is not None else get_dp_group().world_size)
self.dp_rank = (0
if self.dp_size == 1 else get_dp_group().rank_in_group)
self.top_k = top_k
self.num_experts = num_experts
self.global_num_experts = num_experts
assert intermediate_size % self.tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
self.renormalize = renormalize
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.custom_routing_function = custom_routing_function
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
self.expert_map = None
self.activation = activation
if self.ep_size > 1:
# Create a tensor of size num_experts filled with -1
self.local_num_experts, self.expert_map = determine_expert_map(
self.ep_size,
get_ep_group().rank_in_group, self.global_num_experts)
self.tp_rank = get_etp_group().rank_in_group
self.ep_rank = get_ep_group().rank_in_group
else:
# Adjust TP size for DP attention
# haven't test its functionality yet, may remove in the future
self.tp_rank = self.tp_size * self.dp_rank
self.ep_rank = 0
self.tp_size = self.tp_size * self.dp_size
self.ep_size = 1
self.local_num_experts = self.global_num_experts
self.expert_map = None
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
"non-grouped topk.")
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
AscendUnquantizedFusedMoEMethod())
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method is not None
local_num_experts = torch.sum(self.expert_map != -1) \
if self.expert_map is not None else num_experts
moe_quant_params = {
"num_experts": local_num_experts,
"hidden_size": hidden_size,
"intermediate_size_per_partition":
self.intermediate_size_per_partition,
"params_dtype": params_dtype,
"weight_loader": self.weight_loader,
}
# need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
moe_quant_params["intermediate_size_full"] = intermediate_size
self.quant_method.create_weights(layer=self, **moe_quant_params)
def forward(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_prefill: bool,
top_k=None):
assert self.quant_method is not None
if top_k:
real_top_k = top_k
else:
real_top_k = self.top_k
if self.dp_size > 1:
if int(os.environ.get("VLLM_ENABLE_MC2") # type: ignore
) == 1 and not is_prefill:
...
elif int(os.environ.get("USING_LCCL_COM")) == 1: # type: ignore
hidden_states = get_dp_group().all_gather(
hidden_states, 0, False)
router_logits = get_dp_group().all_gather(
router_logits, 0, False)
else:
hidden_states = get_dp_group().all_gather(hidden_states, 0)
router_logits = get_dp_group().all_gather(router_logits, 0)
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=real_top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
is_prefill=is_prefill)
if self.dp_size > 1:
if int(os.environ.get("VLLM_ENABLE_MC2") # type: ignore
) == 1 and not is_prefill:
...
else:
final_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
final_hidden_states,
"sum",
scatter_dim=0,
group=get_dp_group().device_group)
# if self.reduce_results and self.tp_size > 1:
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states

View File

@@ -15,6 +15,7 @@
# This file is a part of the vllm-ascend project.
#
import math
from typing import Optional, Tuple
import torch
@@ -38,7 +39,7 @@ def rope_forward_oot(
if self.cos_sin_cache.dtype != query.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
# adopt custom kernel path for rotary_embedding
if CUSTOM_OP_ENABLED and self.is_neox_style:
if CUSTOM_OP_ENABLED and self.is_neox_style and self.head_size % 32 == 0:
return torch.ops._C.rotary_embedding(
positions,
query,
@@ -66,5 +67,169 @@ def rope_forward_oot(
return query.view(query_shape), key.view(key_shape)
def native_rope_deepseek_forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
):
# seq_len = positions.max() + 1
seq_len = self.max_position_embeddings
# x: [bs, num_attention_heads, seq_len, head_size]
# if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
# self._set_cos_sin_cache(seq_len=seq_len, device=query.device, dtype=query.dtype)
self._set_cos_sin_cache(seq_len=seq_len,
device=query.device,
dtype=query.dtype)
cos = self.cos_cached[:seq_len].to(dtype=query.dtype)
sin = self.sin_cached[:seq_len].to(dtype=query.dtype)
q_pe, k_pe = apply_rotary_pos_emb(query, key, cos, sin, positions)
return q_pe, k_pe
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
# Inverse dim formula to find dim based on number of rotations
def yarn_find_correction_dim(num_rotations,
dim,
base=10000,
max_position_embeddings=2048):
# Note: use torch instead of math to solve MTP compilation error.
return (dim * torch.log(
torch.tensor(max_position_embeddings) /
(num_rotations * 2 * torch.pi))) / (2 * torch.log(torch.tensor(base)))
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
# Find dim range bounds based on rotations
def yarn_find_correction_range(low_rot,
high_rot,
dim,
base=10000,
max_position_embeddings=2048):
# Note: use torch instead of math to solve MTP compilation error.
low = torch.floor(
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = torch.ceil(
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
# Note: use torch instead of max/min to solve MTP compilation error.
return torch.clamp(low, min=0), torch.clamp(high, max=dim - 1)
def yarn_linear_ramp_mask(min_value, max_value, dim):
# Note: The if conditional branch is not used here
# to solve MTP compilation error.
max_value += (min_value == max_value).float() * 0.001
linear_func = (torch.arange(dim, dtype=torch.float32) -
min_value) / (max_value - min_value)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos[position_ids]
sin = sin[position_ids]
cos = cos[:, None, None, :]
sin = sin[:, None, None, :]
if len(q.shape) == 3:
q = q[:, :, None, :]
if len(k.shape) == 2:
k = k[:, None, None, :]
elif len(k.shape) == 3:
k = k[:, :, None, :]
b, h_q, s, d = q.shape
q = q.view(b, h_q, s, d // 2, 2).transpose(4, 3).reshape(b, h_q, s, d)
b, h_k, s, d = k.shape
k = k.view(b, h_k, s, d // 2, 2).transpose(4, 3).reshape(b, h_k, s, d)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
q_embed = q_embed.view(b, h_q, d)
k_embed = k_embed.view(b, h_k, d)
return q_embed, k_embed
def _set_cos_sin_cache(self, seq_len, device, dtype):
seq_len = self.max_position_embeddings
self.max_seq_len_cached = seq_len
dim = self.rotary_dim
freq_extra = 1.0 / (self.base**(
torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
freq_inter = 1.0 / (self.scaling_factor * self.base**(
torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
low, high = yarn_find_correction_range(
self.beta_fast,
self.beta_slow,
dim,
self.base,
self.max_position_embeddings,
)
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
device=device, dtype=torch.float32)
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(seq_len, device=device, dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
# _mscale = float(
# yarn_get_mscale(self.scaling_factor, self.mscale)
# / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
# )
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", (emb.cos() * self.mscale).to(dtype),
persistent=False)
self.register_buffer("sin_cached", (emb.sin() * self.mscale).to(dtype),
persistent=False)
# TODO: Patch when aclnn ops avaiable
RotaryEmbedding.forward_oot = rope_forward_oot
DeepseekScalingRotaryEmbedding.forward = rope_forward_oot
# DeepseekScalingRotaryEmbedding.forward = rope_deepseek_forward_oot
DeepseekScalingRotaryEmbedding.forward = native_rope_deepseek_forward
DeepseekScalingRotaryEmbedding._set_cos_sin_cache = _set_cos_sin_cache
DeepseekScalingRotaryEmbedding.max_seq_len_cached = None

View File

@@ -0,0 +1,67 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Tuple
import torch
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.model_executor.layers.vocab_parallel_embedding import \
VocabParallelEmbedding
def get_masked_input_and_mask(
input_: torch.Tensor, org_vocab_start_index: int,
org_vocab_end_index: int, num_org_vocab_padding: int,
added_vocab_start_index: int,
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
# torch.compile will fuse all of the pointwise ops below
# into a single kernel, making it very fast
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ <
org_vocab_end_index)
added_vocab_mask = (input_ >= added_vocab_start_index) & (
input_ < added_vocab_end_index)
added_offset = added_vocab_start_index - (
org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
valid_offset = (org_vocab_start_index *
org_vocab_mask) + (added_offset * added_vocab_mask)
vocab_mask = org_vocab_mask | added_vocab_mask
input_ = vocab_mask * (input_ - valid_offset)
return input_, ~vocab_mask
def vocab_parallel_embedding_forward(self, input_):
if self.tp_size > 1:
# Build the mask.
masked_input, input_mask = get_masked_input_and_mask(
input_, self.shard_indices.org_vocab_start_index,
self.shard_indices.org_vocab_end_index,
self.shard_indices.num_org_vocab_padding,
self.shard_indices.added_vocab_start_index,
self.shard_indices.added_vocab_end_index)
else:
masked_input = input_
# Get the embeddings.
output_parallel = self.quant_method.embedding(self, masked_input.long())
# Mask the output embedding.
if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
# Reduce across all the model parallel GPUs.
output = tensor_model_parallel_all_reduce(output_parallel)
return output
VocabParallelEmbedding.forward = vocab_parallel_embedding_forward

View File

@@ -0,0 +1,16 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

View File

@@ -15,4 +15,5 @@
# limitations under the License.
#
import vllm_ascend.patch.platform.patch_0_8_4.patch_config # noqa
import vllm_ascend.patch.platform.patch_0_8_4.patch_config # noqa
import vllm_ascend.patch.platform.patch_0_8_4.patch_distributed # noqa

View File

@@ -0,0 +1,138 @@
import torch
import vllm
import vllm.distributed
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
_get_default_timeout,
is_nccl_available)
from torch.distributed.rendezvous import rendezvous
from vllm.config import ParallelConfig
def ascend_destroy_model_parallel():
"""Set the groups to none and destroy them."""
from vllm.distributed.parallel_state import _DP, _PP, _TP
if _TP:
_TP.destroy()
_TP = None
if _PP:
_PP.destroy()
_PP = None
if _DP:
_DP.destroy()
_DP = None
from vllm.platforms import current_platform
current_platform.destroy_platform_model_parallel()
def ascend_stateless_init_torch_distributed_process_group(
host: str, port: int, rank: int, world_size: int,
backend: str) -> ProcessGroup:
"""
A replacement for `torch.distributed.init_process_group` that does not
pollute the global state. The created ProcessGroup object can be used for
some operations such as `allreduce`, because it does not depend on the
global rank. However, some operations such as `broadcast` cannot be used
because it depends on the global rank.
# TODO: ask for help from PyTorch team if we need the `broadcast` operation.
This function is useful when we are not sure about the total number of
processes in the process group. For example, we may have process
1, 2, ..., 8 who want to communicate, and process 9 might be the same
process as process 1, or it might be a different process; process 10
might be the same process as process 5, or it might be a different process.
In this case, how can we reliably form a communication channel within
process 9 and 10, without affecting the communication channel within
process 1, 2, ..., 8?
One possible solution is to figure out if process 9 and 10 are the same
as process 1 and 5 beforehand, and then form a communication channel
based on the information, adjusting the ranks and world_size etc. However,
figuring out the information is not always easy, and it will interfere
with the main communication channel.
Our solution is to always form a communication channel with process 1, 2,
..., 8, and then use this function to form another communication channel
with process 9 and 10. This way, regardless of whether process 9 and 10
are the same as process 1 and 5, the main communication channel is
always formed with process 1, 2, ..., 8, and the additional communication
channel is formed with process 9 and 10.
"""
init_method = f"tcp://{host}:{port}"
backend = Backend(backend) # it is basically string
timeout = _get_default_timeout(backend)
store, rank, world_size = next(
rendezvous(init_method, rank, world_size, timeout=timeout))
store.set_timeout(timeout)
group_rank = rank
group_size = world_size
# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.
prefix_store = PrefixStore(init_method, store)
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
from vllm.platforms import current_platform
if backend == "gloo":
from torch.distributed.distributed_c10d import ProcessGroupGloo
backend_class = ProcessGroupGloo(prefix_store,
group_rank,
group_size,
timeout=timeout)
backend_type = ProcessGroup.BackendType.GLOO
device = torch.device("cpu")
elif backend == "nccl":
assert is_nccl_available()
from torch.distributed.distributed_c10d import ProcessGroupNCCL
backend_options = ProcessGroupNCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
backend_options)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")
elif current_platform.platform_has_backend_register():
current_platform.platform_register_backend()
return pg
else:
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class)
return pg
def parallel_config_get_dp_port(self) -> int:
"""
We might need to initialize process groups in multiple
processes that is related to data parallelism,
e.g. both in the worker and in the engine, which
can live in different processes. To avoid port conflicts, we
increment the port number each time we need to initialize a
new process group related to data parallelism.
"""
answer = self.data_parallel_master_port
self.data_parallel_master_port += 1
import os
# NOTE: Get port from envs directly when using torchrun
port = int(os.environ.get("MASTER_PORT", answer)) # type: ignore
return port
vllm.distributed.parallel_state.destroy_model_parallel = ascend_destroy_model_parallel
vllm.distributed.stateless_init_torch_distributed_process_group = ascend_stateless_init_torch_distributed_process_group
ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port

View File

@@ -13,4 +13,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
import vllm_ascend.patch.platform.patch_main.patch_distributed # noqa F401

View File

@@ -0,0 +1,138 @@
import torch
import vllm
import vllm.distributed
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
_get_default_timeout,
is_nccl_available)
from torch.distributed.rendezvous import rendezvous
from vllm.config import ParallelConfig
def ascend_destroy_model_parallel():
"""Set the groups to none and destroy them."""
from vllm.distributed.parallel_state import _DP, _PP, _TP
if _TP:
_TP.destroy()
_TP = None
if _PP:
_PP.destroy()
_PP = None
if _DP:
_DP.destroy()
_DP = None
from vllm.platforms import current_platform
current_platform.destroy_platform_model_parallel()
def ascend_stateless_init_torch_distributed_process_group(
host: str, port: int, rank: int, world_size: int,
backend: str) -> ProcessGroup:
"""
A replacement for `torch.distributed.init_process_group` that does not
pollute the global state. The created ProcessGroup object can be used for
some operations such as `allreduce`, because it does not depend on the
global rank. However, some operations such as `broadcast` cannot be used
because it depends on the global rank.
# TODO: ask for help from PyTorch team if we need the `broadcast` operation.
This function is useful when we are not sure about the total number of
processes in the process group. For example, we may have process
1, 2, ..., 8 who want to communicate, and process 9 might be the same
process as process 1, or it might be a different process; process 10
might be the same process as process 5, or it might be a different process.
In this case, how can we reliably form a communication channel within
process 9 and 10, without affecting the communication channel within
process 1, 2, ..., 8?
One possible solution is to figure out if process 9 and 10 are the same
as process 1 and 5 beforehand, and then form a communication channel
based on the information, adjusting the ranks and world_size etc. However,
figuring out the information is not always easy, and it will interfere
with the main communication channel.
Our solution is to always form a communication channel with process 1, 2,
..., 8, and then use this function to form another communication channel
with process 9 and 10. This way, regardless of whether process 9 and 10
are the same as process 1 and 5, the main communication channel is
always formed with process 1, 2, ..., 8, and the additional communication
channel is formed with process 9 and 10.
"""
init_method = f"tcp://{host}:{port}"
backend = Backend(backend) # it is basically string
timeout = _get_default_timeout(backend)
store, rank, world_size = next(
rendezvous(init_method, rank, world_size, timeout=timeout))
store.set_timeout(timeout)
group_rank = rank
group_size = world_size
# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.
prefix_store = PrefixStore(init_method, store)
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
from vllm.platforms import current_platform
if backend == "gloo":
from torch.distributed.distributed_c10d import ProcessGroupGloo
backend_class = ProcessGroupGloo(prefix_store,
group_rank,
group_size,
timeout=timeout)
backend_type = ProcessGroup.BackendType.GLOO
device = torch.device("cpu")
elif backend == "nccl":
assert is_nccl_available()
from torch.distributed.distributed_c10d import ProcessGroupNCCL
backend_options = ProcessGroupNCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
backend_options)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")
elif current_platform.platform_has_backend_register():
current_platform.platform_register_backend()
return pg
else:
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class)
return pg
def parallel_config_get_dp_port(self) -> int:
"""
We might need to initialize process groups in multiple
processes that is related to data parallelism,
e.g. both in the worker and in the engine, which
can live in different processes. To avoid port conflicts, we
increment the port number each time we need to initialize a
new process group related to data parallelism.
"""
answer = self.data_parallel_master_port
self.data_parallel_master_port += 1
import os
# NOTE: Get port from envs directly when using torchrun
port = int(os.environ.get("MASTER_PORT", answer)) # type: ignore
return port
vllm.distributed.parallel_state.destroy_model_parallel = ascend_destroy_model_parallel
vllm.distributed.stateless_init_torch_distributed_process_group = ascend_stateless_init_torch_distributed_process_group
ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port

View File

@@ -160,6 +160,8 @@ class NPUPlatform(Platform):
@classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1, use_mla):
if use_v1 and use_mla:
return "vllm_ascend.attention.mla_v1.AscendMLABackend"
if use_v1:
return "vllm_ascend.attention.attention_v1.AscendAttentionBackend"
if use_mla:
@@ -191,3 +193,30 @@ class NPUPlatform(Platform):
model configuration.
"""
return True
@classmethod
def destroy_platform_model_parallel(cls) -> None:
from vllm_ascend.distributed.parallel_state import \
destory_ascend_model_parallel
destory_ascend_model_parallel()
@classmethod
def platform_has_backend_register(cls) -> bool:
return True
@classmethod
def platform_register_backend(cls, pg, prefix_store, group_rank,
group_size, backend_options,
timeout) -> None:
from torch.distributed import ProcessGroup, is_hccl_available
assert is_hccl_available()
import torch_npu # noqa
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
backend_options = ProcessGroupHCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size,
backend_options)
device = torch.device("npu")
backend_class._set_sequence_number_for_group()
backend_type = ProcessGroup.BackendType.CUSTOM
pg._register_backend(device, backend_type, backend_class)

View File

@@ -33,9 +33,7 @@ from vllm.model_executor.layers.quantization import \
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
from vllm.model_executor.parameter import PerTensorScaleParameter
from vllm.model_executor.utils import set_weight_attrs
from .quantizer import AscendQuantizer
@@ -171,12 +169,10 @@ class AscendLinearMethod(LinearMethodBase):
output_size_per_partition,
params_dtype)
for weight_name, weight_param in weight_dict.items():
layer.register_parameter(
weight_name,
ModelWeightParameter(data=weight_param,
input_dim=1,
output_dim=0,
weight_loader=weight_loader))
param = torch.nn.Parameter(weight_param, requires_grad=False)
set_weight_attrs(param, {"input_dim": 1, "output_dim": 0})
layer.register_parameter(weight_name, param)
set_weight_attrs(param, extra_weight_attrs)
pertensor_dict = self.quant_method.get_pertensor_param(params_dtype)
for pertensor_name, pertensor_param in pertensor_dict.items():
@@ -189,11 +185,10 @@ class AscendLinearMethod(LinearMethodBase):
perchannel_dict = self.quant_method.get_perchannel_param(
output_size_per_partition, params_dtype)
for perchannel_name, perchannel_param in perchannel_dict.items():
layer.register_parameter(
perchannel_name,
ChannelQuantScaleParameter(data=perchannel_param,
output_dim=0,
weight_loader=weight_loader))
param = torch.nn.Parameter(perchannel_param, requires_grad=False)
set_weight_attrs(param, {"output_dim": 0})
layer.register_parameter(perchannel_name, param)
set_weight_attrs(param, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"):
@@ -264,48 +259,6 @@ class AscendKVCacheMethod(BaseKVCacheMethod):
seq_lens_tensor_cpu=seq_lens_tensor_cpu)
def fused_moe_perchannel_weight_loader(param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str, shard_id: str,
expert_id: int) -> None:
if shard_id not in ("w1", "w2", "w3"):
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
f"got {shard_id}.")
# Fetch the dim to shard the parameter/loaded weight
# based on the shard id. This will be whatever
# dimension intermediate_size_per_partition is used.
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
expert_data = param.data[expert_id]
tp_rank = get_tensor_model_parallel_rank()
# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size_per_partition is
is_transposed = getattr(param, "is_transposed", False)
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
if is_transposed:
shard_dim = int(not shard_dim)
if shard_id == "w2":
expert_data.copy_(loaded_weight)
elif shard_id in ("w1", "w3"):
shard_size = expert_data.shape[shard_dim] // 2
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
shard_size)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
# w3, up_proj: Load into second logical weight of w13.
else:
assert shard_id == "w3"
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight)
class AscendFusedMoEMethod(FusedMoEMethodBase):
"""FusedMoE method for Ascend quantization.
@@ -341,9 +294,6 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
# load `offset` weight in `fused_moe_perchannel_weight_loader`, the original weight load in vllm 0.7.3 could only load `scale` and `zero`
extra_weight_attrs.update(
{"weight_loader": fused_moe_perchannel_weight_loader})
dynamic_quant_param = self.quant_method.get_dynamic_quant_param(
num_experts, intermediate_size_per_partition, hidden_size,
params_dtype)
@@ -360,15 +310,19 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
global_num_experts: int,
expert_map: torch.Tensor,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
is_prefill: bool = True,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
return self.quant_method.apply(layer, x, use_grouped_topk, top_k,
router_logits, renormalize, topk_group,
num_expert_group,
num_expert_group, global_num_experts,
expert_map, is_prefill,
custom_routing_function, scoring_func,
e_score_correction_bias)

View File

@@ -16,6 +16,8 @@
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/vllm/worker/worker.py
#
import os
import torch
import torch_npu # noqa: F401
from packaging.version import Version
@@ -23,6 +25,8 @@ from vllm.logger import logger
import vllm_ascend.envs as envs
VLLM_ENABLE_GRAPH_MODE = os.environ.get('VLLM_ENABLE_GRAPH_MODE', '0')
def try_register_lib(lib_name: str, lib_info: str = ""):
import importlib

View File

@@ -0,0 +1,17 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import vllm_ascend.worker.cache_engine # noqa

View File

@@ -0,0 +1,69 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/vllm/worker/model_runner.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List, Tuple
import torch
from vllm.utils import is_pin_memory_available
from vllm.worker.cache_engine import CacheEngine
from vllm_ascend.utils import VLLM_ENABLE_GRAPH_MODE
def allocate_kv_cache(
self,
num_blocks: int,
device: str,
) -> List[Tuple]:
"""Allocates KV cache on the specified device."""
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
pin_memory = is_pin_memory_available() if device == "cpu" else False
kv_cache: List[Tuple] = []
# Align entries so they are 256 byte aligned for better performance
# Primarily targets MLA as this typically only ends up having entries
# be 128 byte aligned.
alloc_shape = kv_cache_shape
for _ in range(self.num_attention_layers):
# null block in CpuGpuBlockAllocator requires at least that
# block to be zeroed-out.
# We zero-out everything for simplicity.
layer_kv_cache_nope = torch.zeros(
alloc_shape[:-1] +
(self.model_config.hf_text_config.kv_lora_rank, ),
dtype=self.dtype,
pin_memory=pin_memory,
device=device)
layer_kv_cache_pe = torch.zeros(
alloc_shape[:-1] +
(self.model_config.hf_text_config.qk_rope_head_dim, ),
dtype=self.dtype,
pin_memory=pin_memory,
device=device)
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
# when entry_shape is higher than 1D
kv_cache.append((layer_kv_cache_nope, layer_kv_cache_pe))
return kv_cache
if VLLM_ENABLE_GRAPH_MODE == '1':
CacheEngine._allocate_kv_cache = allocate_kv_cache

View File

@@ -18,18 +18,21 @@
#
import dataclasses
import itertools
import weakref
from contextlib import contextmanager
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set,
Type, TypeVar, Union)
import numpy as np
import torch
import torch.nn as nn
import torch_npu
import vllm.envs as envs
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import VllmConfig
from vllm.config import CompilationLevel, VllmConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_pp_group
from vllm.forward_context import set_forward_context
@@ -53,7 +56,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, flatten_2d_lists,
is_pin_memory_available)
is_pin_memory_available, supports_dynamo)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict,
@@ -72,6 +75,7 @@ if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
TModelInputForNPU = TypeVar('TModelInputForNPU', bound="ModelInputForNPU")
ENCODER_NUM = 0
@dataclass(frozen=True)
@@ -526,6 +530,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
seq_lens = []
max_decode_seq_len = 0
is_prompt = self.inter_data_list[0].is_prompt
for inter_data in self.inter_data_list:
seq_lens.extend(inter_data.seq_lens)
if not inter_data.is_prompt:
@@ -540,7 +545,26 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
for data in self.inter_data_list
}
input_tokens_tensor = torch.tensor(flatten_2d_lists(input_tokens),
# Add graph_pad_size here
if self.runner.vllm_config.compilation_config.level ==\
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
graph_pad_size = self.runner.scheduler_config.max_num_seqs - len(
seq_lens)
else:
graph_pad_size = -1
#print(f"before tensor input_tokens: {input_tokens}")
#print(f"before tensor input_positions: {input_positions}")
#print(f"before list seq_lens: {seq_lens}")
input_tokens = flatten_2d_lists(input_tokens)
input_positions = flatten_2d_lists(input_positions)
if graph_pad_size != -1 and not is_prompt:
input_tokens.extend(itertools.repeat(0, graph_pad_size))
input_positions.extend( # type: ignore
itertools.repeat(0, graph_pad_size))
seq_lens.extend(itertools.repeat(1, graph_pad_size))
query_lens.extend(itertools.repeat(1, graph_pad_size))
input_tokens_tensor = torch.tensor(input_tokens,
dtype=torch.long,
device=self.runner.device)
if mrope_input_positions is not None:
@@ -548,13 +572,16 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
dtype=torch.long,
device=self.runner.device)
else:
input_positions_tensor = torch.tensor(
flatten_2d_lists(input_positions),
dtype=torch.long,
device=self.runner.device)
input_positions_tensor = torch.tensor(input_positions,
dtype=torch.long,
device=self.runner.device)
#print(f"after tensor input_tokens_tensor: {input_tokens_tensor}")
#print(f"after tensor input_positions_tensor: {input_positions_tensor}")
#print(f"after list seq_lens: {seq_lens}")
# Attention metadata.
attn_metadata = self.attn_metadata_builder.build(seq_lens, query_lens)
attn_metadata = self.attn_metadata_builder.build(
seq_lens, query_lens, graph_pad_size)
# LoRA data.
lora_requests = set()
@@ -582,6 +609,13 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
]
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
if self.runner.vllm_config.compilation_config.level ==\
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
torch._dynamo.mark_static(input_tokens_tensor)
torch._dynamo.mark_static(input_positions_tensor)
torch._dynamo.mark_static(attn_metadata.block_tables)
torch._dynamo.mark_static(attn_metadata.slot_mapping)
return self.model_input_cls(
input_tokens=input_tokens_tensor,
input_positions=input_positions_tensor,
@@ -841,6 +875,12 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
self.in_profile_run = False
self.graph_block_tables = np.zeros(
(self.vllm_config.scheduler_config.max_num_seqs,
(model_config.max_model_len + self.block_size - 1) //
self.block_size),
dtype=np.int32)
# Attention-free but stateful models like Mamba need a placeholder attn
# backend, as the attention metadata is needed to manage internal state.
# However we must bypass attention selection altogether for some models
@@ -930,6 +970,26 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
)
self.model = self.lora_manager.create_lora_manager(self.model)
# adapter torch compile with npu_backend
if self.vllm_config.compilation_config.level ==\
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
import torchair # type: ignore
from torchair import patch_for_hcom # type: ignore
# 通信算子成图
patch_for_hcom()
# 设置npu的config如果不设置config可以使用默认的那可以设置npu_backend="npu"
config = torchair.CompilerConfig()
config.experimental_config.frozen_parameter = True
config.experimental_config.tiling_schedule_optimize = True
torch.npu.set_compile_mode(jit_compile=False)
self.compile_model = torchair.inference.cache_compile(
self.model.forward,
dynamic=True,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
config=config,
ge_cache=False)
def save_sharded_state(
self,
path: str,
@@ -1219,10 +1279,43 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
self.attn_state.begin_forward(model_input)
assert model_input.attn_metadata is not None
if self.vllm_config.compilation_config.level ==\
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
torch._dynamo.mark_static(model_input.input_tokens)
torch._dynamo.mark_static(model_input.input_positions)
torch._dynamo.mark_static(model_input.attn_metadata.block_tables)
torch._dynamo.mark_static(model_input.attn_metadata.slot_mapping)
torch._dynamo.mark_static(
model_input.attn_metadata.query_start_loc)
torch._dynamo.mark_static(model_input.attn_metadata.seq_start_loc)
for kv in kv_caches:
if isinstance(kv, tuple):
torch._dynamo.mark_static(kv[0])
torch._dynamo.mark_static(kv[1])
# TODO(andoorve): We can remove this once all
# virtual engines share the same kv cache.
virtual_engine = model_input.virtual_engine
model_executable = self.model
prefill_meta = model_input.attn_metadata.prefill_metadata
previous_hidden_states = kwargs.get("previous_hidden_states")
if prefill_meta is None and self.vllm_config.compilation_config.level > 0:
model_executable = self.compile_model
# Note: graph_batch_size value not same as GPU
graph_batch_size = model_input.input_tokens.shape[ # type: ignore
0] # type: ignore
# Note: previous_hidden_states maybe None not same as GPU
if previous_hidden_states is not None:
previous_hidden_states = torch.cat([
previous_hidden_states,
torch.empty([
graph_batch_size - previous_hidden_states.shape[0],
*previous_hidden_states.shape[1:]
],
dtype=previous_hidden_states.dtype,
device=previous_hidden_states.device)
])
else:
model_executable = self.model
# Receive KV cache in distributed KV cache transfer setting
# In disagg prefill setting, it will also recv hidden states and bypass
@@ -1248,8 +1341,11 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_inner_state else {}
previous_hidden_states = kwargs.get("previous_hidden_states")
model_kwargs = {}
if self.vllm_config.compilation_config.level ==\
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
model_kwargs = {"inputs_embeds": None}
else:
model_kwargs = {}
if previous_hidden_states is not None:
model_kwargs["previous_hidden_states"] = previous_hidden_states
@@ -1273,44 +1369,30 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
**seqlen_agnostic_kwargs,
**model_kwargs)
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end.record()
# Sending KV cache in distributed KV cache transfer setting
# NOTE: the send operation is non-blocking
if self.need_send_kv(model_input, kv_caches):
get_kv_transfer_group().send_kv_caches_and_hidden_states(
# model_executable is used to know which layer the current
# worker is working on, so that we can send KV for only those
# layers.
model_executable,
model_input,
kv_caches,
hidden_or_intermediate_states,
)
# Compute the logits in the last pipeline stage.
if not get_pp_group().is_last_rank:
if (self.is_driver_worker
and hidden_or_intermediate_states is not None
and isinstance(hidden_or_intermediate_states,
IntermediateTensors)
and self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end.synchronize()
model_forward_time = model_forward_start.elapsed_time(
model_forward_end)
orig_model_forward_time = 0.0
if intermediate_tensors is not None:
orig_model_forward_time = intermediate_tensors.tensors.get(
"model_forward_time", torch.tensor(0.0)).item()
hidden_or_intermediate_states.tensors["model_forward_time"] = (
torch.tensor(model_forward_time + orig_model_forward_time))
return hidden_or_intermediate_states
logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata)
# Compute the logits in the last pipeline stage.
if not get_pp_group().is_last_rank:
if (self.is_driver_worker
and hidden_or_intermediate_states is not None
and isinstance(hidden_or_intermediate_states,
IntermediateTensors)
and self.observability_config is not None and
self.observability_config.collect_model_forward_time):
model_forward_end.synchronize()
model_forward_time = model_forward_start.elapsed_time(
model_forward_end)
orig_model_forward_time = 0.0
if intermediate_tensors is not None:
orig_model_forward_time = intermediate_tensors.tensors.get(
"model_forward_time", torch.tensor(0.0)).item()
hidden_or_intermediate_states.tensors[
"model_forward_time"] = (
torch.tensor(model_forward_time +
orig_model_forward_time))
return hidden_or_intermediate_states
# TODO: remove the synchronize here
torch.npu.synchronize()
logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata)
if not self.is_driver_worker:
return []
@@ -1348,6 +1430,9 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
hidden_states = hidden_or_intermediate_states.index_select(
0, indices)
output.prefill_hidden_states = hidden_or_intermediate_states
elif self.vllm_config.compilation_config.level == \
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
hidden_states = hidden_or_intermediate_states[:len(indices)]
else:
hidden_states = hidden_or_intermediate_states

View File

@@ -19,6 +19,7 @@
import gc
import os
import weakref
from typing import TYPE_CHECKING, Dict, List, Optional, Union
import numpy as np
@@ -47,8 +48,7 @@ from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm_ascend.attention.attention import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
AscendMetadata)
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.platform import NPUPlatform
if TYPE_CHECKING:
@@ -104,6 +104,27 @@ class NPUModelRunner:
raise NotImplementedError(
"Non-Attention backend is not supported by V1 NPUModelRunner.")
self.attn_backend = get_attn_backend(
self.head_size,
self.dtype,
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
use_mla=self.model_config.use_mla,
)
if self.attn_backend is None:
error_msg = (
f"Error with get_att_backend: {self.head_size=}, "
f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, "
f"{self.model_config.is_attention_free=}, "
f"{self.model_config.use_mla=}")
logger.error(error_msg)
raise NotImplementedError(
"Non-Attention backend is not supported by V1 GPUModelRunner.")
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
weakref.proxy(self))
# Multi-modal data support
self.input_registry = INPUT_REGISTRY
self.mm_registry = MULTIMODAL_REGISTRY
@@ -191,6 +212,12 @@ class NPUModelRunner:
pin_memory=True)
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32,
device="cpu",
pin_memory=True)
self.query_start_loc_np = self.query_start_loc_cpu.numpy()
self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
dtype=torch.int32,
device="cpu",
@@ -200,6 +227,8 @@ class NPUModelRunner:
self.input_positions_cpu = torch.arange(0,
self.max_num_tokens,
device="cpu")
self.attn_mask = None
self.attn_state = None
# NOTE: Pre-construct a mask matrix to improve the efficiency of
# attention mask construction during inference.
@@ -396,7 +425,11 @@ class NPUModelRunner:
num_reqs = self.input_batch.num_reqs
assert num_reqs > 0
# Copy the blocks from CPU to NPU.
modified_batch = self.attn_metadata_builder.reorder_batch(
self.input_batch, scheduler_output)
if modified_batch:
self.input_batch.refresh_sampling_metadata()
# OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations.
self.input_batch.block_table.commit(num_reqs)
@@ -430,14 +463,13 @@ class NPUModelRunner:
self.positions[:total_num_scheduled_tokens].copy_(
self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
positions = self.positions[:total_num_scheduled_tokens]
self.query_lens = torch.from_numpy(num_scheduled_tokens)
self.seq_lens_np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
seq_lens = self.seq_lens_cpu[:num_reqs]
query_lens = torch.from_numpy(num_scheduled_tokens)
block_table_indices = (req_indices * self.max_num_blocks_per_req +
positions_np // self.block_size)
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
@@ -446,8 +478,6 @@ class NPUModelRunner:
np.add(block_numbers * self.block_size,
block_offsets,
out=self.slot_mapping_np[:total_num_scheduled_tokens])
slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
self.device, non_blocking=True)
attn_state = AscendAttentionState.ChunkedPrefill
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
@@ -461,15 +491,14 @@ class NPUModelRunner:
query_lens=num_scheduled_tokens,
position=positions,
attn_state=attn_state)
self.attn_mask = attn_mask
self.attn_state = attn_state # type: ignore
attn_metadata = AscendMetadata(
seq_lens=query_lens,
context_lens=seq_lens,
slot_mapping=slot_mapping,
block_tables=(
self.input_batch.block_table.get_device_tensor()[:num_reqs]),
attn_mask=attn_mask,
attn_state=attn_state,
attn_metadata = self.attn_metadata_builder.build( # type: ignore
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
common_prefix_len=None,
)
# Prepare input_ids
@@ -804,6 +833,9 @@ class NPUModelRunner:
# different GPUs, and `kv_cache_config.num_blocks` is set to
# the min of all `num_blocks`. Verify it here.
assert num_blocks >= kv_cache_config.num_blocks
# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
# encounter OOM issue
num_blocks = num_blocks // 4
if isinstance(kv_cache_spec, FullAttentionSpec):
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,

View File

@@ -44,6 +44,7 @@ from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput)
from vllm_ascend.device_allocator.camem import CaMemAllocator
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import try_register_lib, vllm_version_is
from vllm_ascend.worker.model_runner import NPUModelRunner
@@ -313,8 +314,14 @@ class NPUWorker(LocalOrDistributedWorkerBase):
for ve in range(self.parallel_config.pipeline_parallel_size):
num_layers = len(self.cache_engine[ve].gpu_cache)
for i in range(num_layers):
torch_npu.npu_format_cast(self.cache_engine[ve].gpu_cache[i],
2)
if torch.is_tensor(self.cache_engine[ve].gpu_cache[i]):
torch_npu.npu_format_cast(
self.cache_engine[ve].gpu_cache[i], 2)
else:
torch_npu.npu_format_cast(
self.cache_engine[ve].gpu_cache[i][0], 2)
torch_npu.npu_format_cast(
self.cache_engine[ve].gpu_cache[i][1], 2)
self.gpu_cache = [
self.cache_engine[ve].gpu_cache
for ve in range(self.parallel_config.pipeline_parallel_size)
@@ -495,6 +502,7 @@ class NPUWorker(LocalOrDistributedWorkerBase):
backend: str = "hccl") -> None:
"""Initialize the distributed environment."""
parallel_config = self.parallel_config
additional_config = self.vllm_config.additional_config
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank,
@@ -502,6 +510,14 @@ class NPUWorker(LocalOrDistributedWorkerBase):
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
expert_tensor_parallel_size = 1
if additional_config is not None and hasattr(
additional_config, "expert_tensor_parallel_size"):
expert_tensor_parallel_size = getattr(
additional_config, "expert_tensor_parallel_size")
init_ascend_model_parallel(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
expert_tensor_parallel_size)
ensure_kv_transfer_initialized(vllm_config)

View File

@@ -38,6 +38,7 @@ from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.worker_base import WorkerBase
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import try_register_lib, vllm_version_is
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
@@ -209,6 +210,8 @@ class NPUWorker(WorkerBase):
def _init_worker_distributed_environment(self) -> None:
"""Initialize the distributed environment."""
additional_config = self.vllm_config.additional_config
parallel_config = self.vllm_config.parallel_config
set_custom_all_reduce(
not self.parallel_config.disable_custom_all_reduce)
init_distributed_environment(self.parallel_config.world_size,
@@ -217,6 +220,13 @@ class NPUWorker(WorkerBase):
ensure_model_parallel_initialized(
self.parallel_config.tensor_parallel_size,
self.parallel_config.pipeline_parallel_size)
expert_tensor_parallel_size = 1
if additional_config is not None and "expert_tensor_parallel_size" in additional_config:
expert_tensor_parallel_size = int(
additional_config["expert_tensor_parallel_size"])
init_ascend_model_parallel(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
expert_tensor_parallel_size)
ensure_kv_transfer_initialized(self.vllm_config)
def _init_profiler(self):