port deepseekv2 and mtp to main branch (#429)
### What this PR does / why we need it? This PR ports all the deepseek graph mode code and mtp code from v0.7.3 to the main branch --------- Signed-off-by: SidaoY <1024863041@qq.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Signed-off-by: Yizhou Liu <liuyizhou5@h-partners.com> Signed-off-by: mengwei805 <mengwei25@huawei.com> Signed-off-by: libaokui <libaokui@huawei.com> Signed-off-by: q00832892 <qiaoyang19@huawei.com> Signed-off-by: ganyi <pleaplusone.gy@gmail.com> Co-authored-by: SidaoY <1024863041@qq.com> Co-authored-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: Yizhou Liu <liuyizhou5@h-partners.com> Co-authored-by: mengwei805 <mengwei25@huawei.com> Co-authored-by: libaokui <libaokui@huawei.com>
This commit is contained in:
@@ -18,3 +18,4 @@ import vllm_ascend.ops.activation # noqa
|
||||
import vllm_ascend.ops.fused_moe # noqa
|
||||
import vllm_ascend.ops.layernorm # noqa
|
||||
import vllm_ascend.ops.rotary_embedding # noqa
|
||||
import vllm_ascend.ops.vocab_parallel_embedding # noqa
|
||||
|
||||
293
vllm_ascend/ops/attention.py
Normal file
293
vllm_ascend/ops/attention.py
Normal file
@@ -0,0 +1,293 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm/tests/kernels/test_moe.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
|
||||
|
||||
# Implementation of vanilla chunked prefill, should be removed after the kernel is ready for
|
||||
# all the corner case
|
||||
def vanilla_chunked_prefill(
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor, # (num_tokens, heads, head_size)
|
||||
key_cache: torch.Tensor, # (num_blocks, block_size, kv_heads, head_size)
|
||||
value_cache: torch.
|
||||
Tensor, # (num_blocks, block_size, kv_heads, head_size,)
|
||||
block_tables: torch.Tensor, # (num_seqs, max_num_blocks_per_seq)
|
||||
cu_seqlen_q: torch.Tensor, # (num_seqs + 1,)
|
||||
cu_seqlen_k: torch.Tensor, # (num_seqs + 1,)
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
causal: bool = True,
|
||||
) -> None:
|
||||
num_query_heads = query.shape[1]
|
||||
head_dim = value_cache.shape[3]
|
||||
num_kv_heads = value_cache.shape[2]
|
||||
block_size = value_cache.shape[1]
|
||||
num_batch = cu_seqlen_q.shape[0] - 1
|
||||
max_num_blocks_per_seq = block_tables.shape[1]
|
||||
|
||||
key = key_cache[block_tables].view(num_batch,
|
||||
max_num_blocks_per_seq * block_size,
|
||||
num_kv_heads, head_dim)
|
||||
|
||||
value = value_cache[block_tables].view(num_batch,
|
||||
max_num_blocks_per_seq * block_size,
|
||||
num_kv_heads, head_dim)
|
||||
key = key[:, :max_seqlen_k, :, :]
|
||||
value = value[:, :max_seqlen_k, :, :]
|
||||
|
||||
seqlen_k = cu_seqlen_k[1:] - cu_seqlen_k[:-1]
|
||||
seqlen_q = cu_seqlen_q[1:] - cu_seqlen_q[:-1]
|
||||
seqlen_q = seqlen_q.view(-1, 1)
|
||||
seqlen_k = seqlen_k.view(-1, 1)
|
||||
seqlen_diff = seqlen_k - seqlen_q
|
||||
q_idx_mask = (torch.arange(0, max_seqlen_q,
|
||||
device="npu").view(1, -1).repeat(num_batch, 1))
|
||||
k_idx_mask = (torch.arange(0, max_seqlen_k,
|
||||
device="npu").view(1, -1).repeat(num_batch, 1))
|
||||
q_mask = q_idx_mask < seqlen_q
|
||||
k_mask = k_idx_mask < seqlen_k
|
||||
|
||||
# calculate idx for causal mask of query [batch, max_seqlen_q]
|
||||
causal_mask_idx = (q_idx_mask + seqlen_diff)[q_mask]
|
||||
|
||||
# generate causal mask [batch, max_seqlen_q, max_seqlen_k]
|
||||
tril_mask = torch.tril(torch.ones(max_seqlen_k, max_seqlen_k,
|
||||
device="npu"))
|
||||
tril_mask[tril_mask == 0] = float("-inf")
|
||||
tril_mask[tril_mask == 1] = 0
|
||||
causal_mask = tril_mask[causal_mask_idx]
|
||||
causal_mask_padding = torch.empty([num_batch, max_seqlen_q, max_seqlen_k],
|
||||
device="npu").fill_(float("-inf"))
|
||||
causal_mask_padding[q_mask] = causal_mask
|
||||
# to [batch, num_heads, max_seqlen_q, max_seqlen_k]
|
||||
causal_mask_padding = causal_mask_padding.unsqueeze(1)
|
||||
|
||||
pad_q = torch.zeros(
|
||||
[num_batch, max_seqlen_q, num_query_heads, head_dim],
|
||||
device="npu",
|
||||
dtype=query.dtype,
|
||||
)
|
||||
pad_k = torch.zeros(
|
||||
[num_batch, max_seqlen_k, num_kv_heads, head_dim],
|
||||
device="npu",
|
||||
dtype=key.dtype,
|
||||
)
|
||||
pad_v = torch.zeros(
|
||||
[num_batch, max_seqlen_k, num_kv_heads, head_dim],
|
||||
device="npu",
|
||||
dtype=value.dtype,
|
||||
)
|
||||
pad_q[q_mask] = query
|
||||
pad_k[k_mask] = key[k_mask]
|
||||
pad_v[k_mask] = value[k_mask]
|
||||
|
||||
if num_query_heads > num_kv_heads:
|
||||
pad_k = pad_k.view(
|
||||
[num_batch, max_seqlen_k, num_kv_heads, 1, head_dim])
|
||||
pad_k = pad_k.repeat(1, 1, 1, num_query_heads // num_kv_heads, 1).view(
|
||||
[num_batch, max_seqlen_k, num_query_heads, head_dim])
|
||||
pad_v = pad_v.view(
|
||||
[num_batch, max_seqlen_k, num_kv_heads, 1, head_dim])
|
||||
pad_v = pad_v.repeat(1, 1, 1, num_query_heads // num_kv_heads, 1).view(
|
||||
[num_batch, max_seqlen_k, num_query_heads, head_dim])
|
||||
# permute to [b, h, n, k]
|
||||
pad_q = pad_q.permute(0, 2, 1, 3)
|
||||
pad_k = pad_k.permute(0, 2, 1, 3)
|
||||
pad_v = pad_v.permute(0, 2, 1, 3)
|
||||
attn_mask = torch.empty([num_batch, 1, 1, max_seqlen_k],
|
||||
device="npu").fill_(float("-inf"))
|
||||
attn_mask[:, :, :, :max_seqlen_k].masked_fill_(k_mask[:, None, None, :], 0)
|
||||
# [b, h, f, t]
|
||||
attn_weights = torch.einsum("bhqd,bhkd->bhqk", pad_q, pad_k)
|
||||
attn_weights *= scale
|
||||
attn_mask = attn_mask.float()
|
||||
attn_weights = attn_weights + attn_mask
|
||||
if causal:
|
||||
attn_weights = attn_weights + causal_mask_padding
|
||||
|
||||
attn_weights = torch.softmax(attn_weights, dim=-1)
|
||||
attn_output = torch.einsum("bhqk,bhkd->bhqd", attn_weights, pad_v.float())
|
||||
attn_output = attn_output.permute(0, 2, 1, 3)
|
||||
|
||||
attn_output = (attn_output[q_mask].view([-1, num_query_heads,
|
||||
head_dim]).to(output.dtype))
|
||||
output.copy_(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
def vanilla_chunked_prefill_mla(
|
||||
output: torch.Tensor, # (num_tokens, num_heads, v_head_dim)
|
||||
query: torch.Tensor, # (num_tokens, num_heads, nope_dim + rope_dim)
|
||||
kv_cache: torch.Tensor, # (num_blocks, block_size, latent_kv)
|
||||
block_tables: torch.Tensor, # (batch_size, max_num_blocks_per_seq)
|
||||
query_lens: torch.Tensor, # (batch_size)
|
||||
context_lens: torch.Tensor, # (batch_size)
|
||||
kv_b_proj: ColumnParallelLinear, # ()
|
||||
max_query_len: int,
|
||||
max_context_len: int,
|
||||
nope_dim: int,
|
||||
rope_dim: int,
|
||||
v_head_dim: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
causal: bool = True) -> None:
|
||||
batch_size = block_tables.size(0)
|
||||
assert query_lens.size(0) == batch_size
|
||||
num_heads = query.size(1)
|
||||
block_size = kv_cache.size(1)
|
||||
latent_kv_dim = kv_cache.size(3) - rope_dim
|
||||
max_num_blocks_per_seq = block_tables.size(1)
|
||||
batch_size = query_lens.size(0)
|
||||
kv_cache = kv_cache.squeeze()
|
||||
# select kv_c out as [batch_size, max_context_len, latent_kv + rope_dim]
|
||||
cache_kv_c_pe = kv_cache[block_tables].view(
|
||||
batch_size, max_num_blocks_per_seq * block_size,
|
||||
latent_kv_dim + rope_dim)[:, :max_context_len, :]
|
||||
# get kv_c and k_pe
|
||||
# cached_kv_c: [batch_size, max_context_len, latent_kv]
|
||||
# cached_k_pe: [batch_size, max_context_len, rope_dim]
|
||||
cache_kv_c = cache_kv_c_pe[:, :, :latent_kv_dim]
|
||||
cache_k_pe = cache_kv_c_pe[:, :, latent_kv_dim:]
|
||||
# get k_rope and v
|
||||
# k_nope: [batch_size, max_context_len, num_heads, nope_dim]
|
||||
# value: [batch_size, max_context_len, num_heads, v_head_dim]
|
||||
k_nope, value = kv_b_proj(cache_kv_c)[0].view(
|
||||
batch_size, max_context_len, num_heads,
|
||||
nope_dim + v_head_dim).split([nope_dim, v_head_dim], dim=-1)
|
||||
# key: [batch_size, max_context_len, num_hads, rope_dim + nope_dim]
|
||||
key = torch.cat(
|
||||
[k_nope, cache_k_pe.unsqueeze(2).expand(-1, -1, num_heads, -1)],
|
||||
dim=-1)
|
||||
|
||||
context_lens = context_lens.view(-1, 1).to("npu")
|
||||
query_lens = query_lens.view(-1, 1).to("npu")
|
||||
seq_diff = context_lens - query_lens
|
||||
|
||||
q_idx_mask = (torch.arange(0, max_query_len,
|
||||
device="npu").view(1, -1).repeat(batch_size, 1))
|
||||
kv_c_idx_mask = (torch.arange(0, max_context_len,
|
||||
device="npu").view(1,
|
||||
-1).repeat(batch_size, 1))
|
||||
kv_c_mask = kv_c_idx_mask < context_lens
|
||||
q_mask = q_idx_mask < query_lens
|
||||
|
||||
# calculate idx for causal mask of query [batch, max_seqlen_q]
|
||||
causal_mask_idx = (q_idx_mask + seq_diff)[q_mask]
|
||||
|
||||
# generate causal mask [batch, max_seqlen_q, max_seqlen_k]
|
||||
tril_mask = torch.tril(
|
||||
torch.ones(max_context_len, max_context_len, device="npu"))
|
||||
tril_mask[tril_mask == 0] = float("-inf")
|
||||
tril_mask[tril_mask == 1] = 0
|
||||
causal_mask = tril_mask[causal_mask_idx]
|
||||
causal_mask_padding = torch.empty(
|
||||
[batch_size, max_query_len, max_context_len],
|
||||
device="npu").fill_(float("-inf"))
|
||||
causal_mask_padding[q_mask] = causal_mask
|
||||
# to [batch, num_heads, max_seqlen_q, max_seqlen_k]
|
||||
causal_mask_padding = causal_mask_padding.unsqueeze(1)
|
||||
|
||||
pad_q = torch.zeros(
|
||||
[batch_size, max_query_len, num_heads, rope_dim + nope_dim],
|
||||
device="npu",
|
||||
dtype=query.dtype,
|
||||
)
|
||||
pad_k = torch.zeros(
|
||||
[batch_size, max_context_len, num_heads, rope_dim + nope_dim],
|
||||
device="npu",
|
||||
dtype=key.dtype,
|
||||
)
|
||||
pad_v = torch.zeros(
|
||||
[batch_size, max_context_len, num_heads, v_head_dim],
|
||||
device="npu",
|
||||
dtype=value.dtype,
|
||||
)
|
||||
pad_q[q_mask] = query
|
||||
pad_k[kv_c_mask] = key[kv_c_mask]
|
||||
pad_v[kv_c_mask] = value[kv_c_mask]
|
||||
|
||||
pad_q = pad_q.permute(0, 2, 1, 3)
|
||||
pad_k = pad_k.permute(0, 2, 1, 3)
|
||||
pad_v = pad_v.permute(0, 2, 1, 3)
|
||||
attn_mask = torch.empty([batch_size, 1, 1, max_context_len],
|
||||
device="npu").fill_(float("-inf"))
|
||||
attn_mask[:, :, :, :max_context_len].masked_fill_(
|
||||
kv_c_mask[:, None, None, :], 0)
|
||||
# [b, h, f, t]
|
||||
attn_weights = torch.einsum("bhqd,bhkd->bhqk", pad_q, pad_k)
|
||||
attn_weights *= scale
|
||||
attn_mask = attn_mask.float()
|
||||
attn_weights = attn_weights + attn_mask
|
||||
if causal:
|
||||
attn_weights = attn_weights + causal_mask_padding
|
||||
|
||||
attn_weights = torch.softmax(attn_weights, dim=-1)
|
||||
attn_output = torch.einsum("bhqk,bhkd->bhqd", attn_weights, pad_v.float())
|
||||
attn_output = attn_output.permute(0, 2, 1, 3)
|
||||
|
||||
attn_output = (attn_output[q_mask].view([-1, num_heads,
|
||||
v_head_dim]).to(output.dtype))
|
||||
output.copy_(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
def vanilla_decode_mla(
|
||||
query: torch.Tensor, # [num_tokens, num_heads, latent_dim + rope_dim]
|
||||
key_cache: torch.
|
||||
Tensor, # [num_blocks, block_size, num_kv_heads, latent_dim + rope_dim]
|
||||
num_kv_heads: int,
|
||||
num_heads: int,
|
||||
scale: float,
|
||||
block_table: torch.Tensor, # [batch_size, max_block_size]
|
||||
context_lens: List[int],
|
||||
mla_vhead_size: int,
|
||||
rope_dim: int,
|
||||
output: torch.Tensor):
|
||||
batch_size = block_table.size()[0]
|
||||
max_block_size = block_table.size()[1]
|
||||
reduce_dim = key_cache.size()[-1]
|
||||
block_size = key_cache.size()[1]
|
||||
latent_dim = reduce_dim - rope_dim
|
||||
kv_c_and_pe = key_cache[block_table].view(
|
||||
[batch_size, max_block_size * block_size, num_kv_heads, reduce_dim])
|
||||
max_context_len = max(context_lens)
|
||||
context_lens = torch.tensor(context_lens, device="npu").view(batch_size, 1)
|
||||
# [batch_size, max_context_len, num_kv_heads, latent_dim + rope_dim]
|
||||
# since the kv head is 1 in deepseek, we use expand here for perf
|
||||
kv_c_and_pe = kv_c_and_pe[:, :max_context_len, :, :].expand(
|
||||
-1, -1, num_heads, 1)
|
||||
kv_c = kv_c_and_pe[..., :latent_dim]
|
||||
kv_idx_mask = (torch.arange(0, max_context_len,
|
||||
device="npu").view(1,
|
||||
-1).repeat(batch_size, 1))
|
||||
# [batch_size, max_context_len]
|
||||
kv_idx_mask = kv_idx_mask < context_lens
|
||||
query = query.unsqueeze(1)
|
||||
attn_weights = torch.einsum("bqhd,bkhd->bhqk", query, kv_c_and_pe)
|
||||
attn_weights *= scale
|
||||
attn_weights = attn_weights + kv_idx_mask[:, -1, -1, :].float()
|
||||
attn_weights = torch.softmax(attn_weights, dim=-1)
|
||||
attn_output = torch.einsum("bhqk,bkhd->bqhd", attn_weights,
|
||||
kv_c.float()).view(-1, num_heads, latent_dim)
|
||||
output.copy_(attn_output)
|
||||
return output
|
||||
35
vllm_ascend/ops/cache.py
Normal file
35
vllm_ascend/ops/cache.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm/tests/kernels/test_moe.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def concat_and_cache_mla(
|
||||
kv_c_normed: torch.Tensor, # [num_tokens, num_kv_head, nope]
|
||||
k_pe: torch.Tensor, # [num_tokens, num_kv_head, rope]
|
||||
kv_cache: torch.
|
||||
Tensor, # [num_blocks, block_size, num_kv_head, nope + rope]
|
||||
slot_mapping, # [num_tokens]
|
||||
):
|
||||
num_blocks = kv_cache.size()[0]
|
||||
block_size = kv_cache.size()[1]
|
||||
num_kv_head = k_pe.size()[1]
|
||||
|
||||
idx_for_copy = slot_mapping // block_size * block_size + slot_mapping % block_size
|
||||
kv_cache = kv_cache.view(num_blocks * block_size, num_kv_head, -1)
|
||||
kv_cache[idx_for_copy] = torch.cat([kv_c_normed.unsqueeze(1), k_pe],
|
||||
dim=-1)
|
||||
@@ -15,12 +15,131 @@
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm/tests/kernels/test_moe.py
|
||||
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch_npu
|
||||
from vllm.model_executor.layers.fused_moe.layer import \
|
||||
UnquantizedFusedMoEMethod
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
|
||||
from vllm.model_executor.layers.quantization.base_config import \
|
||||
QuantizeMethodBase
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
|
||||
|
||||
|
||||
def fused_experts_with_mc2(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
moe_all_to_all_group_name: Optional[str] = None,
|
||||
) -> torch.Tensor:
|
||||
global_bs = 0
|
||||
moe_expert_num = len(expert_map)
|
||||
kwargs = {
|
||||
"x": hidden_states,
|
||||
"expert_ids": topk_ids,
|
||||
"expert_shard_type": 0,
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": moe_expert_num,
|
||||
"global_bs": global_bs,
|
||||
}
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
|
||||
quant_mode = 0
|
||||
ep_group = get_ep_group().device_group
|
||||
local_rank = torch.distributed.get_rank(group=ep_group)
|
||||
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
|
||||
|
||||
world_szie = torch.distributed.get_world_size()
|
||||
tp_size = world_szie // all_to_all_group_size
|
||||
tp_rank = rank % tp_size
|
||||
|
||||
stage1_kwargs = {
|
||||
"scales": None,
|
||||
"quant_mode": quant_mode,
|
||||
"group_ep": moe_all_to_all_group_name,
|
||||
"ep_world_size": all_to_all_group_size,
|
||||
"ep_rank_id": local_rank,
|
||||
# "group_tp": self.moe_rs_group_name,
|
||||
"group_tp": moe_all_to_all_group_name,
|
||||
"tp_world_size": tp_size,
|
||||
"tp_rank_id": tp_rank,
|
||||
}
|
||||
kwargs.update(stage1_kwargs)
|
||||
|
||||
output = torch_npu.npu_moe_distribute_dispatch(**kwargs)
|
||||
# comm_stream.wait_stream(torch.npu.current_stream())
|
||||
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
|
||||
0:5]
|
||||
|
||||
w1 = w1.transpose(1, 2)
|
||||
expert_token_nums = torch.cumsum(expert_token_nums,
|
||||
dim=0,
|
||||
dtype=torch.int64)
|
||||
group_list = expert_token_nums.to(torch.int64)
|
||||
gate_up_out_list = torch_npu.npu_grouped_matmul(
|
||||
x=[expand_x],
|
||||
weight=[w1],
|
||||
split_item=2,
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)
|
||||
|
||||
# TODO: Remove this in the future.
|
||||
gate_up_out = torch.cat(gate_up_out_list, dim=0)
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
||||
|
||||
w2 = w2.transpose(1, 2)
|
||||
down_out_list = torch_npu.npu_grouped_matmul(
|
||||
x=[gate_up_out],
|
||||
weight=[w2],
|
||||
split_item=2,
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)
|
||||
|
||||
down_out_list = torch.cat(down_out_list, dim=0)
|
||||
|
||||
# moeCombine
|
||||
kwargs = {
|
||||
"expand_x": down_out_list,
|
||||
"expert_ids": topk_ids,
|
||||
"expand_idx": expand_idx,
|
||||
"expert_scales": topk_weights.to(torch.float32),
|
||||
"expert_shard_type": 0,
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": moe_expert_num,
|
||||
"global_bs": 0,
|
||||
}
|
||||
tp_recv_counts = output[5]
|
||||
stage3_kwargs = {
|
||||
"ep_send_counts": ep_recv_counts,
|
||||
"group_ep": moe_all_to_all_group_name,
|
||||
"ep_world_size": all_to_all_group_size,
|
||||
"ep_rank_id": local_rank,
|
||||
"tp_send_counts": tp_recv_counts,
|
||||
# "group_tp": self.moe_rs_group_name,
|
||||
"group_tp": moe_all_to_all_group_name,
|
||||
"tp_world_size": tp_size,
|
||||
"tp_rank_id": tp_rank,
|
||||
}
|
||||
kwargs.update(stage3_kwargs)
|
||||
|
||||
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def fused_experts(
|
||||
@@ -47,22 +166,27 @@ def fused_experts(
|
||||
Returns:
|
||||
hidden_states: Hidden states after routing.
|
||||
"""
|
||||
"""
|
||||
# Check constraints.
|
||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||
"""
|
||||
# if torch.distributed.get_rank() == 0:
|
||||
# print(w1.shape)
|
||||
# print(hidden_states.shape)
|
||||
|
||||
original_shape = hidden_states.shape
|
||||
assert len(original_shape) == 2
|
||||
# assert len(original_shape) == 2
|
||||
|
||||
num_tokens = hidden_states.shape[:-1].numel()
|
||||
num_experts = w1.shape[0]
|
||||
dtype = hidden_states.dtype
|
||||
device = hidden_states.device
|
||||
assert dtype in [torch.float32, torch.float16, torch.bfloat16
|
||||
], "Only float32, float16, and bfloat16 are supported"
|
||||
# assert dtype in [torch.float32, torch.float16, torch.bfloat16
|
||||
# ], "Only float32, float16, and bfloat16 are supported"
|
||||
|
||||
if expert_map is not None:
|
||||
# Generate token indices and flatten
|
||||
@@ -152,11 +276,18 @@ def fused_experts(
|
||||
final_hidden_states = torch.zeros(*original_shape,
|
||||
device=hidden_states.device,
|
||||
dtype=dtype)
|
||||
final_hidden_states.index_add_(0, sorted_token_indices,
|
||||
weighted_down_out)
|
||||
# TODO: This should not happen! Look into it!
|
||||
# fill nan with 0.0
|
||||
final_hidden_states[torch.isnan(final_hidden_states)] = 0.0
|
||||
|
||||
# TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
|
||||
# This created multiple NaN and index_add_ will mix them up which harms accracy
|
||||
# remove this mask and filter after it being fixed
|
||||
num_valid_tokens = mask.sum()
|
||||
valid_token_mask = torch.arange(
|
||||
0, sorted_token_indices.shape[0],
|
||||
device=device).unsqueeze(1) < num_valid_tokens
|
||||
valid_output = torch.where(
|
||||
valid_token_mask, weighted_down_out,
|
||||
torch.zeros_like(weighted_down_out)).to(dtype)
|
||||
final_hidden_states.index_add_(0, sorted_token_indices, valid_output)
|
||||
else:
|
||||
# TODO: Reorder device memory 2 times here, replace the current
|
||||
# implementation here when suitable operators become available.
|
||||
@@ -199,16 +330,17 @@ def native_grouped_topk(
|
||||
|
||||
|
||||
def select_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: Optional[bool] = True
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Select top-k experts based on router logits.
|
||||
@@ -232,8 +364,23 @@ def select_experts(
|
||||
Raises:
|
||||
ValueError: If an unsupported scoring function is provided.
|
||||
"""
|
||||
assert hidden_states.shape[0] == router_logits.shape[0], (
|
||||
"Number of tokens mismatch")
|
||||
# assert hidden_states.shape[0] == router_logits.shape[0], (
|
||||
# "Number of tokens mismatch")
|
||||
# if os.environ.get("VLLM_ENABLE_GRAPH_MODE") == "1" and not is_prefill:
|
||||
# topk_weight, topk_idx, _ = torch.ops.npu_inference.npu_moe_gating_top_k(
|
||||
# router_logits,
|
||||
# k=top_k, # topk当前写8
|
||||
# bias=e_score_correction_bias,
|
||||
# k_group=topk_group, # fix: 4
|
||||
# group_count=num_expert_group, # fix 8
|
||||
# group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
|
||||
# renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
||||
# norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
||||
# # out_flag=False, # todo new api; 第三个输出是否输出
|
||||
# # y2_flag=False, # old api; 第三个输出是否输出
|
||||
# routed_scaling_factor=1,
|
||||
# eps=float(1e-20))
|
||||
# return topk_weight, topk_idx
|
||||
|
||||
if custom_routing_function is not None:
|
||||
raise NotImplementedError(
|
||||
@@ -261,14 +408,16 @@ def select_experts(
|
||||
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
|
||||
topk_weights = native_grouped_topk(topk_weights, num_expert_group,
|
||||
topk_group)
|
||||
|
||||
# TODO bfloat16 is not supported in torch.topk with ge graph.
|
||||
if e_score_correction_bias is not None:
|
||||
topk_ids = torch.topk(topk_weights, k=top_k, dim=-1,
|
||||
topk_ids = torch.topk(topk_weights.to(torch.float32),
|
||||
k=top_k,
|
||||
dim=-1,
|
||||
sorted=False)[1]
|
||||
# Use original unbiased scores for the routing weights
|
||||
topk_weights = original_weights.gather(1, topk_ids)
|
||||
else:
|
||||
topk_weights, topk_ids = torch.topk(topk_weights,
|
||||
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
|
||||
k=top_k,
|
||||
dim=-1,
|
||||
sorted=False)
|
||||
@@ -285,46 +434,245 @@ def select_experts(
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts, "Number of global experts mismatch"
|
||||
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
return fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map)
|
||||
ep_group = get_ep_group()
|
||||
self.ep_size = ep_group.world_size
|
||||
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
self.local_batch_size = self.global_batch_size // self.ep_size
|
||||
|
||||
try:
|
||||
device_group = ep_group.device_group
|
||||
# TODO: Try local_rank = ep_group.rank_in_group
|
||||
local_rank = torch.distributed.get_rank(group=device_group)
|
||||
backend = device_group._get_backend(torch.device("npu"))
|
||||
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
|
||||
local_rank)
|
||||
except AttributeError:
|
||||
self.moe_all_to_all_group_name = None
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
super(UnquantizedFusedMoEMethod,
|
||||
self).process_weights_after_loading(layer)
|
||||
layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight(
|
||||
layer.w13_weight.data),
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight(
|
||||
layer.w2_weight.data),
|
||||
requires_grad=False)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill=False,
|
||||
**kwargs,
|
||||
):
|
||||
# assert router_logits.shape[
|
||||
# 1] == global_num_experts, "Number of global experts mismatch"
|
||||
# set prefill as false always, should fix this
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
is_prefill=is_prefill)
|
||||
|
||||
if os.environ.get("VLLM_ENABLE_MC2") == "1" and not is_prefill:
|
||||
return fused_experts_with_mc2(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map,
|
||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
|
||||
else:
|
||||
return fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map)
|
||||
|
||||
|
||||
UnquantizedFusedMoEMethod.forward_oot = forward_oot
|
||||
class AscendFusedMoE(FusedMoE):
|
||||
|
||||
def __init__(self,
|
||||
num_experts,
|
||||
top_k,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
params_dtype=None,
|
||||
reduce_results=False,
|
||||
renormalize=True,
|
||||
use_grouped_topk=False,
|
||||
num_expert_group=None,
|
||||
topk_group=None,
|
||||
quant_config=None,
|
||||
tp_size=None,
|
||||
ep_size=None,
|
||||
dp_size=None,
|
||||
prefix="",
|
||||
custom_routing_function=None,
|
||||
scoring_func="softmax",
|
||||
e_score_correction_bias=None,
|
||||
activation="silu"):
|
||||
super(FusedMoE, self).__init__()
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
|
||||
self.ep_size = get_ep_group().world_size
|
||||
self.tp_size = get_etp_group().world_size
|
||||
self.dp_size = (dp_size
|
||||
if dp_size is not None else get_dp_group().world_size)
|
||||
self.dp_rank = (0
|
||||
if self.dp_size == 1 else get_dp_group().rank_in_group)
|
||||
|
||||
self.top_k = top_k
|
||||
self.num_experts = num_experts
|
||||
self.global_num_experts = num_experts
|
||||
assert intermediate_size % self.tp_size == 0
|
||||
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
||||
self.reduce_results = reduce_results
|
||||
self.renormalize = renormalize
|
||||
self.use_grouped_topk = use_grouped_topk
|
||||
if self.use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
self.num_expert_group = num_expert_group
|
||||
self.topk_group = topk_group
|
||||
self.custom_routing_function = custom_routing_function
|
||||
self.scoring_func = scoring_func
|
||||
self.e_score_correction_bias = e_score_correction_bias
|
||||
self.expert_map = None
|
||||
self.activation = activation
|
||||
|
||||
if self.ep_size > 1:
|
||||
# Create a tensor of size num_experts filled with -1
|
||||
self.local_num_experts, self.expert_map = determine_expert_map(
|
||||
self.ep_size,
|
||||
get_ep_group().rank_in_group, self.global_num_experts)
|
||||
self.tp_rank = get_etp_group().rank_in_group
|
||||
self.ep_rank = get_ep_group().rank_in_group
|
||||
else:
|
||||
# Adjust TP size for DP attention
|
||||
# haven't test its functionality yet, may remove in the future
|
||||
self.tp_rank = self.tp_size * self.dp_rank
|
||||
self.ep_rank = 0
|
||||
self.tp_size = self.tp_size * self.dp_size
|
||||
self.ep_size = 1
|
||||
self.local_num_experts = self.global_num_experts
|
||||
self.expert_map = None
|
||||
|
||||
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
||||
raise ValueError("Only softmax scoring function is supported for "
|
||||
"non-grouped topk.")
|
||||
|
||||
if quant_config is None:
|
||||
self.quant_method: Optional[QuantizeMethodBase] = (
|
||||
AscendUnquantizedFusedMoEMethod())
|
||||
else:
|
||||
self.quant_method = quant_config.get_quant_method(self, prefix)
|
||||
assert self.quant_method is not None
|
||||
|
||||
local_num_experts = torch.sum(self.expert_map != -1) \
|
||||
if self.expert_map is not None else num_experts
|
||||
|
||||
moe_quant_params = {
|
||||
"num_experts": local_num_experts,
|
||||
"hidden_size": hidden_size,
|
||||
"intermediate_size_per_partition":
|
||||
self.intermediate_size_per_partition,
|
||||
"params_dtype": params_dtype,
|
||||
"weight_loader": self.weight_loader,
|
||||
}
|
||||
# need full intermediate size pre-sharding for WNA16 act order
|
||||
if (self.quant_method.__class__.__name__
|
||||
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
|
||||
moe_quant_params["intermediate_size_full"] = intermediate_size
|
||||
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_prefill: bool,
|
||||
top_k=None):
|
||||
assert self.quant_method is not None
|
||||
|
||||
if top_k:
|
||||
real_top_k = top_k
|
||||
else:
|
||||
real_top_k = self.top_k
|
||||
|
||||
if self.dp_size > 1:
|
||||
if int(os.environ.get("VLLM_ENABLE_MC2") # type: ignore
|
||||
) == 1 and not is_prefill:
|
||||
...
|
||||
elif int(os.environ.get("USING_LCCL_COM")) == 1: # type: ignore
|
||||
hidden_states = get_dp_group().all_gather(
|
||||
hidden_states, 0, False)
|
||||
router_logits = get_dp_group().all_gather(
|
||||
router_logits, 0, False)
|
||||
else:
|
||||
hidden_states = get_dp_group().all_gather(hidden_states, 0)
|
||||
router_logits = get_dp_group().all_gather(router_logits, 0)
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=real_top_k,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
global_num_experts=self.num_experts,
|
||||
expert_map=self.expert_map,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
scoring_func=self.scoring_func,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
is_prefill=is_prefill)
|
||||
|
||||
if self.dp_size > 1:
|
||||
if int(os.environ.get("VLLM_ENABLE_MC2") # type: ignore
|
||||
) == 1 and not is_prefill:
|
||||
...
|
||||
else:
|
||||
final_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
|
||||
final_hidden_states,
|
||||
"sum",
|
||||
scatter_dim=0,
|
||||
group=get_dp_group().device_group)
|
||||
|
||||
# if self.reduce_results and self.tp_size > 1:
|
||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -38,7 +39,7 @@ def rope_forward_oot(
|
||||
if self.cos_sin_cache.dtype != query.dtype:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
|
||||
# adopt custom kernel path for rotary_embedding
|
||||
if CUSTOM_OP_ENABLED and self.is_neox_style:
|
||||
if CUSTOM_OP_ENABLED and self.is_neox_style and self.head_size % 32 == 0:
|
||||
return torch.ops._C.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
@@ -66,5 +67,169 @@ def rope_forward_oot(
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
|
||||
|
||||
def native_rope_deepseek_forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# seq_len = positions.max() + 1
|
||||
seq_len = self.max_position_embeddings
|
||||
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
# if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
|
||||
# self._set_cos_sin_cache(seq_len=seq_len, device=query.device, dtype=query.dtype)
|
||||
self._set_cos_sin_cache(seq_len=seq_len,
|
||||
device=query.device,
|
||||
dtype=query.dtype)
|
||||
|
||||
cos = self.cos_cached[:seq_len].to(dtype=query.dtype)
|
||||
sin = self.sin_cached[:seq_len].to(dtype=query.dtype)
|
||||
|
||||
q_pe, k_pe = apply_rotary_pos_emb(query, key, cos, sin, positions)
|
||||
|
||||
return q_pe, k_pe
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., :x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2:]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Inverse dim formula to find dim based on number of rotations
|
||||
def yarn_find_correction_dim(num_rotations,
|
||||
dim,
|
||||
base=10000,
|
||||
max_position_embeddings=2048):
|
||||
# Note: use torch instead of math to solve MTP compilation error.
|
||||
return (dim * torch.log(
|
||||
torch.tensor(max_position_embeddings) /
|
||||
(num_rotations * 2 * torch.pi))) / (2 * torch.log(torch.tensor(base)))
|
||||
|
||||
|
||||
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
|
||||
# Find dim range bounds based on rotations
|
||||
def yarn_find_correction_range(low_rot,
|
||||
high_rot,
|
||||
dim,
|
||||
base=10000,
|
||||
max_position_embeddings=2048):
|
||||
# Note: use torch instead of math to solve MTP compilation error.
|
||||
low = torch.floor(
|
||||
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
||||
high = torch.ceil(
|
||||
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
|
||||
# Note: use torch instead of max/min to solve MTP compilation error.
|
||||
return torch.clamp(low, min=0), torch.clamp(high, max=dim - 1)
|
||||
|
||||
|
||||
def yarn_linear_ramp_mask(min_value, max_value, dim):
|
||||
# Note: The if conditional branch is not used here
|
||||
# to solve MTP compilation error.
|
||||
max_value += (min_value == max_value).float() * 0.001
|
||||
linear_func = (torch.arange(dim, dtype=torch.float32) -
|
||||
min_value) / (max_value - min_value)
|
||||
ramp_func = torch.clamp(linear_func, 0, 1)
|
||||
return ramp_func
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`):
|
||||
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
||||
used to pass offsetted position ids when working with a KV-cache.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos[position_ids]
|
||||
sin = sin[position_ids]
|
||||
cos = cos[:, None, None, :]
|
||||
sin = sin[:, None, None, :]
|
||||
|
||||
if len(q.shape) == 3:
|
||||
q = q[:, :, None, :]
|
||||
if len(k.shape) == 2:
|
||||
k = k[:, None, None, :]
|
||||
elif len(k.shape) == 3:
|
||||
k = k[:, :, None, :]
|
||||
|
||||
b, h_q, s, d = q.shape
|
||||
q = q.view(b, h_q, s, d // 2, 2).transpose(4, 3).reshape(b, h_q, s, d)
|
||||
|
||||
b, h_k, s, d = k.shape
|
||||
k = k.view(b, h_k, s, d // 2, 2).transpose(4, 3).reshape(b, h_k, s, d)
|
||||
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
|
||||
q_embed = q_embed.view(b, h_q, d)
|
||||
k_embed = k_embed.view(b, h_k, d)
|
||||
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
seq_len = self.max_position_embeddings
|
||||
self.max_seq_len_cached = seq_len
|
||||
dim = self.rotary_dim
|
||||
|
||||
freq_extra = 1.0 / (self.base**(
|
||||
torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
||||
freq_inter = 1.0 / (self.scaling_factor * self.base**(
|
||||
torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
||||
|
||||
low, high = yarn_find_correction_range(
|
||||
self.beta_fast,
|
||||
self.beta_slow,
|
||||
dim,
|
||||
self.base,
|
||||
self.max_position_embeddings,
|
||||
)
|
||||
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
|
||||
device=device, dtype=torch.float32)
|
||||
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
t = torch.arange(seq_len, device=device, dtype=torch.float32)
|
||||
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
|
||||
# _mscale = float(
|
||||
# yarn_get_mscale(self.scaling_factor, self.mscale)
|
||||
# / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
|
||||
# )
|
||||
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", (emb.cos() * self.mscale).to(dtype),
|
||||
persistent=False)
|
||||
self.register_buffer("sin_cached", (emb.sin() * self.mscale).to(dtype),
|
||||
persistent=False)
|
||||
|
||||
|
||||
# TODO: Patch when aclnn ops avaiable
|
||||
RotaryEmbedding.forward_oot = rope_forward_oot
|
||||
DeepseekScalingRotaryEmbedding.forward = rope_forward_oot
|
||||
# DeepseekScalingRotaryEmbedding.forward = rope_deepseek_forward_oot
|
||||
DeepseekScalingRotaryEmbedding.forward = native_rope_deepseek_forward
|
||||
DeepseekScalingRotaryEmbedding._set_cos_sin_cache = _set_cos_sin_cache
|
||||
DeepseekScalingRotaryEmbedding.max_seq_len_cached = None
|
||||
|
||||
67
vllm_ascend/ops/vocab_parallel_embedding.py
Normal file
67
vllm_ascend/ops/vocab_parallel_embedding.py
Normal file
@@ -0,0 +1,67 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import \
|
||||
VocabParallelEmbedding
|
||||
|
||||
|
||||
def get_masked_input_and_mask(
|
||||
input_: torch.Tensor, org_vocab_start_index: int,
|
||||
org_vocab_end_index: int, num_org_vocab_padding: int,
|
||||
added_vocab_start_index: int,
|
||||
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# torch.compile will fuse all of the pointwise ops below
|
||||
# into a single kernel, making it very fast
|
||||
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ <
|
||||
org_vocab_end_index)
|
||||
added_vocab_mask = (input_ >= added_vocab_start_index) & (
|
||||
input_ < added_vocab_end_index)
|
||||
added_offset = added_vocab_start_index - (
|
||||
org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
|
||||
valid_offset = (org_vocab_start_index *
|
||||
org_vocab_mask) + (added_offset * added_vocab_mask)
|
||||
vocab_mask = org_vocab_mask | added_vocab_mask
|
||||
input_ = vocab_mask * (input_ - valid_offset)
|
||||
return input_, ~vocab_mask
|
||||
|
||||
|
||||
def vocab_parallel_embedding_forward(self, input_):
|
||||
if self.tp_size > 1:
|
||||
# Build the mask.
|
||||
masked_input, input_mask = get_masked_input_and_mask(
|
||||
input_, self.shard_indices.org_vocab_start_index,
|
||||
self.shard_indices.org_vocab_end_index,
|
||||
self.shard_indices.num_org_vocab_padding,
|
||||
self.shard_indices.added_vocab_start_index,
|
||||
self.shard_indices.added_vocab_end_index)
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = self.quant_method.embedding(self, masked_input.long())
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
return output
|
||||
|
||||
|
||||
VocabParallelEmbedding.forward = vocab_parallel_embedding_forward
|
||||
Reference in New Issue
Block a user