Drop torchair (#4814)
aclgraph is stable and fast now. Let's drop torchair graph mode now.
TODO: some logic to adapt torchair should be cleaned up as well. We'll
do it in the following PR.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
@@ -18,15 +18,6 @@ from uuid import uuid4
|
||||
|
||||
from vllm.logger import logger
|
||||
|
||||
TORCHAIR_MODEL_LIST = ["deepseek", "pangu", "kimi_k2", "qwen"]
|
||||
|
||||
|
||||
def _check_torchair_supported(model_type: str):
|
||||
for supported_model in TORCHAIR_MODEL_LIST:
|
||||
if supported_model in model_type.lower():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def check_kv_extra_config(vllm_config):
|
||||
|
||||
@@ -66,11 +57,6 @@ class AscendConfig:
|
||||
|
||||
def __init__(self, vllm_config):
|
||||
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
|
||||
torchair_graph_config = additional_config.get("torchair_graph_config",
|
||||
{})
|
||||
|
||||
self.torchair_graph_config = TorchairGraphConfig(
|
||||
torchair_graph_config, vllm_config, additional_config)
|
||||
|
||||
xlite_graph_config = additional_config.get("xlite_graph_config", {})
|
||||
self.xlite_graph_config = XliteGraphConfig(xlite_graph_config,
|
||||
@@ -107,8 +93,8 @@ class AscendConfig:
|
||||
self.chunked_prefill_for_mla = additional_config.get(
|
||||
"chunked_prefill_for_mla", False)
|
||||
self.enable_shared_expert_dp = additional_config.get(
|
||||
"enable_shared_expert_dp", False
|
||||
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel
|
||||
"enable_shared_expert_dp",
|
||||
False) and vllm_config.parallel_config.enable_expert_parallel
|
||||
if self.enable_shared_expert_dp:
|
||||
from vllm_ascend.utils import enable_sp
|
||||
assert enable_sp(vllm_config=vllm_config,
|
||||
@@ -215,86 +201,6 @@ class AscendCompilationConfig:
|
||||
# Add more compilation related configs here as needed
|
||||
|
||||
|
||||
class TorchairGraphConfig:
|
||||
"""
|
||||
Configuration Object for torchair_graph_config from additional_config
|
||||
"""
|
||||
|
||||
def __init__(self, torchair_graph_config, vllm_config, additional_config):
|
||||
self.enabled = torchair_graph_config.get("enabled", False)
|
||||
self.mode = torchair_graph_config.get("mode", '')
|
||||
self.use_cached_graph = torchair_graph_config.get(
|
||||
"use_cached_graph", False)
|
||||
self.use_cached_kv_cache_bytes = torchair_graph_config.get(
|
||||
"use_cached_kv_cache_bytes", False)
|
||||
self.graph_batch_sizes = torchair_graph_config.get(
|
||||
"graph_batch_sizes", [])
|
||||
self.graph_batch_sizes_init = torchair_graph_config.get(
|
||||
"graph_batch_sizes_init", False)
|
||||
self.enable_multistream_mla = torchair_graph_config.get(
|
||||
"enable_multistream_mla", False)
|
||||
self.enable_view_optimize = torchair_graph_config.get(
|
||||
"enable_view_optimize", True)
|
||||
self.enable_frozen_parameter = torchair_graph_config.get(
|
||||
"enable_frozen_parameter", True)
|
||||
self.enable_kv_nz = torchair_graph_config.get("enable_kv_nz", False)
|
||||
self.enable_super_kernel = torchair_graph_config.get(
|
||||
"enable_super_kernel", False)
|
||||
|
||||
if not isinstance(self.graph_batch_sizes, list):
|
||||
raise TypeError("graph_batch_sizes must be list[int]")
|
||||
if self.graph_batch_sizes_init and len(self.graph_batch_sizes) > 0:
|
||||
raise ValueError(
|
||||
"graph_batch_sizes_init is only valid when graph_batch_sizes is empty"
|
||||
)
|
||||
if not self.enabled:
|
||||
if self.mode:
|
||||
raise RuntimeError(
|
||||
"mode is valid only when Torchair graph mode is enabled")
|
||||
if self.use_cached_graph:
|
||||
raise RuntimeError(
|
||||
"use_cached_graph is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.use_cached_kv_cache_bytes:
|
||||
raise RuntimeError(
|
||||
"use_cached_kv_cache_bytes is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.graph_batch_sizes:
|
||||
raise RuntimeError(
|
||||
"graph_batch_sizes is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.graph_batch_sizes_init:
|
||||
raise RuntimeError(
|
||||
"graph_batch_sizes_init is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.enable_multistream_mla:
|
||||
raise RuntimeError(
|
||||
"enable_multistream_mla is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.enable_kv_nz:
|
||||
raise RuntimeError(
|
||||
"enable_kv_nz is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.enable_super_kernel:
|
||||
raise RuntimeError(
|
||||
"enable_super_kernel is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.enable_super_kernel:
|
||||
if vllm_config.parallel_config.tensor_parallel_size != 1:
|
||||
raise RuntimeError(
|
||||
"enable_super_kernel is valid only when tensor_parallel_size is 1"
|
||||
)
|
||||
if not additional_config.get("multistream_overlap_shared_expert",
|
||||
False):
|
||||
raise RuntimeError(
|
||||
"enable_super_kernel is valid only when multistream_overlap_shared_expert is enabled"
|
||||
)
|
||||
if self.use_cached_kv_cache_bytes and not self.use_cached_graph:
|
||||
raise RuntimeError(
|
||||
"use_cached_kv_cache_bytes is valid only when Torchair graph mode and use_cached_graph are enabled"
|
||||
)
|
||||
|
||||
|
||||
class XliteGraphConfig:
|
||||
"""
|
||||
Configuration Object for xlite_graph_config from additional_config
|
||||
@@ -382,39 +288,7 @@ def get_ascend_config():
|
||||
def check_ascend_config(vllm_config, enforce_eager):
|
||||
ascend_config = get_ascend_config()
|
||||
|
||||
# for eager mode
|
||||
if enforce_eager:
|
||||
# torchair_graph cannot be enabled with eager mode.
|
||||
if ascend_config.torchair_graph_config.enabled:
|
||||
raise RuntimeError(
|
||||
"Can't enable graph mode and eager mode at the same time. Please set `enforce_eager=False` if you attempt to enable NPU graph mode."
|
||||
)
|
||||
# for graph mode
|
||||
else:
|
||||
# torchair_graph case
|
||||
if ascend_config.torchair_graph_config.enabled:
|
||||
# torchair_graph is supported for deepseek/pangu/qwen model only.
|
||||
if vllm_config.model_config:
|
||||
model_type = vllm_config.model_config.hf_config.model_type
|
||||
if not _check_torchair_supported(model_type):
|
||||
raise NotImplementedError(
|
||||
"Torchair graph mode only works with following model types:"
|
||||
f"{TORCHAIR_MODEL_LIST}.")
|
||||
if ascend_config.enable_shared_expert_dp:
|
||||
logger.warning(
|
||||
"enable_shared_expert_dp is not supported for torchair graph mode currently, "
|
||||
"it has been disabled automatically.")
|
||||
# aclgraph case
|
||||
else:
|
||||
if ascend_config.ascend_compilation_config.enable_quantization_fusion:
|
||||
logger.info(
|
||||
"Quantization fusion enabled! op fusion on quantization are expected. "
|
||||
)
|
||||
|
||||
if vllm_config.model_config:
|
||||
model_type = vllm_config.model_config.hf_config.model_type
|
||||
if "qwen" not in model_type:
|
||||
logger.warning(
|
||||
"ACL Graph is currently experimental. Please "
|
||||
"raise an issue on https://github.com/vllm-project/vllm-ascend/issues"
|
||||
" if you encourage any Error")
|
||||
if ascend_config.ascend_compilation_config.enable_quantization_fusion:
|
||||
logger.info(
|
||||
"Quantization fusion enabled! op fusion on quantization are expected. "
|
||||
)
|
||||
|
||||
@@ -857,7 +857,6 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
ascend_config = get_ascend_config()
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
|
||||
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.ring_mla_mask_size = 512
|
||||
@@ -1248,7 +1247,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
||||
kv_no_split = kv_no_split.view(
|
||||
B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
|
||||
cache_mode = "PA"
|
||||
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||
kv_no_split,
|
||||
self.kv_a_layernorm.weight,
|
||||
@@ -1276,7 +1275,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
||||
kv_no_split = kv_no_split.view(
|
||||
B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
|
||||
cache_mode = "PA"
|
||||
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||
kv_no_split,
|
||||
self.kv_a_layernorm.weight,
|
||||
@@ -1318,18 +1317,11 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
# shape of knope/k_pe for npu graph mode should be:
|
||||
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
|
||||
actual_seq_lengths = None
|
||||
if self.enable_kv_nz:
|
||||
k_nope = k_nope.view(-1, self.num_kv_heads,
|
||||
self.kv_lora_rank // 16, block_size, 16)
|
||||
k_pe = k_pe.view(-1, self.num_kv_heads,
|
||||
self.qk_rope_head_dim // 16, block_size, 16)
|
||||
input_layout = "BSND"
|
||||
else:
|
||||
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
|
||||
self.kv_lora_rank)
|
||||
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
|
||||
self.qk_rope_head_dim)
|
||||
input_layout = "BNSD"
|
||||
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
|
||||
self.kv_lora_rank)
|
||||
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
|
||||
self.qk_rope_head_dim)
|
||||
input_layout = "BNSD"
|
||||
|
||||
if attn_metadata.attn_state in [
|
||||
AscendAttentionState.SpecDecoding,
|
||||
@@ -1346,14 +1338,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
|
||||
actual_seq_lengths = decode_meta.actual_seq_lengths_q
|
||||
else:
|
||||
if self.enable_kv_nz:
|
||||
q_nope = q_nope.view(num_tokens, 1, self.num_heads,
|
||||
-1).contiguous()
|
||||
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
|
||||
else:
|
||||
q_nope = q_nope.view(num_tokens, self.num_heads, 1,
|
||||
-1).contiguous()
|
||||
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
|
||||
q_nope = q_nope.view(num_tokens, self.num_heads, 1,
|
||||
-1).contiguous()
|
||||
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
|
||||
sparse_mode = 0
|
||||
spec_attn_mask = None
|
||||
|
||||
|
||||
@@ -345,7 +345,6 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
ascend_config = get_ascend_config()
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
|
||||
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
||||
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
|
||||
|
||||
assert self.indexer is not None, "Indexer is required for DSA."
|
||||
@@ -534,7 +533,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
||||
kv_no_split = kv_no_split.view(
|
||||
B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
|
||||
cache_mode = "PA"
|
||||
|
||||
if self.enable_sfa_cp:
|
||||
assert slots_cp is not None
|
||||
|
||||
@@ -453,7 +453,6 @@ class KVCacheRecvingThread(threading.Thread):
|
||||
def _cat_kv_cache(self, block_ids: list[list[int]]):
|
||||
# Get necessary parameters
|
||||
k_cache = list(self.kv_caches.values())[0][0]
|
||||
kv_shape = k_cache.shape
|
||||
dtype = k_cache.dtype
|
||||
device = k_cache.device
|
||||
head_dim = self.model_config.hf_config.head_dim
|
||||
@@ -494,13 +493,6 @@ class KVCacheRecvingThread(threading.Thread):
|
||||
|
||||
# Process each layer in the KV cache
|
||||
for _, (k_cache_layer, v_cache_layer) in self.kv_caches.items():
|
||||
if len(
|
||||
k_cache_layer.shape
|
||||
) == 3: # kv shape in torchair model is [num_block, block_size, num_kv_head*head_dim]
|
||||
k_cache_layer = k_cache_layer.view(kv_shape[0], kv_shape[1],
|
||||
num_kv_head, head_dim)
|
||||
v_cache_layer = v_cache_layer.view(kv_shape[0], kv_shape[1],
|
||||
num_kv_head, head_dim)
|
||||
# Load cache data into buffers
|
||||
torch_npu.atb.npu_paged_cache_load(
|
||||
k_cache_layer,
|
||||
|
||||
@@ -99,8 +99,6 @@ class MoECommMethod(ABC):
|
||||
w2_scale: Optional[list[torch.Tensor]] = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
# For TorchAir graph
|
||||
is_torchair: bool = False,
|
||||
# For Cube/Vector parallel
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
@@ -283,8 +281,6 @@ class FusedAlltoAllCommImpl(MoECommMethod):
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
# For TorchAir graph
|
||||
is_torchair: bool = False,
|
||||
# For Cube/Vector parallel
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
|
||||
@@ -26,10 +26,7 @@ from vllm.platforms import Platform, PlatformEnum
|
||||
# todo: please remove it when solve cuda hard code in vllm
|
||||
os.environ["VLLM_DISABLE_SHARED_EXPERTS_STREAM"] = "1"
|
||||
|
||||
from vllm_ascend.ascend_config import (check_ascend_config, get_ascend_config,
|
||||
init_ascend_config)
|
||||
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
|
||||
delete_torchair_cache_file)
|
||||
from vllm_ascend.ascend_config import check_ascend_config, init_ascend_config
|
||||
from vllm_ascend.utils import refresh_block_size
|
||||
|
||||
# isort: off
|
||||
@@ -204,25 +201,6 @@ class NPUPlatform(Platform):
|
||||
compilation_config.mode)
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
# set CUDAGraphMode to None when torchair is enabled, no mather what compilation_config.level is.
|
||||
if ascend_config.torchair_graph_config.enabled:
|
||||
logger.info(
|
||||
"Torchair compilation enabled on NPU. Setting CUDAGraphMode to NONE"
|
||||
)
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
# Note: We delete the torchair cache folder here to prevent runtime issues caused by dimension
|
||||
# mismatches or configuration inconsistencies when users reuse cached computation graphs. Though
|
||||
# this will increase graph compilation duration, it significantly enhances robustness and decreases
|
||||
# graph launching time during inference.
|
||||
if check_torchair_cache_exist(
|
||||
) and not ascend_config.torchair_graph_config.use_cached_kv_cache_bytes:
|
||||
logger.warning(
|
||||
"Torchair cache folder is deleted here to prevent runtime issues caused by dimension "
|
||||
"mismatches or configuration inconsistencies when users reuse cached computation graphs. "
|
||||
"In order to decrease torchair graph compilation time, users can enable both use_cached_graph "
|
||||
"and use_cached_kv_cache_bytes in torchair_graph_config.")
|
||||
delete_torchair_cache_file()
|
||||
|
||||
# set cudaprah sizes before extending `compilation_config.splitting_ops`
|
||||
vllm_config._set_cudagraph_sizes()
|
||||
# There are cases where default cudagraph_capture_sizes are not friendly
|
||||
@@ -303,9 +281,7 @@ class NPUPlatform(Platform):
|
||||
if parallel_config and parallel_config.worker_cls == "auto":
|
||||
# TODO: this is a tricky way to disable `use_sequence_parallel_moe` in vllm.
|
||||
parallel_config.all2all_backend = "flashinfer_all2allv"
|
||||
if ascend_config.torchair_graph_config.enabled:
|
||||
parallel_config.worker_cls = "vllm_ascend.torchair.torchair_worker.NPUTorchairWorker"
|
||||
elif ascend_config.xlite_graph_config.enabled:
|
||||
if ascend_config.xlite_graph_config.enabled:
|
||||
logger.info(
|
||||
"Euler Xlite enabled. See: https://gitee.com/openeuler/GVirt/tree/master/xlite"
|
||||
)
|
||||
@@ -390,29 +366,14 @@ class NPUPlatform(Platform):
|
||||
use_sparse=False,
|
||||
attn_type: str | None = None,
|
||||
):
|
||||
ascend_config = get_ascend_config()
|
||||
|
||||
if use_mla and ascend_config.enable_shared_expert_dp:
|
||||
if use_mla and use_sparse:
|
||||
return "vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend"
|
||||
|
||||
use_torchair = ascend_config.torchair_graph_config.enabled
|
||||
# choose attention backend based on use_mla and use_torchair
|
||||
# choose attention backend based on use_mla
|
||||
backend_map = {
|
||||
(True, False, True):
|
||||
"vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend",
|
||||
(True, False, False):
|
||||
"vllm_ascend.attention.mla_v1.AscendMLABackend",
|
||||
(False, False, True):
|
||||
"vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend",
|
||||
(False, False, False):
|
||||
(True, False): "vllm_ascend.attention.mla_v1.AscendMLABackend",
|
||||
(False, False):
|
||||
"vllm_ascend.attention.attention_v1.AscendAttentionBackend",
|
||||
(True, True, False):
|
||||
"vllm_ascend.attention.sfa_v1.AscendSFABackend",
|
||||
(True, True, True):
|
||||
"vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend",
|
||||
(True, True): "vllm_ascend.attention.sfa_v1.AscendSFABackend",
|
||||
}
|
||||
return backend_map[(use_mla, use_sparse, use_torchair)]
|
||||
return backend_map[(use_mla, use_sparse)]
|
||||
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
|
||||
@@ -111,10 +111,9 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
ascend_config = get_ascend_config()
|
||||
self.use_aclgraph = (
|
||||
vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE
|
||||
and not vllm_config.model_config.enforce_eager
|
||||
and not ascend_config.torchair_graph_config.enabled)
|
||||
self.use_aclgraph = (vllm_config.compilation_config.mode
|
||||
== CompilationMode.VLLM_COMPILE
|
||||
and not vllm_config.model_config.enforce_eager)
|
||||
|
||||
self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
|
||||
self.in_dtype = vllm_config.model_config.dtype
|
||||
|
||||
@@ -20,21 +20,14 @@ from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
||||
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
||||
from vllm_ascend.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm_ascend.spec_decode.suffix_proposer import SuffixDecodingProposer
|
||||
from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer
|
||||
|
||||
|
||||
def get_spec_decode_method(method,
|
||||
vllm_config,
|
||||
device,
|
||||
runner,
|
||||
is_torchair_graph=False):
|
||||
def get_spec_decode_method(method, vllm_config, device, runner):
|
||||
if method == "ngram":
|
||||
return NgramProposer(vllm_config, device, runner)
|
||||
elif method in ("eagle", "eagle3"):
|
||||
return EagleProposer(vllm_config, device, runner)
|
||||
elif method == "mtp":
|
||||
if is_torchair_graph:
|
||||
return TorchairMtpProposer(vllm_config, device, runner)
|
||||
return MtpProposer(vllm_config, device, runner)
|
||||
elif method == 'suffix':
|
||||
return SuffixDecodingProposer(vllm_config, device, runner)
|
||||
|
||||
@@ -1,357 +0,0 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import vllm
|
||||
from torch import nn
|
||||
from transformers import Qwen2Config
|
||||
from vllm.attention.backends.abstract import AttentionMetadata, AttentionType
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_reduce_scatter)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
|
||||
from vllm.model_executor.models.qwen2 import Qwen2Attention # noqa: F401
|
||||
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM # noqa: F401
|
||||
from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Model
|
||||
from vllm.model_executor.models.utils import (AutoWeightsLoader,
|
||||
PPMissingLayer, maybe_prefix)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.config import set_default_rope_theta
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
|
||||
|
||||
def all_gather_and_maybe_unpad(
|
||||
hidden_states: torch.Tensor,
|
||||
pad_size: int,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
|
||||
if pad_size > 0:
|
||||
return hidden_states[:-pad_size, :]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def maybe_pad_and_reduce_scatter(
|
||||
hidden_states: torch.Tensor,
|
||||
pad_size: int,
|
||||
) -> torch.Tensor:
|
||||
if pad_size > 0:
|
||||
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_size))
|
||||
hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, 0)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomQwen2Attention(Qwen2Attention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_parameters: Optional[dict[str, Any]] = None,
|
||||
max_position: int = 4096 * 32,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
hidden_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
max_position=max_position,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
attn_type=attn_type,
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
rope_parameters=rope_parameters)
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
if self.torchair_graph_enabled and attn_metadata is not None and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
q, k = self.rotary_emb(positions,
|
||||
q,
|
||||
k,
|
||||
is_prefill=False,
|
||||
is_qwen_torchair=True)
|
||||
forward_kwargs = {}
|
||||
output_shape = q.shape
|
||||
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
|
||||
forward_kwargs['output'] = output
|
||||
|
||||
attn_output = self.attn.impl.forward(self.attn,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
**forward_kwargs)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
else:
|
||||
if type(self.rotary_emb) is RotaryEmbedding:
|
||||
q, k = self.rotary_emb(positions, q, k, is_qwen_torchair=True)
|
||||
else:
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class CustomQwen2DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen2Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
set_default_rope_theta(config, default_theta=1000000)
|
||||
|
||||
dual_chunk_attention_config = getattr(config,
|
||||
"dual_chunk_attention_config",
|
||||
None)
|
||||
|
||||
# By default, Qwen2 uses causal attention as it is a decoder-only model.
|
||||
# You can override the HF config with `is_causal=False` to enable
|
||||
# bidirectional attention, which is used in some embedding models
|
||||
# (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
|
||||
if getattr(config, "is_causal", True):
|
||||
attn_type = AttentionType.DECODER
|
||||
else:
|
||||
attn_type = AttentionType.ENCODER_ONLY
|
||||
|
||||
self.self_attn = CustomQwen2Attention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
max_position=config.max_position_embeddings,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_parameters=config.rope_parameters,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
attn_type=attn_type,
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
)
|
||||
self.mlp = Qwen2MLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
hidden_states = self.self_attn(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims={
|
||||
"input_ids": 0,
|
||||
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
|
||||
# otherwise (seq_len, ).
|
||||
"positions": -1,
|
||||
"intermediate_tensors": 0,
|
||||
"inputs_embeds": 0,
|
||||
})
|
||||
class CustomQwen2Model(Qwen2Model):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
decoder_layer_type: type[nn.Module] = CustomQwen2DecoderLayer):
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
decoder_layer_type=decoder_layer_type)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
kv_cache = kv_caches[i - self.start_layer] \
|
||||
if kv_caches is not None else None
|
||||
hidden_states, residual = layer(positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomQwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
# add `CustomQwen2Model` to init self.model
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.model = CustomQwen2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "lm_head"))
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata=None, # type: ignore
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."]
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
vllm.model_executor.models.qwen2.Qwen2ForCausalLM = CustomQwen2ForCausalLM
|
||||
@@ -1,527 +0,0 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2024 The Qwen team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Adapted from vllm/model_executor/models/qwen3_moe.py
|
||||
# This file is a part of the vllm-ascend project.
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, CompilationMode, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
|
||||
get_tp_group)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.interfaces import (MixtureOfExperts,
|
||||
SupportsLoRA, SupportsPP)
|
||||
from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention,
|
||||
Qwen3MoeDecoderLayer,
|
||||
Qwen3MoeForCausalLM,
|
||||
Qwen3MoeMLP, Qwen3MoeModel,
|
||||
Qwen3MoeSparseMoeBlock)
|
||||
from vllm.model_executor.models.utils import (
|
||||
PPMissingLayer, extract_layer_index,
|
||||
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.torchair.ops.sequence_parallel import (MetadataForPadding,
|
||||
init_metadata_for_sp)
|
||||
from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE
|
||||
|
||||
|
||||
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
nn.Module.__init__(self)
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
if self.tp_size > config.num_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {config.num_experts}.")
|
||||
|
||||
self.gate = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.num_experts,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
|
||||
self.experts = TorchairAscendFusedMoE(
|
||||
num_experts=config.num_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
)
|
||||
|
||||
self.top_k = config.num_experts_per_tok
|
||||
|
||||
self.dp_size = get_dp_group().world_size
|
||||
|
||||
self.tp_group = get_tp_group().device_group
|
||||
self.tp_rank = get_tp_group().rank_in_group
|
||||
self.ep_group = get_ep_group()
|
||||
|
||||
self.params_dtype = torch.get_default_dtype()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attn_metadata=None,
|
||||
_metadata_for_padding: Optional[MetadataForPadding] = None,
|
||||
):
|
||||
if attn_metadata is None:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
# when profile runs, force experts to load balanced tokens
|
||||
# to avoid high memory consumption on a single rank.
|
||||
enable_force_load_balance = get_forward_context().in_profile_run
|
||||
is_prefill = get_forward_context().with_prefill
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
is_prefill=is_prefill,
|
||||
top_k=self.top_k,
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
shared_experts=None,
|
||||
_metadata_for_padding=_metadata_for_padding,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomQwen3MoeAttention(Qwen3MoeAttention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_parameters: dict[str, Any],
|
||||
max_position_embeddings: int = 8192,
|
||||
head_dim: Optional[int] = None,
|
||||
rms_norm_eps: float = 1e-06,
|
||||
qkv_bias: bool = False,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = head_dim or (hidden_size // self.total_num_heads)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=qkv_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj")
|
||||
|
||||
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj")
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=rope_parameters,
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
|
||||
@staticmethod
|
||||
def normalize_qkv(qkv: torch.Tensor, q_size: int, kv_size: int,
|
||||
head_dim: int, q_norm, k_norm):
|
||||
q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
|
||||
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // head_dim, head_dim)
|
||||
q_by_head = q_norm(q_by_head)
|
||||
q = q_by_head.view(q.shape)
|
||||
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // head_dim, head_dim)
|
||||
k_by_head = k_norm(k_by_head)
|
||||
k = k_by_head.view(k.shape)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = self.normalize_qkv(qkv, self.q_size, self.kv_size,
|
||||
self.head_dim, self.q_norm, self.k_norm)
|
||||
|
||||
if (self.torchair_graph_enabled and attn_metadata is not None and
|
||||
attn_metadata.attn_state == AscendAttentionState.DecodeOnly):
|
||||
q, k = self.rotary_emb(positions,
|
||||
q,
|
||||
k,
|
||||
is_prefill=False,
|
||||
is_qwen_torchair=True)
|
||||
forward_kwargs = {}
|
||||
output_shape = q.shape
|
||||
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
|
||||
forward_kwargs['output'] = output
|
||||
|
||||
attn_output = self.attn.impl.forward(self.attn,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
**forward_kwargs)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
else:
|
||||
q, k = self.rotary_emb(positions, q, k, is_qwen_torchair=True)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
vllm_config: Optional[VllmConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
|
||||
nn.Module.__init__(self)
|
||||
self.hidden_size = config.hidden_size
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
self.self_attn = CustomQwen3MoeAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_parameters=config.rope_parameters,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
qkv_bias=getattr(config, 'attention_bias', False),
|
||||
head_dim=getattr(config, 'head_dim', None),
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
|
||||
# `mlp_only_layers` in the config.
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
|
||||
config.mlp_only_layers)
|
||||
self.use_aclgraph = (vllm_config is not None
|
||||
and vllm_config.compilation_config.mode
|
||||
== CompilationMode.VLLM_COMPILE
|
||||
and not vllm_config.model_config.enforce_eager)
|
||||
if (layer_idx not in mlp_only_layers) and (
|
||||
config.num_experts > 0 and
|
||||
(layer_idx + 1) % config.decoder_sparse_step == 0):
|
||||
if not self.use_aclgraph:
|
||||
# FIXME: custom sparse moe block doesn't work with aclgraph.
|
||||
self.mlp = CustomSparseMoeBlock(config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
else:
|
||||
self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
else:
|
||||
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
self.enable_sequence_parallelism = (
|
||||
vllm_config.compilation_config.pass_config.enable_sp
|
||||
if vllm_config is not None else False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
_metadata_for_padding: Optional[MetadataForPadding] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# To prevent precision issues during the decoder phase when only prefilling enables SP
|
||||
if not self.enable_sequence_parallelism:
|
||||
self.self_attn.o_proj.reduce_results = True
|
||||
else:
|
||||
self.self_attn.o_proj.reduce_results = not _metadata_for_padding.not_dummy_and_is_prefill if _metadata_for_padding is not None else True
|
||||
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
||||
residual = _metadata_for_padding.padding_slice(residual)
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
||||
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
|
||||
hidden_states)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
||||
hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter(
|
||||
hidden_states)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
if not self.use_aclgraph:
|
||||
hidden_states = self.mlp(
|
||||
hidden_states, _metadata_for_padding=_metadata_for_padding)
|
||||
else:
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class CustomQwen3MoeModel(Qwen3MoeModel):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
eplb_config = parallel_config.eplb_config
|
||||
self.num_redundant_experts = eplb_config.num_redundant_experts
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.config = config
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=f"{prefix}.embed_tokens")
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: CustomQwen3MoeDecoderLayer(
|
||||
config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
_metadata_for_padding: Optional[MetadataForPadding] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
kv_caches[i -
|
||||
self.start_layer] if kv_caches is not None else None,
|
||||
attn_metadata,
|
||||
_metadata_for_padding=_metadata_for_padding)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
|
||||
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
|
||||
hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
SupportsPP.__init__(self)
|
||||
SupportsLoRA.__init__(self)
|
||||
MixtureOfExperts.__init__(self)
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = CustomQwen3MoeModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"))
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sp
|
||||
# Set MoE hyperparameters
|
||||
self.expert_weights: list[torch.Tensor] = []
|
||||
|
||||
self.moe_layers: list[FusedMoE] = []
|
||||
example_layer = None
|
||||
for layer in self.model.layers:
|
||||
if isinstance(layer, PPMissingLayer):
|
||||
continue
|
||||
|
||||
assert isinstance(layer, Qwen3MoeDecoderLayer)
|
||||
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
|
||||
example_layer = layer.mlp
|
||||
self.moe_layers.append(layer.mlp.experts)
|
||||
|
||||
if example_layer is None:
|
||||
raise RuntimeError("No Qwen3MoE layer found in the model.layers.")
|
||||
|
||||
self.num_moe_layers = len(self.moe_layers)
|
||||
self.num_expert_groups = 1
|
||||
self.num_shared_experts = 0
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
_metadata_for_padding = init_metadata_for_sp(
|
||||
input_ids, self.enable_sequence_parallelism)
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds, _metadata_for_padding)
|
||||
return hidden_states
|
||||
@@ -1,221 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Adapted from vllm/model_executor/models/deepseek_mtp.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.deepseek_mtp import (
|
||||
DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer,
|
||||
SharedHead)
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend.torchair.models.torchair_deepseek_v2 import \
|
||||
TorchairDeepseekV2DecoderLayer
|
||||
|
||||
|
||||
class TorchairDeepSeekShareHead(SharedHead):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "head"))
|
||||
|
||||
|
||||
class TorchairDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer
|
||||
):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
prefix: str,
|
||||
model_config: ModelConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
self.shared_head = TorchairDeepSeekShareHead(config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix,
|
||||
"shared_head"))
|
||||
self.mtp_block = TorchairDeepseekV2DecoderLayer(
|
||||
config, prefix, model_config, cache_config, quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_index: int = 0,
|
||||
) -> torch.Tensor:
|
||||
assert inputs_embeds is not None
|
||||
# masking inputs at position 0, as not needed by MTP
|
||||
inputs_embeds = torch.where((positions == 0).unsqueeze(-1),
|
||||
torch.zeros_like(inputs_embeds),
|
||||
inputs_embeds)
|
||||
inputs_embeds = self.enorm(inputs_embeds)
|
||||
previous_hidden_states = self.hnorm(previous_hidden_states)
|
||||
|
||||
hidden_states = self.eh_proj(
|
||||
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
|
||||
|
||||
del inputs_embeds, previous_hidden_states
|
||||
replace_allreduce = hidden_states.shape[0] % self.tp_size == 0
|
||||
|
||||
hidden_states, residual = self.mtp_block(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=None,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
replace_allreduce=replace_allreduce)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TorchairDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.mtp_start_layer_idx = config.num_hidden_layers
|
||||
self.num_mtp_layers = config.num_nextn_predict_layers
|
||||
# to map the exact layer index from weights
|
||||
self.layers = torch.nn.ModuleDict({
|
||||
str(idx):
|
||||
TorchairDeepSeekMultiTokenPredictorLayer(
|
||||
config,
|
||||
f"{prefix}.layers.{idx}",
|
||||
model_config=vllm_config.model_config,
|
||||
cache_config=vllm_config.cache_config,
|
||||
quant_config=vllm_config.quant_config,
|
||||
)
|
||||
for idx in range(self.mtp_start_layer_idx,
|
||||
self.mtp_start_layer_idx + self.num_mtp_layers)
|
||||
})
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
|
||||
# Note: torch._dynamo.exc.Unsupported: builtin: str
|
||||
self.layers_list = [
|
||||
self.layers[str(idx)]
|
||||
for idx in range(self.mtp_start_layer_idx,
|
||||
self.mtp_start_layer_idx + self.num_mtp_layers)
|
||||
]
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||
step_kv_cache = kv_caches[
|
||||
current_step_idx] if kv_caches is not None else None
|
||||
return self.layers_list[current_step_idx](
|
||||
input_ids,
|
||||
positions,
|
||||
step_kv_cache,
|
||||
attn_metadata,
|
||||
previous_hidden_states,
|
||||
inputs_embeds,
|
||||
current_step_idx,
|
||||
)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||
mtp_layer = self.layers_list[current_step_idx]
|
||||
logits = self.logits_processor(mtp_layer.shared_head.head,
|
||||
mtp_layer.shared_head(hidden_states))
|
||||
return logits
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class TorchairDeepSeekMTP(DeepSeekMTP):
|
||||
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
|
||||
# NOTE 2.The description file generated by the current msmodelslim tool does not have
|
||||
# MTP layer info. Please manually add it and set the value to FLOAT.
|
||||
packed_modules_mapping = {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
self.model = TorchairDeepSeekMultiTokenPredictor(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
hidden_states: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, hidden_states, inputs_embeds,
|
||||
spec_step_idx)
|
||||
return hidden_states
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,28 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from vllm_ascend.torchair.models.torchair_deepseek_v2 import \
|
||||
TorchairDeepseekV2ForCausalLM
|
||||
|
||||
|
||||
class TorchairDeepseekV3ForCausalLM(TorchairDeepseekV2ForCausalLM):
|
||||
pass
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,120 +0,0 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
get_tp_group, tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_reduce_scatter)
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
|
||||
|
||||
class MetadataForPadding:
|
||||
|
||||
def __init__(self,
|
||||
padding_flag=False,
|
||||
lengths_sum_padding=0,
|
||||
lengths_sum_unpadding=0,
|
||||
pad_size=0,
|
||||
not_dummy_and_is_prefill=False):
|
||||
self.padding_flag = padding_flag
|
||||
self.not_dummy_and_is_prefill = not_dummy_and_is_prefill
|
||||
|
||||
self.lengths_sum_padding = lengths_sum_padding
|
||||
self.lengths_sum_unpadding = lengths_sum_unpadding
|
||||
self.pad_size = pad_size
|
||||
|
||||
self.tp_size = get_tp_group().world_size
|
||||
self.tp_rank_in_group = get_tp_group().rank_in_group
|
||||
|
||||
assert self.lengths_sum_padding % self.tp_size == 0
|
||||
self.slice_size = self.lengths_sum_padding // self.tp_size
|
||||
|
||||
self.mc2_mask = torch.zeros(
|
||||
self.lengths_sum_padding,
|
||||
dtype=torch.bool,
|
||||
device=NPUPlatform.device_type,
|
||||
)
|
||||
self.mc2_mask[:lengths_sum_unpadding] = True
|
||||
|
||||
def padding_aligned_reduce_scatter(self,
|
||||
data: torch.Tensor) -> torch.Tensor:
|
||||
if self.padding_flag:
|
||||
pad_size = self.pad_size
|
||||
padded_data = F.pad(data, (0, 0, 0, pad_size))
|
||||
else:
|
||||
padded_data = data
|
||||
padded_data_reduce_scatter = tensor_model_parallel_reduce_scatter(
|
||||
padded_data, 0)
|
||||
|
||||
return padded_data_reduce_scatter
|
||||
|
||||
def allgather_unpadding_aligned(self,
|
||||
padded_data: torch.Tensor) -> torch.Tensor:
|
||||
padded_data_allgather = tensor_model_parallel_all_gather(
|
||||
padded_data, 0)
|
||||
if self.padding_flag:
|
||||
lengths_sum_unpadding = self.lengths_sum_unpadding
|
||||
unpadding_data = padded_data_allgather[:lengths_sum_unpadding]
|
||||
else:
|
||||
unpadding_data = padded_data_allgather
|
||||
return unpadding_data
|
||||
|
||||
def padding_slice(self, data: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
padded_data = F.pad(data, (0, 0, 0, self.pad_size))
|
||||
start = self.tp_rank_in_group * self.slice_size
|
||||
end = start + self.slice_size
|
||||
slice_data = padded_data[start:end]
|
||||
|
||||
return slice_data
|
||||
|
||||
def padding_aligned_scatter(self, data: torch.Tensor) -> torch.Tensor:
|
||||
if self.padding_flag:
|
||||
pad_size = self.pad_size
|
||||
padded_data = F.pad(data, (0, 0, 0, pad_size))
|
||||
else:
|
||||
padded_data = data
|
||||
# padded_data = data
|
||||
padded_data = torch.tensor_split(padded_data, self.tp_size, dim=0)
|
||||
|
||||
padded_data_reduce_scatter = padded_data[self.tp_rank_in_group]
|
||||
|
||||
return padded_data_reduce_scatter
|
||||
|
||||
|
||||
def init_metadata_for_sp(input_ids, enable_sequence_parallelism):
|
||||
if not enable_sequence_parallelism:
|
||||
return MetadataForPadding(padding_flag=False,
|
||||
not_dummy_and_is_prefill=False)
|
||||
|
||||
is_perifll = 0
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if attn_metadata is not None:
|
||||
if hasattr(attn_metadata,
|
||||
'is_only_prefill') and attn_metadata.is_only_prefill:
|
||||
is_perifll = 1
|
||||
if hasattr(attn_metadata,
|
||||
'num_prefills') and attn_metadata.num_prefills > 0:
|
||||
is_perifll = 1
|
||||
|
||||
if is_perifll:
|
||||
lengths_sum_unpadding = input_ids.shape[0]
|
||||
lengths_sum_padding = (
|
||||
(lengths_sum_unpadding + tp_size - 1) // tp_size) * tp_size
|
||||
if lengths_sum_unpadding == lengths_sum_padding:
|
||||
padding_flag = False
|
||||
else:
|
||||
padding_flag = True
|
||||
pad_size = lengths_sum_padding - lengths_sum_unpadding
|
||||
_metadata_for_padding = MetadataForPadding(
|
||||
lengths_sum_unpadding=lengths_sum_unpadding,
|
||||
lengths_sum_padding=lengths_sum_padding,
|
||||
padding_flag=padding_flag,
|
||||
pad_size=pad_size,
|
||||
not_dummy_and_is_prefill=True)
|
||||
|
||||
return _metadata_for_padding
|
||||
|
||||
return MetadataForPadding(padding_flag=False,
|
||||
not_dummy_and_is_prefill=False)
|
||||
@@ -1,245 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
|
||||
|
||||
def dispose_tensor(x: torch.Tensor):
|
||||
x.set_(torch.empty([], device=x.device, dtype=x.dtype))
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerMetadata:
|
||||
"""Metadata for a layer.
|
||||
"""
|
||||
layer: Optional[LinearBase] # The layer object.
|
||||
post_method: Callable[[
|
||||
torch.nn.Module
|
||||
], None] # The `process_weights_after_loading` method from the quant method.
|
||||
weight: torch.Tensor # The weight tensor.
|
||||
window_idx: int # The index of the window.
|
||||
|
||||
|
||||
@dataclass
|
||||
class SharedWindowMetadata:
|
||||
"""Metadata for a shared window.
|
||||
"""
|
||||
weight: torch.Tensor # The weight tensor to be shared by layers.
|
||||
data_layer_idx: int # The index of the layer this window's weight is equal to.
|
||||
work: Optional[torch.distributed.Work] # The asynchronous broadcast work.
|
||||
|
||||
|
||||
@dataclass
|
||||
class SeriesMetadata:
|
||||
"""Metadata for a weight shared series.
|
||||
"""
|
||||
group: GroupCoordinator
|
||||
start_layer: int
|
||||
end_layer: int
|
||||
num_layers: int
|
||||
prefetch_step: int
|
||||
dummy_weight: torch.Tensor # Dummy weight to replace the loaded weight matrix. All the layers in the series share the same dummy weight tensor.
|
||||
layers: list[LayerMetadata]
|
||||
shared_windows: list[
|
||||
SharedWindowMetadata] # Shared windows for prefetching. The window size is (`prefetch_step` + 1), as only the weights for the next (`prefetch_step` + 1) layers need to be stored.
|
||||
window_offset: int # The index of the window for the next coming layer.
|
||||
|
||||
def is_source(self, layer_idx) -> bool:
|
||||
return layer_idx % self.group.world_size == self.group.rank_in_group
|
||||
|
||||
def post_process_after_loading(self):
|
||||
# This method only needs to be called once per series.
|
||||
if self.shared_windows:
|
||||
return
|
||||
for layer_idx in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[layer_idx - self.start_layer]
|
||||
is_source = self.is_source(layer_idx)
|
||||
# If the weight uses dummy weight, make a copy temporary such that the post method call won't affect other layers which also uses dummy weight.
|
||||
if not is_source:
|
||||
layer.weight.set_(torch.empty_like(self.dummy_weight))
|
||||
# Broadcast to get the true weight.
|
||||
dist.broadcast(layer.weight,
|
||||
src=self.group.ranks[layer_idx %
|
||||
self.group.world_size],
|
||||
group=self.group.device_group)
|
||||
assert layer.layer is not None
|
||||
# Call `process_weights_after_loading` from the quant method.
|
||||
layer.post_method(layer.layer)
|
||||
step = layer_idx - self.start_layer
|
||||
if step < self.prefetch_step:
|
||||
# Build the windows for the first `prefetch_step` layers. The weights can be used for the first `prefetch_step` layers in `forward()`, so also clone the weights.
|
||||
self.shared_windows.append(
|
||||
SharedWindowMetadata(
|
||||
weight=layer.weight.clone().detach(),
|
||||
data_layer_idx=layer_idx,
|
||||
work=None,
|
||||
))
|
||||
layer.window_idx = step
|
||||
# When the layer not intended to be stored in this device, link to the corresponding window's tensor.
|
||||
if not is_source:
|
||||
layer.weight.set_(self.shared_windows[-1].weight)
|
||||
else:
|
||||
# Build one more window for prefetch. The weight is useless, so just keep the shape.
|
||||
if step == self.prefetch_step:
|
||||
self.shared_windows.append(
|
||||
SharedWindowMetadata(
|
||||
weight=torch.empty_like(layer.weight),
|
||||
data_layer_idx=-1,
|
||||
work=None,
|
||||
))
|
||||
# When the layer not intended to be stored in this device, dispose the tensor.
|
||||
if not is_source:
|
||||
dispose_tensor(layer.weight)
|
||||
|
||||
dispose_tensor(self.dummy_weight)
|
||||
|
||||
def reach_layer(self, layer_idx: int):
|
||||
# The index of the layer to be prefetched.
|
||||
next_layer_idx = (layer_idx + self.prefetch_step
|
||||
) % self.num_layers + self.start_layer
|
||||
next_layer = self.layers[next_layer_idx - self.start_layer]
|
||||
# The index of the window to store the weight for the coming layer.
|
||||
next_layer.window_idx = self.window_offset
|
||||
window = self.shared_windows[next_layer.window_idx]
|
||||
# When the layer not intended to be stored in this device, link to the corresponding window's tensor.
|
||||
if not self.is_source(next_layer_idx):
|
||||
next_layer.weight.set_(window.weight)
|
||||
# Update `window_offset` by rolling one step.
|
||||
self.window_offset = (self.window_offset + 1) % (self.prefetch_step +
|
||||
1)
|
||||
assert window.data_layer_idx != next_layer_idx
|
||||
window.data_layer_idx = next_layer_idx
|
||||
# Start asynchronous broadcast work.
|
||||
window.work = dist.broadcast(
|
||||
next_layer.weight,
|
||||
src=self.group.ranks[next_layer_idx % self.group.world_size],
|
||||
group=self.group.device_group,
|
||||
async_op=True)
|
||||
|
||||
def wait_weight(self, layer_idx: int):
|
||||
# Find the asynchronous broadcast work and wait for it.
|
||||
assert self.shared_windows
|
||||
window = self.shared_windows[self.layers[layer_idx -
|
||||
self.start_layer].window_idx]
|
||||
# Make sure the data in the corresponding shared window is for the current layer.
|
||||
assert window.data_layer_idx == layer_idx
|
||||
if window.work is not None:
|
||||
window.work.wait()
|
||||
window.work = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerExternalMetadata:
|
||||
"""External metadata for a layer.
|
||||
"""
|
||||
series: SeriesMetadata
|
||||
layer_idx: int
|
||||
|
||||
|
||||
_series_dict: dict[str, SeriesMetadata] = {}
|
||||
|
||||
_layer_external_dict: dict[int, LayerExternalMetadata] = {}
|
||||
|
||||
|
||||
def _create_forward_wrapper(forward: Callable, series: SeriesMetadata,
|
||||
layer_idx: int) -> Callable:
|
||||
|
||||
def wrapped_forward(*args, **kwargs):
|
||||
# Wait for the weight.
|
||||
series.wait_weight(layer_idx)
|
||||
return forward(*args, **kwargs)
|
||||
|
||||
return wrapped_forward
|
||||
|
||||
|
||||
"""
|
||||
Register linear layers into a shared storage series.
|
||||
|
||||
In a parallel group, each device stores a distinct, non-overlapping subset of layers from the series. All layers in a series must have the same structure (are isomorphic). The weight matrix for the i-th layer is stored on device (i % n), where n is the number of devices.
|
||||
|
||||
After loading the model, you must call `post_process_after_loading_for_shared_weight_series(layer)` on any layer of this series to complete the initialization.
|
||||
|
||||
During execution, each time a new layer is reached, you must call `reach_layer_for_shared_weight_series(layer)` for that layer to prefetch the weights. The argument `prefetch_step` is a non-negative integer k that manages asynchronous weight prefetching. Each call to `reach_layer_for_shared_weight_series(current_layer)` method will trigger an asynchronous prefetch for the weights of the k-th subsequent layer after `current_layer` within the series.
|
||||
|
||||
Note: The layers are managed as a circular buffer. The index of the layer to prefetch is determined by the formula:
|
||||
- total_layers = end_layer - start_layer
|
||||
- prefetch_layer_idx = (layer_idx + prefetch_step) % total_layers + start_layer
|
||||
|
||||
To hold the weights for the current layer and the k prefetched layers, a pool of (k + 1) shared tensor buffers will be created for this series.
|
||||
|
||||
Arguments:
|
||||
series_name: This name identifies which series this layer belongs to.
|
||||
group: The group coordinator for handling asynchronous communications. It is recommended to create a new group coordinator for each new series.
|
||||
start_layer: The index of the first layer in the series (inclusive).
|
||||
end_layer: The index of the last layer in the series (exclusive). Thus, the series includes all layers with indices in the range [start_layer, end_layer).
|
||||
layer_idx: The index of the current layer.
|
||||
layer: The linear layer object to register.
|
||||
prefetch_step: An integer that manages asynchronous weight prefetching. Setting it to 0 or 1 can cover most cases.
|
||||
"""
|
||||
|
||||
|
||||
def register_layer_to_shared_weight_series(
|
||||
series_name: str,
|
||||
group: GroupCoordinator,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
layer_idx: int,
|
||||
layer: LinearBase,
|
||||
prefetch_step: int = 1,
|
||||
):
|
||||
global _series_dict
|
||||
if series_name not in _series_dict:
|
||||
num_layers = end_layer - start_layer
|
||||
assert num_layers > 0
|
||||
assert prefetch_step >= 0 and prefetch_step <= num_layers - 2
|
||||
_series_dict[series_name] = SeriesMetadata(
|
||||
group=group,
|
||||
start_layer=start_layer,
|
||||
end_layer=end_layer,
|
||||
num_layers=num_layers,
|
||||
prefetch_step=prefetch_step,
|
||||
dummy_weight=torch.empty_like(layer.weight),
|
||||
layers=[
|
||||
LayerMetadata(
|
||||
layer=None,
|
||||
post_method=lambda layer: None,
|
||||
weight=torch.empty([]),
|
||||
window_idx=-1,
|
||||
) for _ in range(num_layers)
|
||||
],
|
||||
shared_windows=[],
|
||||
window_offset=prefetch_step,
|
||||
)
|
||||
series = _series_dict[series_name]
|
||||
assert layer.quant_method is not None
|
||||
series.layers[layer_idx - start_layer] = LayerMetadata(
|
||||
layer=layer,
|
||||
post_method=layer.quant_method.process_weights_after_loading,
|
||||
weight=layer.weight,
|
||||
window_idx=-1,
|
||||
)
|
||||
# Discard the original `process_weights_after_loading` method such that it won't be called by others.
|
||||
layer.quant_method.process_weights_after_loading = lambda layer: None
|
||||
# When the layer not intended to be stored in this device, dispose the tensor and skip weight loading.
|
||||
if not series.is_source(layer_idx):
|
||||
dispose_tensor(layer.weight)
|
||||
layer.weight.weight_loader = lambda *args, **kwargs: None
|
||||
layer.forward = _create_forward_wrapper(layer.forward, series, layer_idx)
|
||||
global _layer_external_dict
|
||||
_layer_external_dict[id(layer)] = LayerExternalMetadata(
|
||||
series=series,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
|
||||
|
||||
def post_process_after_loading_for_shared_weight_series(layer: LinearBase):
|
||||
ext = _layer_external_dict[id(layer)]
|
||||
ext.series.post_process_after_loading()
|
||||
|
||||
|
||||
def reach_layer_for_shared_weight_series(layer: LinearBase):
|
||||
ext = _layer_external_dict[id(layer)]
|
||||
ext.series.reach_layer(ext.layer_idx)
|
||||
@@ -1,37 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def torchair_silu_and_mul_forward_oot(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""AscendSiluAndMul forward in torchair mode.
|
||||
|
||||
The key difference from the original implementation is the removal of operators
|
||||
from the torch.ops.vllm class, as these operators only function in non-torchair
|
||||
modes. Adding them back would cause the graph compilation to fail.
|
||||
"""
|
||||
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
||||
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16)
|
||||
else:
|
||||
out = torch_npu.npu_swiglu(x)
|
||||
return out
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,78 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
|
||||
_original_re_init = RMSNorm.__init__
|
||||
|
||||
|
||||
def torchair_rmsnorm_init_(
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float = 1e-6,
|
||||
var_hidden_size: Optional[int] = None,
|
||||
has_weight: bool = True,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
_original_re_init(self, hidden_size, eps, var_hidden_size, has_weight,
|
||||
dtype)
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.bias = None
|
||||
# quantization with anti_method m4 will generate none-zero norm bias
|
||||
if vllm_config.quant_config is not None and \
|
||||
any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()):
|
||||
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
|
||||
requires_grad=False)
|
||||
|
||||
|
||||
def torchair_rmsnorm_forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""AscendRMSNorm forward in torchair mode.
|
||||
|
||||
The key difference from the original implementation is the removal of operators
|
||||
from the torch.ops.vllm class, as these operators only function in non-torchair
|
||||
modes. Adding them back would cause the graph compilation to fail.
|
||||
"""
|
||||
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
||||
if residual is not None:
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
orig_dtype = residual.dtype
|
||||
x = x + residual.to(x.dtype)
|
||||
residual = x.to(orig_dtype)
|
||||
x, _ = torch_npu.npu_rms_norm(x, self.weight,
|
||||
self.variance_epsilon)
|
||||
else:
|
||||
x, _, residual = torch_npu.npu_add_rms_norm(
|
||||
x, residual, self.weight, self.variance_epsilon)
|
||||
if self.bias is not None:
|
||||
x.add_(self.bias)
|
||||
return x, residual
|
||||
|
||||
x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
|
||||
if self.bias is not None:
|
||||
x.add_(self.bias)
|
||||
return x
|
||||
@@ -1,367 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.utils import (AscendDeviceType, enable_custom_op,
|
||||
get_ascend_device_type)
|
||||
|
||||
|
||||
def custom_rotary_embedding_enabled(query, neox_style, head_size):
|
||||
return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and enable_custom_op(
|
||||
)
|
||||
|
||||
|
||||
def rope_forward_oot(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
is_neox_style_override: Optional[bool] = None,
|
||||
is_qwen_torchair: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if get_ascend_config(
|
||||
).torchair_graph_config.enabled and not is_qwen_torchair:
|
||||
return self.forward_native(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
offsets,
|
||||
)
|
||||
|
||||
query_shape, key_shape = query.shape, key.shape
|
||||
if self.cos_sin_cache.device != query.device:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
|
||||
if self.cos_sin_cache.dtype != query.dtype:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
|
||||
neox_style = self.is_neox_style
|
||||
if is_neox_style_override is not None:
|
||||
neox_style = is_neox_style_override
|
||||
# adopt custom kernel path for rotary_embedding
|
||||
if custom_rotary_embedding_enabled(
|
||||
query, neox_style, self.head_size) and get_ascend_device_type(
|
||||
) != AscendDeviceType._310P:
|
||||
query, key = torch.ops._C_ascend.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
neox_style,
|
||||
)
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
if offsets is not None:
|
||||
raise NotImplementedError(
|
||||
"Batched rotary embedding is currently not supported on NPU.")
|
||||
else:
|
||||
# TODO: Remove the contiguous in the future.
|
||||
query = query.contiguous().view(query.shape[0], -1)
|
||||
key = key.contiguous().view(key.shape[0], -1)
|
||||
torch_npu._npu_rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
neox_style,
|
||||
)
|
||||
return query.view(query_shape), key.view(key_shape)
|
||||
|
||||
|
||||
def native_rope_deepseek_forward(self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None):
|
||||
if len(key.shape) == 2:
|
||||
key = key[:, None, :]
|
||||
# Note: we implement the non neox_style method with shuffle the last dim and neox style
|
||||
# calculation method which is also more compute friendly to the ascend machine
|
||||
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py
|
||||
neox_style = True
|
||||
if self.is_neox_style is False:
|
||||
b, h_q, d = query.shape
|
||||
query = query.view(b, h_q, d // 2, 2).transpose(3,
|
||||
2).reshape(b, h_q, d)
|
||||
b, h_k, d = key.shape
|
||||
key = key.view(b, h_k, d // 2, 2).transpose(3, 2).reshape(b, h_k, d)
|
||||
q_pe, k_pe = rope_forward_oot(self, positions, query, key, offsets,
|
||||
neox_style)
|
||||
return q_pe, k_pe
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., :x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2:]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Inverse dim formula to find dim based on number of rotations
|
||||
def yarn_find_correction_dim(num_rotations,
|
||||
dim,
|
||||
base=10000,
|
||||
max_position_embeddings=2048):
|
||||
# Note: use torch instead of math to solve MTP compilation error.
|
||||
return (dim * torch.log(
|
||||
torch.tensor(max_position_embeddings) /
|
||||
(num_rotations * 2 * torch.pi))) / (2 * torch.log(torch.tensor(base)))
|
||||
|
||||
|
||||
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
|
||||
# Find dim range bounds based on rotations
|
||||
def yarn_find_correction_range(low_rot,
|
||||
high_rot,
|
||||
dim,
|
||||
base=10000,
|
||||
max_position_embeddings=2048):
|
||||
# Note: use torch instead of math to solve MTP compilation error.
|
||||
low = torch.floor(
|
||||
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
||||
high = torch.ceil(
|
||||
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
|
||||
# Note: use torch instead of max/min to solve MTP compilation error.
|
||||
return torch.clamp(low, min=0), torch.clamp(high, max=dim - 1)
|
||||
|
||||
|
||||
def yarn_linear_ramp_mask(min_value, max_value, dim):
|
||||
# Note: The if conditional branch is not used here
|
||||
# to solve MTP compilation error.
|
||||
max_value += (min_value == max_value).float() * 0.001
|
||||
linear_func = (torch.arange(dim, dtype=torch.float32) -
|
||||
min_value) / (max_value - min_value)
|
||||
ramp_func = torch.clamp(linear_func, 0, 1)
|
||||
return ramp_func
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`):
|
||||
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
||||
used to pass offsetted position ids when working with a KV-cache.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos[position_ids]
|
||||
sin = sin[position_ids]
|
||||
cos = cos[:, None, None, :]
|
||||
sin = sin[:, None, None, :]
|
||||
|
||||
if len(q.shape) == 3:
|
||||
q = q[:, :, None, :]
|
||||
if len(k.shape) == 2:
|
||||
k = k[:, None, None, :]
|
||||
elif len(k.shape) == 3:
|
||||
k = k[:, :, None, :]
|
||||
|
||||
b, h_q, s, d = q.shape
|
||||
q = q.view(b, h_q, s, d // 2, 2).transpose(4, 3).reshape(b, h_q, s, d)
|
||||
|
||||
b, h_k, s, d = k.shape
|
||||
k = k.view(b, h_k, s, d // 2, 2).transpose(4, 3).reshape(b, h_k, s, d)
|
||||
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
|
||||
q_embed = q_embed.view(b, h_q, d)
|
||||
k_embed = k_embed.view(b, h_k, d)
|
||||
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def _set_cos_sin_cache(self, max_seq_len, device, dtype):
|
||||
dim = self.rotary_dim
|
||||
|
||||
freq_extra = 1.0 / (self.base**(
|
||||
torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
||||
freq_inter = 1.0 / (self.scaling_factor * self.base**(
|
||||
torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
||||
|
||||
low, high = yarn_find_correction_range(
|
||||
self.beta_fast,
|
||||
self.beta_slow,
|
||||
dim,
|
||||
self.base,
|
||||
self.max_position_embeddings,
|
||||
)
|
||||
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
|
||||
device=device, dtype=torch.float32)
|
||||
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
t = torch.arange(max_seq_len, device=device, dtype=torch.float32)
|
||||
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale
|
||||
sin_cached = torch.cat([freqs, freqs], dim=-1).sin() * self.mscale
|
||||
cos_cached = cos_cached.to(dtype)
|
||||
sin_cached = sin_cached.to(dtype)
|
||||
cache = torch.cat([freqs.cos() * self.mscale,
|
||||
freqs.sin() * self.mscale],
|
||||
dim=-1).to(dtype)
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
self.register_buffer("cos_cached", cos_cached, persistent=False)
|
||||
self.register_buffer("sin_cached", sin_cached, persistent=False)
|
||||
|
||||
|
||||
def __set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
inv_freq = 1.0 / (self.base**(torch.arange(
|
||||
0, self.rotary_dim, 2, device=device, dtype=torch.float32) *
|
||||
(1 / self.rotary_dim)))
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
|
||||
t = torch.arange(self.max_position_embeddings,
|
||||
device=self.inv_freq.device,
|
||||
dtype=torch.float32)
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos", emb.cos().to(dtype=dtype), persistent=False)
|
||||
self.register_buffer("sin", emb.sin().to(dtype=dtype), persistent=False)
|
||||
self.embed = F.embedding
|
||||
|
||||
|
||||
_original_re_init = RotaryEmbedding.__init__
|
||||
|
||||
|
||||
def qwen_rope_init_func(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
_original_re_init(self, head_size, rotary_dim, max_position_embeddings,
|
||||
base, is_neox_style, dtype)
|
||||
if get_ascend_config().torchair_graph_config.enabled:
|
||||
__set_cos_sin_cache(self,
|
||||
seq_len=max_position_embeddings,
|
||||
device="npu",
|
||||
dtype=dtype)
|
||||
|
||||
|
||||
def rope_forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
is_neox_style_override: Optional[bool] = None,
|
||||
max_seq_len: Optional[int] = None,
|
||||
is_prefill: Optional[bool] = True,
|
||||
is_qwen_torchair: Optional[bool] = False,
|
||||
):
|
||||
if get_ascend_config().torchair_graph_config.enabled \
|
||||
and is_qwen_torchair and not is_prefill:
|
||||
if max_seq_len is not None and torch.gt(max_seq_len,
|
||||
self.max_position_embeddings):
|
||||
__set_cos_sin_cache(self,
|
||||
seq_len=max_seq_len,
|
||||
device=query.device,
|
||||
dtype=torch.float32)
|
||||
|
||||
# bsnd/bnsd
|
||||
if positions is not None:
|
||||
cos = self.embed(positions, self.cos)
|
||||
sin = self.embed(positions, self.sin)
|
||||
self.cos_embed = cos
|
||||
self.sin_embed = sin
|
||||
else:
|
||||
cos = self.cos_embed
|
||||
sin = self.sin_embed
|
||||
|
||||
query = query.view(*query.shape[:-1], -1, self.head_size).contiguous()
|
||||
key = key.view(*key.shape[:-1], -1, self.head_size).contiguous()
|
||||
|
||||
cos = cos.unsqueeze(-2).unsqueeze(-2)
|
||||
sin = sin.unsqueeze(-2).unsqueeze(-2)
|
||||
|
||||
query = query.unsqueeze(1)
|
||||
key = key.unsqueeze(1)
|
||||
|
||||
q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(
|
||||
query, key, cos, sin)
|
||||
return q_embed.flatten(-2), k_embed.flatten(-2)
|
||||
else:
|
||||
return rope_forward_oot(self, positions, query, key, offsets,
|
||||
is_neox_style_override,
|
||||
is_qwen_torchair) # type: ignore
|
||||
|
||||
|
||||
def deepseek_rope_init_func(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
extrapolation_factor: float = 1,
|
||||
attn_factor: float = 1,
|
||||
beta_fast: int = 32,
|
||||
beta_slow: int = 1,
|
||||
mscale: float = 1,
|
||||
mscale_all_dim: float = 0,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
self.extrapolation_factor = extrapolation_factor
|
||||
self.attn_factor = attn_factor
|
||||
self.beta_fast = beta_fast
|
||||
self.beta_slow = beta_slow
|
||||
# Get n-d magnitude scaling corrected for interpolation.
|
||||
self.mscale = float(
|
||||
yarn_get_mscale(self.scaling_factor, float(mscale)) /
|
||||
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
|
||||
attn_factor)
|
||||
super(DeepseekScalingRotaryEmbedding,
|
||||
self).__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style, dtype)
|
||||
|
||||
# NOTE: For ascend friendly computing, reorder sin and cos cache
|
||||
self.max_seq_len = math.ceil(max_position_embeddings * scaling_factor)
|
||||
_set_cos_sin_cache(self, self.max_seq_len, dtype=dtype, device="npu")
|
||||
@@ -1,38 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
|
||||
|
||||
def vocab_embedding_forward(self, input_):
|
||||
if self.tp_size > 1:
|
||||
# Build the mask.
|
||||
masked_input, input_mask = self._get_masked_input_and_mask(
|
||||
input_, self.shard_indices.org_vocab_start_index,
|
||||
self.shard_indices.org_vocab_end_index,
|
||||
self.shard_indices.num_org_vocab_padding,
|
||||
self.shard_indices.added_vocab_start_index,
|
||||
self.shard_indices.added_vocab_end_index)
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = self.quant_method.embedding(self, masked_input.long())
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
return output
|
||||
@@ -1,501 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts
|
||||
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import (
|
||||
torchair_fused_experts_with_all2all, torchair_fused_experts_with_mc2)
|
||||
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
||||
|
||||
|
||||
class TorchairAscendW4A8DynamicLinearMethod:
|
||||
"""Linear method for Ascend W4A8_DYNAMIC
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.transpose_weight = True
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.group_size = vllm_config.quant_config.quant_description.get(
|
||||
"group_size", 256)
|
||||
quant_version = vllm_config.quant_config.quant_description.get(
|
||||
"version", "0")
|
||||
self.new_quant_version = quant_version == "1.0.0"
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
def get_weight(self, input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
params_dict = {}
|
||||
|
||||
if self.new_quant_version:
|
||||
pack_factor = 2
|
||||
actual_output_size = output_size // pack_factor
|
||||
params_dict["weight"] = torch.empty(actual_output_size,
|
||||
input_size,
|
||||
dtype=torch.int8)
|
||||
params_dict["_packed_dim"] = 0
|
||||
params_dict["_packed_factor"] = pack_factor
|
||||
else:
|
||||
params_dict["weight"] = torch.empty(output_size,
|
||||
input_size,
|
||||
dtype=torch.int8)
|
||||
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def get_perchannel_param(output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
def get_pergroup_param(self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
layer_type: Optional[str] = None) -> Dict[str, Any]:
|
||||
params_dict = {}
|
||||
params_dict["weight_scale"] = torch.empty(output_size,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
params_dict["weight_offset"] = torch.empty(output_size,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
params_dict["weight_scale_second"] = torch.empty(output_size,
|
||||
input_size //
|
||||
self.group_size,
|
||||
dtype=params_dtype)
|
||||
params_dict["weight_offset_second"] = torch.empty(output_size,
|
||||
input_size //
|
||||
self.group_size,
|
||||
dtype=params_dtype)
|
||||
|
||||
if self.new_quant_version:
|
||||
scale_bias_dim = 16 if layer_type == "row" else 1
|
||||
params_dict["scale_bias"] = torch.empty(output_size,
|
||||
scale_bias_dim,
|
||||
dtype=torch.float32)
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def process_scale_second(weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
per_group_scale: torch.Tensor,
|
||||
is_new_quant: bool = False):
|
||||
k, n = weight.shape
|
||||
group_num, n_scale = per_group_scale.shape
|
||||
|
||||
if is_new_quant:
|
||||
n = n * 2
|
||||
|
||||
bias = None
|
||||
if not is_new_quant:
|
||||
weight_high = weight.to(torch.float32).reshape(
|
||||
group_num, -1, n) * per_group_scale.reshape(group_num, 1, n)
|
||||
weight_high = weight_high.reshape(k, n)
|
||||
bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0)
|
||||
|
||||
antiquant_scale = (scale * per_group_scale).reshape(group_num, n)
|
||||
return antiquant_scale.npu(), bias
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
tp_rank: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch_npu.npu_weight_quant_batchmatmul(
|
||||
x,
|
||||
layer.weight,
|
||||
antiquant_scale=layer.weight_scale_second.to(x.dtype),
|
||||
antiquant_group_size=self.group_size,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
if self.transpose_weight:
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
layer.weight_scale.data = layer.weight_scale.data.flatten().to(
|
||||
torch.float32)
|
||||
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
||||
layer.weight_scale_second.data, scale_bias = self.process_scale_second(
|
||||
layer.weight.data,
|
||||
layer.weight_scale.data,
|
||||
layer.weight_scale_second.data.transpose(0, 1).contiguous(),
|
||||
is_new_quant=self.new_quant_version,
|
||||
)
|
||||
|
||||
if self.new_quant_version:
|
||||
if hasattr(layer, "scale_bias"):
|
||||
if layer.scale_bias.data.shape[1] == 1:
|
||||
layer.scale_bias.data = layer.scale_bias.data.flatten()
|
||||
else:
|
||||
layer.scale_bias.data = layer.scale_bias.data.contiguous()
|
||||
else:
|
||||
if scale_bias is not None:
|
||||
param = torch.nn.Parameter(scale_bias, requires_grad=False)
|
||||
layer.register_parameter("weight_scale_bias", param)
|
||||
|
||||
if self.new_quant_version:
|
||||
assert layer.weight.data.shape[-1] % 4 == 0, \
|
||||
f"the last dim of weight needs to be divided by 4, got shape {layer.weight.data.shape}"
|
||||
layer.weight.data = layer.weight.data.view(
|
||||
torch.int32).contiguous()
|
||||
else:
|
||||
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(
|
||||
layer.weight.data.to(torch.int32))
|
||||
|
||||
|
||||
class TorchairAscendW4A8DynamicFusedMoEMethod:
|
||||
"""FusedMoe method for Ascend W4A8_DYNAMIC.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.transpose_weight = True
|
||||
|
||||
self.ep_group = get_ep_group()
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.group_size = vllm_config.quant_config.quant_description.get(
|
||||
"group_size", 256)
|
||||
# NOTE: the weights are quantized from bf16 to int4 through a per-channel quantization process
|
||||
self.is_per_channel_weight = self.group_size == 0
|
||||
quant_version = vllm_config.quant_config.quant_description.get(
|
||||
"version", "0")
|
||||
# NOTE: new quantize weights: 2 int4 pack into int8
|
||||
self.new_quant_version = quant_version == "1.0.0"
|
||||
self.tp_size = 1 if vllm_config.parallel_config.enable_expert_parallel else self.ep_group.world_size
|
||||
if self.new_quant_version and self.tp_size > 16:
|
||||
raise ValueError(
|
||||
"The current weight does not support moe part tp>16.")
|
||||
|
||||
try:
|
||||
device_group = get_mc2_group().device_group
|
||||
# TODO: Try local_rank = ep_group.rank_in_group
|
||||
local_rank = torch.distributed.get_rank(group=device_group)
|
||||
backend = device_group._get_backend(torch.device("npu"))
|
||||
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
|
||||
local_rank)
|
||||
except AttributeError:
|
||||
self.moe_all_to_all_group_name = ""
|
||||
|
||||
def get_weight(self, num_experts: int,
|
||||
intermediate_size_per_partition: int, hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
param_dict = {}
|
||||
if self.new_quant_version:
|
||||
w13_output_size = intermediate_size_per_partition
|
||||
w2_output_size = hidden_sizes // 2
|
||||
else:
|
||||
w13_output_size = 2 * intermediate_size_per_partition
|
||||
w2_output_size = hidden_sizes
|
||||
|
||||
param_dict["w13_weight"] = torch.empty(num_experts,
|
||||
w13_output_size,
|
||||
hidden_sizes,
|
||||
dtype=torch.int8)
|
||||
param_dict["w2_weight"] = torch.empty(num_experts,
|
||||
w2_output_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int8)
|
||||
return param_dict
|
||||
|
||||
def get_dynamic_quant_param(self, num_experts: int,
|
||||
intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
param_dict = {}
|
||||
param_dict["w13_weight_scale"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
|
||||
param_dict["w13_weight_offset"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
|
||||
param_dict["w2_weight_scale"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_weight_offset"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
|
||||
if not self.is_per_channel_weight:
|
||||
param_dict["w13_weight_scale_second"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_sizes // self.group_size,
|
||||
dtype=torch.float32)
|
||||
param_dict["w13_weight_offset_second"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_sizes // self.group_size,
|
||||
dtype=torch.float32)
|
||||
|
||||
param_dict["w2_weight_scale_second"] = torch.empty(
|
||||
num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition // self.group_size,
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_weight_offset_second"] = torch.empty(
|
||||
num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition // self.group_size,
|
||||
dtype=torch.float32)
|
||||
|
||||
if self.new_quant_version:
|
||||
param_dict["w13_scale_bias"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_scale_bias"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
16 // self.tp_size,
|
||||
dtype=torch.float32)
|
||||
|
||||
return param_dict
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = True,
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
||||
|
||||
if global_num_experts == 256:
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||
router_logits,
|
||||
k=top_k, # topk currently is 8
|
||||
bias=e_score_correction_bias,
|
||||
k_group=topk_group, # fix: 4
|
||||
group_count=num_expert_group, # fix 8
|
||||
group_select_mode=
|
||||
1, # 0: the maximum in the group; 1: topk2.sum(fix)
|
||||
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
||||
norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
||||
# out_flag=False, # todo new api; should the third output be output
|
||||
# y2_flag=False, # old api; should the third output be output
|
||||
routed_scaling_factor=1,
|
||||
eps=float(1e-20))
|
||||
else:
|
||||
topk_weights, topk_ids = torchair_select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
fused_moe_state = get_forward_context().fused_moe_state
|
||||
shared_gate_up, shared_dequant_scale = None, None
|
||||
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
|
||||
with npu_stream_switch("moe_secondary", 0):
|
||||
npu_wait_tensor(quantized_x_for_share, router_logits)
|
||||
share_up_out, _ = shared_experts.gate_up_proj(
|
||||
(quantized_x_for_share, dynamic_scale_for_share))
|
||||
shared_gate_up, shared_dequant_scale = share_up_out[
|
||||
0], share_up_out[1]
|
||||
|
||||
# this is a naive implementation for experts load balance so as
|
||||
# to avoid accumulating too much tokens on a single rank.
|
||||
# currently it is only activated when doing profile runs.
|
||||
if enable_force_load_balance:
|
||||
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
||||
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
if fused_moe_state == FusedMoEState.MC2:
|
||||
return torchair_fused_experts_with_mc2(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w1_scale_bias=layer.w13_scale_bias,
|
||||
w2_scale_bias=layer.w2_scale_bias,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map,
|
||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
shared_experts=shared_experts,
|
||||
is_torchair=self.torchair_graph_enabled,
|
||||
quantized_x_for_share=shared_gate_up,
|
||||
dynamic_scale_for_share=shared_dequant_scale,
|
||||
mc2_mask=kwargs.get("mc2_mask", None),
|
||||
dynamic_eplb=self.dynamic_eplb)
|
||||
else:
|
||||
# The current implementation of deepseek moe splits hidden_states
|
||||
# according to tp_size before they are feed into layers module.
|
||||
# Therefore, all2all is needed no matter how dp/tp is set so as to
|
||||
# dispatch/combine tokens.
|
||||
return torchair_fused_experts_with_all2all(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w1_scale_bias=layer.w13_scale_bias,
|
||||
w2_scale_bias=layer.w2_scale_bias,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map,
|
||||
ep_group=self.ep_group,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
)
|
||||
|
||||
def process_scale(self, weight: torch.Tensor, scale, per_group_scale):
|
||||
scale = scale.transpose(1, 2).contiguous()
|
||||
if self.is_per_channel_weight:
|
||||
scale_np = scale.cpu().numpy()
|
||||
scale_np.dtype = np.uint32
|
||||
scale_uint64_tensor = torch.from_numpy(scale_np.astype(
|
||||
np.int64)).npu()
|
||||
return scale_uint64_tensor, None
|
||||
per_group_scale = per_group_scale.transpose(1, 2).contiguous()
|
||||
group_num, k, n = weight.shape
|
||||
# the weight of the new version is reduced by half by pack n, so it needs to be restored
|
||||
if self.new_quant_version:
|
||||
n = n * 2
|
||||
per_group_scale = per_group_scale.reshape(group_num, -1, n)
|
||||
group_num, quantgroup_num, n = per_group_scale.shape
|
||||
bias = None
|
||||
if not self.new_quant_version:
|
||||
weight_high = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * \
|
||||
per_group_scale.reshape([group_num, quantgroup_num, 1, n])
|
||||
weight_high = weight_high.reshape([group_num, k, n])
|
||||
bias = 8 * (weight_high.to(torch.float32) * scale).sum(axis=1)
|
||||
scale_fp32 = (scale * per_group_scale).to(torch.float16).to(
|
||||
torch.float32)
|
||||
scale_fp32_np = scale_fp32.cpu().numpy()
|
||||
scale_fp32_np.dtype = np.uint32
|
||||
sscale_uint64 = np.zeros((group_num, quantgroup_num, n * 2),
|
||||
dtype=np.uint32)
|
||||
|
||||
sscale_uint64[..., ::2] = scale_fp32_np
|
||||
|
||||
sscale_uint64_buffer = np.frombuffer(sscale_uint64.tobytes(),
|
||||
dtype=np.int64).copy()
|
||||
sscale_uint64_tensor = torch.from_numpy(sscale_uint64_buffer).reshape(
|
||||
group_num, quantgroup_num, n)
|
||||
sscale_uint64_tensor = sscale_uint64_tensor.npu()
|
||||
return sscale_uint64_tensor, bias
|
||||
|
||||
def update_bias(self, layer, w13_bias, w2_bias):
|
||||
if self.new_quant_version:
|
||||
layer.w13_scale_bias.data = layer.w13_scale_bias.data.transpose(
|
||||
1, 2).contiguous().sum(axis=1)
|
||||
layer.w2_scale_bias.data = layer.w2_scale_bias.data.transpose(
|
||||
1, 2).contiguous().sum(axis=1)
|
||||
else:
|
||||
w13_scale_bias = torch.nn.Parameter(w13_bias, requires_grad=False)
|
||||
layer.register_parameter("w13_scale_bias", w13_scale_bias)
|
||||
w2_scale_bias = torch.nn.Parameter(w2_bias, requires_grad=False)
|
||||
layer.register_parameter("w2_scale_bias", w2_scale_bias)
|
||||
|
||||
def pack_to_int32(self, weight: torch.Tensor):
|
||||
if self.new_quant_version:
|
||||
# pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4
|
||||
assert weight.shape[
|
||||
-1] % 4 == 0, "the last dim of weight needs to be divided by 4"
|
||||
return weight.view(torch.int32).contiguous()
|
||||
else:
|
||||
return torch_npu.npu_quantize(weight.to(torch.float32),
|
||||
torch.tensor([1.]).npu(), None,
|
||||
torch.quint4x2, -1, False)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.transpose_weight:
|
||||
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
w13_weight_scale_second = layer.w13_weight_scale_second.data if hasattr(
|
||||
layer, "w13_weight_scale_second") else None
|
||||
w2_weight_scale_second = layer.w2_weight_scale_second.data if hasattr(
|
||||
layer, "w2_weight_scale_second") else None
|
||||
layer.w13_weight_scale.data, w13_bias = self.process_scale(
|
||||
layer.w13_weight, layer.w13_weight_scale.data,
|
||||
w13_weight_scale_second)
|
||||
layer.w2_weight_scale.data, w2_bias = self.process_scale(
|
||||
layer.w2_weight, layer.w2_weight_scale.data,
|
||||
w2_weight_scale_second)
|
||||
if hasattr(layer, "w13_weight_scale_second"):
|
||||
# scale_second is no longer used, release this part of the memory
|
||||
del layer.w13_weight_scale_second
|
||||
del layer.w2_weight_scale_second
|
||||
del layer.w13_weight_offset_second
|
||||
del layer.w2_weight_offset_second
|
||||
|
||||
self.update_bias(layer, w13_bias, w2_bias)
|
||||
|
||||
layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data)
|
||||
layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,457 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer,
|
||||
AttentionType)
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
|
||||
AscendAttentionMetadataBuilder,
|
||||
AscendAttentionState,
|
||||
AscendMetadata)
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendDeviceType,
|
||||
aligned_16, get_ascend_device_type, nd_to_nz_2d)
|
||||
|
||||
|
||||
class AscendAttentionTorchairBackend(AscendAttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ASCEND_TORCHAIR"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["AscendAttentionTorchairBackendImpl"]:
|
||||
return AscendAttentionTorchairBackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["AscendAttentionTorchairMetadataBuilder"]:
|
||||
return AscendAttentionTorchairMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_blocks, block_size, num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_bsh_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_blocks, block_size, num_kv_heads * head_size)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendDecodeMetadata:
|
||||
# Input positions for rotrary embeddings since for MLA the rotary
|
||||
# position embeddings are applied inside the attention backend
|
||||
input_positions: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
max_seq_lens: int
|
||||
seq_lens_list: list[int]
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendTorchairMetadata(AscendMetadata):
|
||||
|
||||
decode: Optional[AscendDecodeMetadata] = None
|
||||
|
||||
|
||||
class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec,
|
||||
layer_names,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
self.max_num_blocks_per_req = cdiv(
|
||||
self.model_config.max_model_len,
|
||||
self.vllm_config.cache_config.block_size)
|
||||
self.max_blocks = (self.model_config.max_model_len +
|
||||
self.vllm_config.cache_config.block_size -
|
||||
1) // self.vllm_config.cache_config.block_size
|
||||
|
||||
def _get_graph_runner_block_tables(
|
||||
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
|
||||
max_blocks = self.max_blocks
|
||||
|
||||
graph_block_tables = torch.zeros((num_seqs, max_blocks),
|
||||
dtype=block_tables.dtype,
|
||||
device=block_tables.device)
|
||||
|
||||
num_blocks = block_tables.size(1)
|
||||
if num_blocks <= max_blocks:
|
||||
graph_block_tables[:num_seqs, :
|
||||
num_blocks] = block_tables[:num_seqs, :
|
||||
num_blocks]
|
||||
else:
|
||||
graph_block_tables[:num_seqs, :
|
||||
max_blocks] = block_tables[:num_seqs, :
|
||||
max_blocks]
|
||||
|
||||
return graph_block_tables[:, :max_blocks]
|
||||
|
||||
def build_torchair_graph_dummy(
|
||||
self, common_attn_metadata: TorchairCommonAttentionMetadata
|
||||
) -> AscendTorchairMetadata:
|
||||
device = self.device
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
block_table = torch.zeros((num_reqs, self.max_blocks),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
block_table = self._get_graph_runner_block_tables(
|
||||
num_reqs, block_table)
|
||||
seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device)
|
||||
input_positions = torch.zeros(num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=device).long()
|
||||
slot_mapping = torch.full((num_reqs, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
query_start_loc = torch.full((num_reqs, ),
|
||||
-1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
decode_metadata = AscendDecodeMetadata(input_positions=input_positions,
|
||||
block_table=block_table,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_list=seq_lens.tolist(),
|
||||
max_seq_lens=1)
|
||||
|
||||
attn_metadata = AscendTorchairMetadata(
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
block_tables=block_table,
|
||||
query_lens=0,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_lens=seq_lens,
|
||||
slot_mapping=slot_mapping,
|
||||
attn_state=AscendAttentionState.DecodeOnly,
|
||||
decode=decode_metadata)
|
||||
return attn_metadata
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: Optional[nn.Module] = None,
|
||||
):
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
block_table[:num_reqs, :self.max_num_blocks_per_req] = (
|
||||
block_table[:num_reqs])
|
||||
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
||||
attn_mask = common_attn_metadata.attn_mask
|
||||
|
||||
attn_state = common_attn_metadata.attn_state
|
||||
if get_ascend_device_type(
|
||||
) == AscendDeviceType._310P and attn_state == AscendAttentionState.PrefillNoCache:
|
||||
mask_nz = nd_to_nz_2d(attn_mask)
|
||||
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), 29)
|
||||
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
||||
num_reqs
|
||||
+ 1]
|
||||
query_start_loc = query_start_loc_cpu.to(self.device,
|
||||
non_blocking=True)
|
||||
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
input_positions = common_attn_metadata.positions[:
|
||||
num_actual_tokens].long(
|
||||
)
|
||||
|
||||
decode_metadata = None
|
||||
graph_pad_size = common_attn_metadata.graph_pad_size
|
||||
use_torchair_graph = graph_pad_size > -1
|
||||
if common_attn_metadata.attn_state in [
|
||||
AscendAttentionState.DecodeOnly,
|
||||
]:
|
||||
max_seq_lens = seq_lens.max().item()
|
||||
num_seqs = len(seq_lens)
|
||||
if use_torchair_graph and common_attn_metadata.attn_state in [
|
||||
AscendAttentionState.DecodeOnly,
|
||||
]:
|
||||
num_reqs_pad_size = 0
|
||||
num_token_pad_size = 0
|
||||
if graph_pad_size != 0:
|
||||
pad_value = 0
|
||||
num_token_pad_size = graph_pad_size - num_actual_tokens
|
||||
num_reqs_pad_size = (
|
||||
graph_pad_size //
|
||||
common_attn_metadata.decode_token_per_req - num_reqs)
|
||||
pad_value = 1
|
||||
padded_seq_lens = seq_lens.tolist() + [pad_value
|
||||
] * num_reqs_pad_size
|
||||
|
||||
seq_lens = torch.from_numpy(
|
||||
np.array(padded_seq_lens).astype(np.int32))
|
||||
padding = torch.full((num_token_pad_size, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=slot_mapping.dtype,
|
||||
device=slot_mapping.device)
|
||||
slot_mapping = torch.cat([slot_mapping, padding])
|
||||
block_table_padding = torch.zeros(
|
||||
(num_reqs_pad_size, ) + block_table.shape[1:],
|
||||
dtype=block_table.dtype,
|
||||
device=block_table.device)
|
||||
block_table = torch.cat([block_table, block_table_padding],
|
||||
dim=0)
|
||||
block_table = self._get_graph_runner_block_tables(
|
||||
num_seqs + num_reqs_pad_size, block_table)
|
||||
padding_0 = torch.zeros(num_token_pad_size,
|
||||
dtype=input_positions.dtype,
|
||||
device=input_positions.device)
|
||||
input_positions = torch.cat([input_positions, padding_0])
|
||||
|
||||
decode_metadata = AscendDecodeMetadata(
|
||||
input_positions=input_positions,
|
||||
block_table=block_table,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_list=seq_lens.tolist(),
|
||||
max_seq_lens=max_seq_lens,
|
||||
attn_mask=None)
|
||||
|
||||
attn_metadata = AscendTorchairMetadata(
|
||||
decode=decode_metadata,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
block_tables=block_table,
|
||||
query_start_loc=query_start_loc,
|
||||
query_lens=query_lens,
|
||||
seq_lens=seq_lens,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
slot_mapping=slot_mapping,
|
||||
attn_mask=attn_mask,
|
||||
attn_state=attn_state)
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class AscendAttentionTorchairBackendImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
self.hidden_size = self.num_heads * self.head_size
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.sliding_window = sliding_window
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes,
|
||||
dtype=torch.float32,
|
||||
device="npu")
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.attn_type = attn_type
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.key_cache = None
|
||||
self.value_cache = None
|
||||
self.scale_tensor = torch.zeros((), device='npu', dtype=torch.int32)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AscendTorchairMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Ascend attention.
|
||||
Args:
|
||||
query: shape = [batch_size, seq_len, num_heads * head_size]
|
||||
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
kv_cache: shape = [2, num_blocks, block_size,
|
||||
num_kv_heads, head_size]
|
||||
key_cache = [num_blocks, block_size,
|
||||
num_kv_heads, head_size]
|
||||
value_cache = [num_blocks, block_size,
|
||||
num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [batch_size * seq_len, num_heads, head_size]
|
||||
"""
|
||||
num_tokens = query.shape[0]
|
||||
use_kv_cache_quant = (kv_cache is not None and len(kv_cache) > 0
|
||||
and kv_cache[0].numel() > 0
|
||||
and kv_cache[0].dtype == torch.int8)
|
||||
if output is None:
|
||||
output = torch.empty(num_tokens,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
|
||||
if hasattr(layer, 'quant_method') and use_kv_cache_quant:
|
||||
output = layer.quant_method.apply(layer, query, key, value,
|
||||
kv_cache, attn_metadata,
|
||||
self.attn_type, self.scale,
|
||||
output)
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
|
||||
if attn_metadata is None:
|
||||
return output.view(num_tokens, self.hidden_size).fill_(0)
|
||||
|
||||
output = output.view(-1, self.num_heads, self.head_size)
|
||||
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
attn_type = self.attn_type
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"AscendAttentionTorchairBackendImpl")
|
||||
|
||||
if kv_cache is not None and kv_cache[0].numel() > 0:
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
slots = attn_metadata.slot_mapping
|
||||
|
||||
block_size = self.scale_tensor + key_cache.shape[1]
|
||||
slots_indices = slots.reshape(-1, 1)
|
||||
block_indices = slots_indices // block_size
|
||||
slots_indices = slots_indices % block_size
|
||||
indices = torch.cat((block_indices, slots_indices), dim=1)
|
||||
torch_npu.npu_scatter_nd_update_(key_cache, indices, key)
|
||||
torch_npu.npu_scatter_nd_update_(value_cache, indices, value)
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
|
||||
self.key_cache = key_cache
|
||||
self.value_cache = value_cache
|
||||
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.attn_mask is not None
|
||||
mask = attn_metadata.attn_mask
|
||||
|
||||
# View q k v to BSH.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
# align q k v output tensors
|
||||
query = aligned_16(query)
|
||||
key = aligned_16(key)
|
||||
value = aligned_16(value)
|
||||
output = aligned_16(output)
|
||||
|
||||
# do reformat in case of broadcasted tensors
|
||||
mask = mask.repeat(attn_metadata.seq_lens.size(0), 1, 1, 1)
|
||||
mask = torch_npu.npu_format_cast(mask.contiguous(),
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
torch_npu._npu_flash_attention(query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
mask=mask,
|
||||
seq_len=attn_metadata.seq_lens,
|
||||
scale_value=self.scale,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
out=output)
|
||||
output = output[:num_tokens, :, :]
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.attn_mask is not None
|
||||
compress_mask = attn_metadata.attn_mask
|
||||
batch_size = attn_metadata.query_lens.shape[0]
|
||||
block_table = attn_metadata.block_tables[:batch_size, :]
|
||||
torch_npu._npu_flash_attention_qlens(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
block_table=block_table,
|
||||
mask=compress_mask,
|
||||
seq_len=attn_metadata.query_lens,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
out=output)
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
decode_meta = attn_metadata.decode
|
||||
assert decode_meta is not None
|
||||
seq_lens = decode_meta.seq_lens_list
|
||||
block_table = decode_meta.block_table
|
||||
block_size = key_cache.shape[1]
|
||||
query = query.view(num_tokens, 1,
|
||||
self.num_heads * self.head_size).contiguous()
|
||||
output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
query=query,
|
||||
key=key_cache,
|
||||
value=value_cache,
|
||||
query_rope=None,
|
||||
key_rope=None,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
input_layout='BSH',
|
||||
atten_mask=decode_meta.attn_mask,
|
||||
sparse_mode=0,
|
||||
scale=self.scale,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
block_table=block_table,
|
||||
block_size=block_size,
|
||||
actual_seq_lengths_kv=seq_lens,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Torchair graph mode with non-MLA attention backend is still experimental."
|
||||
"v1 scheduler(chunked prefill) is not supported at this moment."
|
||||
)
|
||||
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,574 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
|
||||
# isort: skip_file
|
||||
|
||||
import math
|
||||
import types
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import logger
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.spec_decode import get_spec_decode_method
|
||||
from vllm_ascend.torchair.utils import (
|
||||
TORCHAIR_CACHE_DIR, TorchairCommonAttentionMetadata,
|
||||
check_torchair_cache_exist, converting_weight_acl_format,
|
||||
register_torchair_model, torchair_ops_patch,
|
||||
torchair_quant_method_register, write_kv_cache_bytes_to_file)
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
AscendDeviceType, get_ascend_device_type)
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
|
||||
|
||||
class NPUTorchairModelRunner(NPUModelRunner):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||||
self.ascend_config = get_ascend_config()
|
||||
self.enable_shared_expert_dp = self.ascend_config.enable_shared_expert_dp
|
||||
super().__init__(vllm_config, device)
|
||||
if self.speculative_config:
|
||||
self.actual_seq_lengths_q = list(
|
||||
range(self.decode_token_per_req, self.max_num_tokens + 1,
|
||||
self.decode_token_per_req))
|
||||
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
|
||||
None, None, vllm_config, device)
|
||||
self.use_sparse = hasattr(self.model_config.hf_config, "index_topk")
|
||||
|
||||
register_torchair_model()
|
||||
torchair_ops_patch()
|
||||
torchair_quant_method_register()
|
||||
if self.enable_shared_expert_dp:
|
||||
return
|
||||
self.new_kv_cache_bytes = -1
|
||||
self.torchair_compiled_model = None # type: ignore
|
||||
self.torchair_compiled_models = {} # type: ignore
|
||||
self.use_cached_npu_graph = self.ascend_config.torchair_graph_config.use_cached_graph
|
||||
self.use_cached_kv_cache_bytes = self.ascend_config.torchair_graph_config.use_cached_kv_cache_bytes
|
||||
self.torchair_graph_batch_sizes = self.ascend_config.torchair_graph_config.graph_batch_sizes
|
||||
if self.ascend_config.torchair_graph_config.graph_batch_sizes_init:
|
||||
self.init_torchair_graph_batch_sizes()
|
||||
|
||||
self.update_torchair_graph_batch_sizes()
|
||||
|
||||
torch._dynamo.cache_size.config.cache_size_limit += len(
|
||||
self.torchair_graph_batch_sizes)
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
torch._logging.set_logs(
|
||||
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
|
||||
|
||||
self._check_batch_sizes_consistency()
|
||||
|
||||
def _set_up_drafter(self):
|
||||
super()._set_up_drafter()
|
||||
if self.speculative_config:
|
||||
# Torchair do not support disable_padded_drafter_batch
|
||||
# Enforce to disable this feature
|
||||
self.speculative_config.disable_padded_drafter_batch = True
|
||||
|
||||
def _get_drafter(self):
|
||||
return get_spec_decode_method(self.speculative_config.method,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
self,
|
||||
is_torchair_graph=True)
|
||||
|
||||
def _may_pad_kv_consumer_num_seq(self):
|
||||
# pd disaggregation scenario need redundant_batch_sizes to avoid each batch's seq_len exceed 16 tokens
|
||||
# self.max_num_reqs here is greater than the actual maximum request number
|
||||
if self.decode_token_per_req > 1 and self.is_kv_consumer:
|
||||
# applied only when speculative decoding is active
|
||||
FIA_SEQ_LEN_LIMIT = 16
|
||||
new_max_num_reqs = self.max_num_reqs + math.ceil(
|
||||
self.max_num_reqs / FIA_SEQ_LEN_LIMIT) + math.ceil(
|
||||
(self.max_num_reqs * self.decode_token_per_req) /
|
||||
(FIA_SEQ_LEN_LIMIT**2))
|
||||
if self.max_num_reqs < new_max_num_reqs:
|
||||
logger.warning(
|
||||
f"max_num_reqs is updated to {new_max_num_reqs}")
|
||||
self.max_num_reqs = new_max_num_reqs
|
||||
|
||||
def _init_mc2_tokens_capacity(self):
|
||||
# NOTE: To be clear, we need to make sure that during graph capture, the number of
|
||||
# tokens is less than or equal to mc2_tokens_capacity. According to _set_cudagraph_sizes,
|
||||
# the max number of tokens in graph is min(max_num_seqs * uniform_decode_query_len, 512).
|
||||
max_num_tokens = self.max_num_reqs * self.uniform_decode_query_len
|
||||
tp_size = self.parallel_config.tensor_parallel_size
|
||||
# Use integer arithmetic for ceiling division.
|
||||
max_graph_batch_size = self.calculate_new_torchair_graph_batch_size(
|
||||
max_num_tokens, tp_size)
|
||||
self.mc2_tokens_capacity = max_graph_batch_size
|
||||
|
||||
if get_ascend_device_type(
|
||||
) == AscendDeviceType._910_93 and self.mc2_tokens_capacity > 512:
|
||||
logger.error(
|
||||
f"A3: the max number of tokens must smaller then 512, but now is {self.mc2_tokens_capacity}"
|
||||
)
|
||||
if get_ascend_device_type(
|
||||
) == AscendDeviceType._910B and self.mc2_tokens_capacity > 256:
|
||||
logger.error(
|
||||
f"A2: the max number of tokens must smaller then 256, but now is {self.mc2_tokens_capacity}"
|
||||
)
|
||||
|
||||
def _sync_metadata_across_dp(
|
||||
self, num_tokens: int,
|
||||
with_prefill: bool) -> tuple[int, Optional[torch.Tensor], bool]:
|
||||
"""Override from NPUModelRunner to pad num_tokens"""
|
||||
if self.enable_shared_expert_dp:
|
||||
# Padding is not required for shared_expert_dp cases in eager mode.
|
||||
return num_tokens, None, with_prefill
|
||||
if self.dp_size == 1:
|
||||
if not with_prefill:
|
||||
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
|
||||
num_tokens)
|
||||
return maybe_padded_num_tokens, None, with_prefill
|
||||
return num_tokens, None, with_prefill
|
||||
|
||||
num_tokens_across_dp = torch.zeros(self.dp_size + 1,
|
||||
dtype=torch.int32,
|
||||
device="npu")
|
||||
num_tokens_across_dp[self.dp_rank] = num_tokens
|
||||
num_tokens_across_dp[-1] = int(with_prefill)
|
||||
dist.all_reduce(num_tokens_across_dp,
|
||||
group=get_dp_group().device_group)
|
||||
with_prefill = bool(num_tokens_across_dp[-1])
|
||||
num_tokens_across_dp = num_tokens_across_dp[:-1]
|
||||
|
||||
if not with_prefill:
|
||||
max_num_token = num_tokens_across_dp.max().item()
|
||||
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
|
||||
max_num_token)
|
||||
num_tokens_across_dp = torch.full((self.dp_size, ),
|
||||
maybe_padded_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device="npu")
|
||||
else:
|
||||
maybe_padded_num_tokens = num_tokens
|
||||
|
||||
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill
|
||||
|
||||
def _build_dummy_attn_metadata(
|
||||
self,
|
||||
with_prefill: bool,
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
max_query_len: int,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
|
||||
force_attention: bool = False,
|
||||
) -> Optional[dict[str, Any]]:
|
||||
# NOTE: If torchair graph mode and not with_prefill,
|
||||
# we can't skip_attn, it will cause graph recompile.
|
||||
if with_prefill or self.enable_shared_expert_dp:
|
||||
attn_metadata = super()._build_dummy_attn_metadata(
|
||||
with_prefill, num_reqs, num_tokens, max_query_len,
|
||||
num_scheduled_tokens, aclgraph_runtime_mode, force_attention)
|
||||
else:
|
||||
common_attn_metadata = TorchairCommonAttentionMetadata(
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=1,
|
||||
actual_seq_lengths_q=self.actual_seq_lengths_q,
|
||||
attn_mask=self.attn_mask,
|
||||
spec_attn_mask=self.spec_attn_mask,
|
||||
decode_token_per_req=self.decode_token_per_req,
|
||||
)
|
||||
attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy(
|
||||
common_attn_metadata)
|
||||
return attn_metadata
|
||||
|
||||
def _generate_dummy_run_hidden_states(self, with_prefill,
|
||||
is_torchair_compile, input_ids,
|
||||
positions, attn_metadata, num_tokens,
|
||||
intermediate_tensors, inputs_embeds):
|
||||
if with_prefill or self.enable_shared_expert_dp:
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND)
|
||||
hidden_states = super()._generate_dummy_run_hidden_states(
|
||||
with_prefill, is_torchair_compile, input_ids, positions,
|
||||
attn_metadata, num_tokens, intermediate_tensors, inputs_embeds)
|
||||
else:
|
||||
# Only mark static while compiling
|
||||
if is_torchair_compile:
|
||||
torch._dynamo.mark_static(input_ids)
|
||||
torch._dynamo.mark_static(positions)
|
||||
torch._dynamo.mark_static(attn_metadata.decode.block_table)
|
||||
torch._dynamo.mark_static(attn_metadata.decode.input_positions)
|
||||
torch._dynamo.mark_static(get_forward_context().mc2_mask)
|
||||
if hasattr(attn_metadata.decode, "sin"):
|
||||
torch._dynamo.mark_static(attn_metadata.decode.sin)
|
||||
torch._dynamo.mark_static(attn_metadata.decode.cos)
|
||||
torch._dynamo.mark_static(attn_metadata.slot_mapping)
|
||||
if self.speculative_config:
|
||||
torch._dynamo.mark_static(attn_metadata.decode.attn_mask)
|
||||
for kv in self.kv_caches:
|
||||
assert isinstance(kv, tuple), "kv_cache must be a tuple"
|
||||
torch._dynamo.mark_static(kv[0])
|
||||
torch._dynamo.mark_static(kv[1])
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
compiled_model = self._get_torchair_lazy_compiled_model(num_tokens)
|
||||
model_kwargs = {}
|
||||
model_kwargs["kv_caches"] = self.kv_caches
|
||||
model_kwargs["attn_metadata"] = attn_metadata
|
||||
hidden_states = compiled_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=None,
|
||||
**model_kwargs,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def _convert_torch_format(self, kv_cache):
|
||||
if self.enable_shared_expert_dp:
|
||||
return super()._convert_torch_format(kv_cache)
|
||||
kv_cache = torch_npu.npu_format_cast(kv_cache, ACL_FORMAT_FRACTAL_ND)
|
||||
return kv_cache
|
||||
|
||||
def _compile_torchair_graph(self, torchair_graph_batch_sizes) -> None:
|
||||
# Trigger torchair graph capture for specific shapes.
|
||||
# Capture the large shapes first so that the smaller shapes
|
||||
# can reuse the memory pool allocated for the large shapes.
|
||||
for idx, num_tokens in enumerate(reversed(torchair_graph_batch_sizes)):
|
||||
for _ in range(self.vllm_config.compilation_config.
|
||||
cudagraph_num_of_warmups):
|
||||
self._dummy_run(num_tokens, is_torchair_compile=True)
|
||||
self._dummy_run(num_tokens, is_torchair_compile=True)
|
||||
logger.info("Batchsize %d is compiled successfully: %d/%d.",
|
||||
num_tokens, idx + 1, len(torchair_graph_batch_sizes))
|
||||
|
||||
def _capture_model(self):
|
||||
"""Override from NPUModelRunner to use torchair graph capture."""
|
||||
if self.enable_shared_expert_dp:
|
||||
return super()._capture_model()
|
||||
# TODO(NeverRaR): Calling graph_capture(device=self.device) in
|
||||
# torchair graph capture can cause some issues, so now we just
|
||||
# temporarily split the codepath for the two different graph patterns.
|
||||
torchair_graph_batch_sizes = self.torchair_graph_batch_sizes
|
||||
graph_num = len(torchair_graph_batch_sizes)
|
||||
|
||||
if self.use_cached_npu_graph and not check_torchair_cache_exist():
|
||||
# If caching is enabled but does not exist (either
|
||||
# use_cached_kv_cache_bytes is disabled or kv_cache_bytes are
|
||||
# different), we will compile the model twice. The first time is
|
||||
# used to generate the cache, and the second time is used to load the
|
||||
# cache to skip the overhead caused by Dynamo guard mechanism.
|
||||
logger.info(
|
||||
"Cache compilation for torchair graph is enabled. Now we compile graph to genetate"
|
||||
" torchair cache, this usually takes %.1f~%.1f mins.",
|
||||
0.5 * graph_num, 1.5 * graph_num)
|
||||
self._compile_torchair_graph(torchair_graph_batch_sizes)
|
||||
NPUPlatform.synchronize()
|
||||
# Note: We reset dynamo and reload the compiled torchair cached computation graph below
|
||||
# that was compiled above. This operation reduces graph launch time by 2-4ms and avoids
|
||||
# runtime errors caused by configuration mismatches in graph mode.
|
||||
torch._dynamo.reset()
|
||||
self.torchair_compiled_models.clear()
|
||||
if self.use_cached_npu_graph:
|
||||
logger.info(
|
||||
"Loading torchair graph cache, this usually takes %.1f~%.1f mins.",
|
||||
0.3 * graph_num, 0.5 * graph_num)
|
||||
self._compile_torchair_graph(torchair_graph_batch_sizes)
|
||||
else:
|
||||
logger.info(
|
||||
"Capturing torchair graph, this usually takes %.1f~%.1f mins.",
|
||||
0.5 * graph_num, 1.5 * graph_num)
|
||||
self._compile_torchair_graph(torchair_graph_batch_sizes)
|
||||
|
||||
if self.use_cached_kv_cache_bytes and self.new_kv_cache_bytes > 0:
|
||||
write_kv_cache_bytes_to_file(torch.distributed.get_rank(),
|
||||
self.new_kv_cache_bytes)
|
||||
|
||||
def _use_aclgraph(self) -> bool:
|
||||
if self.enable_shared_expert_dp:
|
||||
return super()._use_aclgraph()
|
||||
return False
|
||||
|
||||
def _check_batch_sizes_consistency(self) -> None:
|
||||
if not dist.is_initialized():
|
||||
return
|
||||
|
||||
local = torch.tensor(self.torchair_graph_batch_sizes,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
gathered_graph_batch_size = local.clone()
|
||||
dist.all_reduce(gathered_graph_batch_size,
|
||||
group=get_dp_group().cpu_group)
|
||||
expected = local * self.dp_size
|
||||
|
||||
if not torch.equal(gathered_graph_batch_size, expected):
|
||||
diff_idxs = (gathered_graph_batch_size != expected).nonzero(
|
||||
as_tuple=False).flatten().tolist()
|
||||
raise AssertionError(
|
||||
f"[Graph BatchSize Mismatch] Found mismatches at indices {diff_idxs}.\n"
|
||||
f"Local (rank {self.dp_rank}): {local.tolist()}\n"
|
||||
f"Sum over ranks: {gathered_graph_batch_size.tolist()}\n"
|
||||
f"Expected if all equal: {[v * self.dp_size for v in local.tolist()]}"
|
||||
)
|
||||
|
||||
def _update_graph_pad_size(self, with_prefill, graph_pad_size):
|
||||
if with_prefill or self.enable_shared_expert_dp:
|
||||
super()._update_graph_pad_size(with_prefill, graph_pad_size)
|
||||
else:
|
||||
self.graph_pad_size = graph_pad_size
|
||||
|
||||
def _update_input_ids_and_positions(self, input_ids, positions,
|
||||
num_input_tokens, with_prefill,
|
||||
padded_num_tokens_across_dp):
|
||||
"""Override from NPUModelRunner to update input_ids and positions"""
|
||||
input_ids, positions = super()._update_input_ids_and_positions(
|
||||
input_ids, positions, num_input_tokens, with_prefill,
|
||||
padded_num_tokens_across_dp)
|
||||
|
||||
if with_prefill or self.enable_shared_expert_dp:
|
||||
return input_ids, positions
|
||||
else:
|
||||
input_ids = self.input_ids[:padded_num_tokens_across_dp]
|
||||
positions = self.positions[:padded_num_tokens_across_dp]
|
||||
return input_ids, positions
|
||||
|
||||
def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
|
||||
padded_num_tokens_across_dp,
|
||||
input_ids, positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds):
|
||||
if attn_metadata is not None and isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata['model.layers.0.self_attn.attn']
|
||||
|
||||
if self.enable_shared_expert_dp:
|
||||
return super()._generate_process_reqs_hidden_states(
|
||||
attn_metadata, with_prefill, padded_num_tokens_across_dp,
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds)
|
||||
model_kwargs = {
|
||||
"kv_caches": self.kv_caches,
|
||||
"attn_metadata": attn_metadata
|
||||
}
|
||||
if not with_prefill:
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_NZ)
|
||||
compiled_model = self._get_torchair_lazy_compiled_model(
|
||||
padded_num_tokens_across_dp)
|
||||
hidden_states = compiled_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
assert self.model is not None
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**model_kwargs,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def _get_torchair_lazy_compiled_model(self, batch_size: int):
|
||||
if batch_size < 0 or batch_size > self.torchair_graph_batch_sizes[-1]:
|
||||
raise ValueError(
|
||||
f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.torchair_graph_batch_sizes[-1]}"
|
||||
)
|
||||
|
||||
compiled_model = self.torchair_compiled_models.get(
|
||||
batch_size
|
||||
) if self.use_cached_npu_graph else self.torchair_compiled_model
|
||||
|
||||
if compiled_model:
|
||||
return compiled_model
|
||||
|
||||
import torchair # type: ignore
|
||||
from torchair import patch_for_hcom # type: ignore
|
||||
|
||||
patch_for_hcom()
|
||||
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
# on 300I Duo platform, we need to patch broadcast. however, this patch will be
|
||||
# overwritten by patch_for_hcom in torchair. so we need to re-patch it here.
|
||||
from vllm_ascend.patch.platform.patch_distributed import \
|
||||
communication_adaptation_310p
|
||||
communication_adaptation_310p()
|
||||
|
||||
config = torchair.CompilerConfig()
|
||||
if self.ascend_config.torchair_graph_config.mode:
|
||||
config.mode = self.ascend_config.torchair_graph_config.mode
|
||||
config.experimental_config.frozen_parameter = \
|
||||
self.ascend_config.torchair_graph_config.enable_frozen_parameter
|
||||
# enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to
|
||||
# disable it on 300I Duo platform now.
|
||||
config.experimental_config.tiling_schedule_optimize = get_ascend_device_type(
|
||||
) != AscendDeviceType._310P
|
||||
config.experimental_config.enable_view_optimize = \
|
||||
self.ascend_config.torchair_graph_config.enable_view_optimize
|
||||
torch.npu.set_compile_mode(jit_compile=False)
|
||||
if not self.use_cached_npu_graph:
|
||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||
self.torchair_compiled_model = torch.compile(
|
||||
self.model,
|
||||
dynamic=not self.use_sparse,
|
||||
fullgraph=True,
|
||||
backend=npu_backend)
|
||||
return self.torchair_compiled_model
|
||||
else:
|
||||
# Generate a new forward proxy code object to prevent the invalidation of
|
||||
# compilation cache caused by dynamo retracing
|
||||
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
|
||||
forward_fn = self.model.forward
|
||||
code = forward_fn.__code__
|
||||
# Mark code object with a new proxy name
|
||||
modified_code = code.replace(co_name=forward_proxy_name, )
|
||||
|
||||
modified_func = types.FunctionType(modified_code,
|
||||
forward_fn.__globals__,
|
||||
name=forward_proxy_name,
|
||||
argdefs=forward_fn.__defaults__)
|
||||
|
||||
self.model.__dict__[forward_proxy_name] = modified_func.__get__(
|
||||
self.model, nn.Module)
|
||||
self.torchair_compiled_models[
|
||||
batch_size] = torchair.inference.cache_compile(
|
||||
self.model.__dict__[forward_proxy_name],
|
||||
dynamic=not self.use_sparse,
|
||||
fullgraph=True,
|
||||
cache_dir=TORCHAIR_CACHE_DIR,
|
||||
config=config,
|
||||
ge_cache=False)
|
||||
return self.torchair_compiled_models[batch_size]
|
||||
|
||||
def init_torchair_graph_batch_sizes(self):
|
||||
start_graph_batch_size = 4
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
# NOTE: When use all2all | mc2, We need to slice the `num_tokens` dimension into `tp_size` blocks
|
||||
start_graph_batch_size = max(start_graph_batch_size, tp_size)
|
||||
|
||||
while (start_graph_batch_size <= self.max_num_reqs):
|
||||
self.torchair_graph_batch_sizes.append(start_graph_batch_size)
|
||||
start_graph_batch_size *= 2
|
||||
|
||||
def calculate_new_torchair_graph_batch_size(self, old_graph_batch_size,
|
||||
tp_size):
|
||||
cur_graph_batch_size = (old_graph_batch_size + tp_size -
|
||||
1) // tp_size * tp_size
|
||||
# MTP > 1: Cal LCMLeast Common Multiple with graph_batch_size and tp_size,
|
||||
# Both adapter multi-dp and FIA operator
|
||||
if self.speculative_config is not None and self.speculative_config.num_speculative_tokens > 1:
|
||||
cur_graph_batch_size = (tp_size * old_graph_batch_size) \
|
||||
// math.gcd(tp_size, old_graph_batch_size)
|
||||
return cur_graph_batch_size
|
||||
|
||||
def select_torchair_padded_batch_size(self, batch_size: int):
|
||||
for padded_batch_size in self.torchair_graph_batch_sizes:
|
||||
if batch_size <= padded_batch_size:
|
||||
# we treat batch_size as num of requests
|
||||
return padded_batch_size
|
||||
raise ValueError(
|
||||
f"cur batch_size is invalid, torchair_graph_batch_sizes is "
|
||||
f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}."
|
||||
)
|
||||
|
||||
def update_torchair_graph_batch_sizes(self):
|
||||
# return graph_batch_sizes according to the max number of tokens
|
||||
# first pad according to the number of requests
|
||||
if self.is_kv_consumer and self.speculative_config and self.speculative_config.method == 'mtp':
|
||||
# pd disaggregation scenario may incorrectly calculate the batch in mtp scenario, so we force set it to max_num_reqs
|
||||
self.torchair_graph_batch_sizes = [self.max_num_reqs]
|
||||
logger.warning(
|
||||
f"is kv_consumer, torch_graph_batch_sizes sets to [max_num_seqs] {[self.max_num_reqs]}"
|
||||
)
|
||||
elif len(self.torchair_graph_batch_sizes) == 0:
|
||||
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
|
||||
else:
|
||||
self.torchair_graph_batch_sizes = sorted(
|
||||
self.torchair_graph_batch_sizes)
|
||||
while self.torchair_graph_batch_sizes[-1] > self.max_num_reqs:
|
||||
self.torchair_graph_batch_sizes.pop()
|
||||
if len(self.torchair_graph_batch_sizes) == 0:
|
||||
logger.warning(
|
||||
"torch_graph_batch_sizes is invalid, reset it to [1, max_num_seqs]"
|
||||
)
|
||||
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
|
||||
if self.torchair_graph_batch_sizes[-1] < self.max_num_reqs:
|
||||
self.torchair_graph_batch_sizes.append(self.max_num_reqs)
|
||||
|
||||
# padded max number tokens = max_num_req * decode_token_per_req
|
||||
self.torchair_graph_batch_sizes = [
|
||||
graph_batch_size * self.decode_token_per_req
|
||||
for graph_batch_size in self.torchair_graph_batch_sizes
|
||||
]
|
||||
|
||||
# NOTE: when enable_expert_parallel on A3, we need to check if `graph_batch_size` is divisible by `tp_size`
|
||||
# Because we use x_active_mask for dispatch/combine op on A3, which requires that input shape should be same
|
||||
# on all EP ranks
|
||||
if get_ascend_device_type(
|
||||
) == AscendDeviceType._910_93 and self.parallel_config.enable_expert_parallel:
|
||||
self._align_graph_size_divisible_by_tp_size()
|
||||
|
||||
def _align_graph_size_divisible_by_tp_size(self):
|
||||
tp_size = self.parallel_config.tensor_parallel_size
|
||||
new_graph_batch_sizes = []
|
||||
for graph_batch_size in self.torchair_graph_batch_sizes:
|
||||
cur_graph_batch_size = self.calculate_new_torchair_graph_batch_size(
|
||||
graph_batch_size, tp_size)
|
||||
if cur_graph_batch_size not in new_graph_batch_sizes and \
|
||||
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
|
||||
new_graph_batch_sizes.append(cur_graph_batch_size)
|
||||
elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \
|
||||
and self.decode_token_per_req > 1:
|
||||
logger.warning(
|
||||
f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens",
|
||||
f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size."
|
||||
)
|
||||
new_max_num_reqs = math.ceil(
|
||||
max(new_graph_batch_sizes) / self.decode_token_per_req)
|
||||
if self.max_num_reqs != new_max_num_reqs:
|
||||
logger.warning(f"max_num_reqs is updated to {new_max_num_reqs}")
|
||||
self.max_num_reqs = new_max_num_reqs
|
||||
if not (self.decode_token_per_req > 1 and self.is_kv_consumer):
|
||||
# Do not update scheduler_config.max_num_seqs in KV consumer + MTP
|
||||
# Since FIA need extra space for padding
|
||||
# Enforce self.max_num_seqs > self.scheduler_config.max_num_seqs in KV consumer + MTP
|
||||
self.scheduler_config.max_num_seqs = new_max_num_reqs
|
||||
|
||||
if new_graph_batch_sizes != self.torchair_graph_batch_sizes:
|
||||
logger.warning(
|
||||
f"torchair_graph_batch_sizes are updated to {new_graph_batch_sizes}."
|
||||
)
|
||||
self.torchair_graph_batch_sizes = new_graph_batch_sizes
|
||||
|
||||
def _build_drafter_prepare_inputs_torchair_param(self):
|
||||
if self.enable_shared_expert_dp:
|
||||
return super()._build_drafter_prepare_inputs_torchair_param()
|
||||
else:
|
||||
return True
|
||||
@@ -1,543 +0,0 @@
|
||||
import types
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchair
|
||||
from torchair import patch_for_hcom
|
||||
from vllm.config import (CUDAGraphMode, VllmConfig,
|
||||
get_layers_from_vllm_config, set_current_vllm_config)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.model_executor.model_loader.utils import \
|
||||
process_weights_after_loading
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.spec_decode import MtpProposer
|
||||
from vllm_ascend.torchair.models.torchair_deepseek_mtp import \
|
||||
TorchairDeepSeekMTP
|
||||
from vllm_ascend.torchair.utils import (TORCHAIR_CACHE_DIR,
|
||||
TorchairCommonAttentionMetadata)
|
||||
from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable
|
||||
|
||||
PADDING_SLOT_ID = -1
|
||||
|
||||
|
||||
class TorchairMtpProposer(MtpProposer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device,
|
||||
runner,
|
||||
):
|
||||
super().__init__(vllm_config, device, runner)
|
||||
self.torchair_compiled_model = None # type: ignore
|
||||
self.torchair_compiled_models = {} # type: ignore
|
||||
|
||||
def load_model(self, model) -> None:
|
||||
loader = get_model_loader(self.vllm_config.load_config)
|
||||
|
||||
target_attn_layer_names = set(
|
||||
get_layers_from_vllm_config(self.vllm_config,
|
||||
AttentionLayerBase).keys())
|
||||
draft_model_config = \
|
||||
self.vllm_config.speculative_config.draft_model_config
|
||||
target_device = self.vllm_config.device_config.device
|
||||
|
||||
with set_default_torch_dtype(
|
||||
draft_model_config.dtype), set_current_vllm_config(
|
||||
self.vllm_config):
|
||||
|
||||
self.model = TorchairDeepSeekMTP(
|
||||
vllm_config=self.vllm_config).to(target_device)
|
||||
|
||||
draft_attn_layer_names = (get_layers_from_vllm_config(
|
||||
self.vllm_config, AttentionLayerBase).keys() -
|
||||
target_attn_layer_names)
|
||||
|
||||
assert len(draft_attn_layer_names) == 1
|
||||
self.attn_layer_name = list(draft_attn_layer_names)
|
||||
|
||||
self.model.load_weights(
|
||||
loader.get_all_weights(
|
||||
self.vllm_config.speculative_config.draft_model_config,
|
||||
self.model))
|
||||
process_weights_after_loading(self.model, draft_model_config,
|
||||
target_device)
|
||||
|
||||
@torch.inference_mode()
|
||||
def dummy_run(self,
|
||||
num_tokens: int,
|
||||
with_prefill: bool = False,
|
||||
skip_attn: bool = False,
|
||||
num_reqs: int = 0,
|
||||
num_tokens_across_dp=None,
|
||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor=None,
|
||||
dummy_compute_logits=lambda hidden_states: None) -> None:
|
||||
moe_comm_type = self.runner._select_moe_comm_method(num_tokens)
|
||||
|
||||
if not with_prefill:
|
||||
skip_attn = False
|
||||
if skip_attn:
|
||||
attn_metadata = None
|
||||
else:
|
||||
common_attn_metadata = TorchairCommonAttentionMetadata(
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=1,
|
||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
||||
attn_mask=self.runner.attn_mask,
|
||||
spec_attn_mask=self.runner.spec_attn_mask,
|
||||
decode_token_per_req=self.runner.decode_token_per_req,
|
||||
)
|
||||
attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy(
|
||||
common_attn_metadata)
|
||||
|
||||
input_ids = self.input_ids[:num_tokens]
|
||||
positions = self.positions[:num_tokens]
|
||||
previous_hidden_states = self.hidden_states[:num_tokens]
|
||||
for _ in range(self.num_speculative_tokens):
|
||||
with set_ascend_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
with_prefill=with_prefill,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
||||
moe_comm_type=moe_comm_type,
|
||||
in_profile_run=self.runner.in_profile_run,
|
||||
num_actual_tokens=0):
|
||||
if not with_prefill:
|
||||
assert attn_metadata is not None
|
||||
torch._dynamo.mark_static(input_ids)
|
||||
torch._dynamo.mark_static(positions)
|
||||
torch._dynamo.mark_static(previous_hidden_states)
|
||||
torch._dynamo.mark_static(attn_metadata.decode.block_table)
|
||||
torch._dynamo.mark_static(
|
||||
attn_metadata.decode.input_positions)
|
||||
if hasattr(attn_metadata.decode, "sin"):
|
||||
torch._dynamo.mark_static(attn_metadata.decode.sin)
|
||||
torch._dynamo.mark_static(attn_metadata.decode.cos)
|
||||
torch._dynamo.mark_static(get_forward_context().mc2_mask)
|
||||
torch._dynamo.mark_static(attn_metadata.slot_mapping)
|
||||
torch._dynamo.mark_static(attn_metadata.decode.attn_mask)
|
||||
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
|
||||
num_tokens)
|
||||
torchair_compiled_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
hidden_states=previous_hidden_states,
|
||||
inputs_embeds=None,
|
||||
intermediate_tensors=None,
|
||||
attn_metadata=attn_metadata,
|
||||
kv_caches=self.runner.kv_caches[-1:],
|
||||
spec_step_idx=0)
|
||||
else:
|
||||
self.model(input_ids=input_ids,
|
||||
positions=positions,
|
||||
hidden_states=previous_hidden_states)
|
||||
dummy_compute_logits(previous_hidden_states)
|
||||
if with_prefill:
|
||||
break
|
||||
|
||||
def generate_token_ids(self,
|
||||
valid_sampled_token_ids: list[list[int]],
|
||||
sampling_metadata: SamplingMetadata = None,
|
||||
scheduler_output: SchedulerOutput = None,
|
||||
spec_decode_metadata: SpecDecodeMetadata = None,
|
||||
positions: torch.Tensor = None,
|
||||
num_scheduled_tokens: int = 0,
|
||||
hidden_states: torch.Tensor = None,
|
||||
attn_metadata=None,
|
||||
aux_hidden_states: torch.Tensor = None):
|
||||
if attn_metadata is not None and isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata['model.layers.0.self_attn.attn']
|
||||
next_token_ids: list[int] = []
|
||||
for i, token_ids in enumerate(valid_sampled_token_ids):
|
||||
if token_ids:
|
||||
# Common case.
|
||||
next_token_id = token_ids[-1]
|
||||
else:
|
||||
# Partial prefill (rare case).
|
||||
# Get the next token id from the request state.
|
||||
req_id = self.runner.input_batch.req_ids[i]
|
||||
req_state = self.runner.requests[req_id]
|
||||
seq_len = (req_state.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
next_token_id = req_state.get_token_id(seq_len)
|
||||
next_token_ids.append(next_token_id)
|
||||
next_token_ids = torch.tensor(next_token_ids,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
accepted_token_indices = None
|
||||
if spec_decode_metadata is None:
|
||||
# input_ids can be None for multimodal models.
|
||||
target_token_ids = self.runner.input_ids[:num_scheduled_tokens]
|
||||
target_positions = positions[:num_scheduled_tokens]
|
||||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
target_slot_mapping = attn_metadata.slot_mapping
|
||||
cu_num_tokens = attn_metadata.query_start_loc
|
||||
else:
|
||||
# TODO(woosuk): Refactor this.
|
||||
num_draft_tokens = spec_decode_metadata.num_draft_tokens
|
||||
num_rejected_tokens = [
|
||||
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
|
||||
for i, n in enumerate(num_draft_tokens)
|
||||
]
|
||||
num_rejected_tokens = torch.tensor(
|
||||
num_rejected_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
cu_num_tokens, accepted_token_indices, target_token_ids, \
|
||||
target_positions, target_hidden_states, target_slot_mapping = self._torchair_prepare_inputs(
|
||||
attn_metadata.query_start_loc,
|
||||
num_rejected_tokens,
|
||||
self.runner.input_ids[:num_scheduled_tokens],
|
||||
positions[:num_scheduled_tokens],
|
||||
hidden_states[:num_scheduled_tokens],
|
||||
attn_metadata.slot_mapping[:num_scheduled_tokens],
|
||||
)
|
||||
|
||||
draft_token_ids = self._propose_torchair(
|
||||
target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
target_slot_mapping=target_slot_mapping,
|
||||
next_token_ids=next_token_ids,
|
||||
cu_num_tokens=cu_num_tokens,
|
||||
block_table=attn_metadata.block_tables,
|
||||
sampling_metadata=sampling_metadata,
|
||||
token_indices=accepted_token_indices)
|
||||
spec_token_ids = draft_token_ids.tolist()
|
||||
return spec_token_ids
|
||||
|
||||
def _torchair_prepare_inputs(
|
||||
self,
|
||||
# [batch_size + 1]
|
||||
cu_target_query_lens: torch.Tensor,
|
||||
# [batch_size]
|
||||
num_rejected_tokens: torch.Tensor,
|
||||
token_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
||||
torch.Tensor, torch.Tensor]:
|
||||
# cu_target_query_lens: [0, a, a + b, a + b + c]
|
||||
# num_rejected_tokens: [n1, n2, n3]
|
||||
# num_tokens_per_req: [a - n1, b - n2, c - n3]
|
||||
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
|
||||
# token_indices: [0, 1, ..., a - n1 - 1,
|
||||
# a, a + 1, ..., a + b - n2 - 1,
|
||||
# a + b, a + b + 1, ..., a + b + c - n3 - 1]
|
||||
# [0, a, a + b, a + b + c] -> [a, b, c]
|
||||
query_len_per_req = (cu_target_query_lens[1:] -
|
||||
cu_target_query_lens[:-1])
|
||||
# [a, b, c] -> [a - n1, b - n2, c - n3]
|
||||
|
||||
cu_num_tokens = cu_target_query_lens
|
||||
relative_index = query_len_per_req - num_rejected_tokens - 1
|
||||
token_indices = cu_num_tokens[:-1] + relative_index
|
||||
# the seq len of each bath is padded to 1+num_speculative_tokens, thus input is same as the main model
|
||||
target_token_ids = token_ids
|
||||
target_positions = positions
|
||||
target_hidden_states = hidden_states
|
||||
target_slot_mapping = slot_mapping
|
||||
|
||||
return cu_num_tokens, token_indices, target_token_ids, target_positions, target_hidden_states, target_slot_mapping
|
||||
|
||||
def _propose_torchair(
|
||||
self,
|
||||
# [num_tokens]
|
||||
target_token_ids: torch.Tensor,
|
||||
# [num_tokens]
|
||||
target_positions: torch.Tensor,
|
||||
# [num_tokens, hidden_size]
|
||||
target_hidden_states: torch.Tensor,
|
||||
# [num_tokens]
|
||||
target_slot_mapping: torch.Tensor,
|
||||
# [batch_size]
|
||||
next_token_ids: torch.Tensor,
|
||||
# [batch_size + 1] starting with 0
|
||||
cu_num_tokens: torch.Tensor,
|
||||
# [batch_size, max_num_blocks_per_req]
|
||||
block_table: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
token_indices=None) -> torch.Tensor:
|
||||
num_tokens = target_token_ids.shape[0]
|
||||
batch_size = next_token_ids.shape[0]
|
||||
last_token_indices = cu_num_tokens[1:] - 1
|
||||
|
||||
# Shift the input ids by one token.
|
||||
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
|
||||
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
|
||||
# Replace the last token with the next token.
|
||||
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
||||
if token_indices is not None:
|
||||
last_token_indices = token_indices
|
||||
|
||||
self.input_ids[last_token_indices] = next_token_ids
|
||||
|
||||
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
|
||||
max_query_len = query_lens.max().item()
|
||||
|
||||
# FIXME: reorder_batch() needs to be called before build()
|
||||
# because fields of attn_metadata_builder needs to be updated.
|
||||
# However, currently reorder_batch() takes input_batch and
|
||||
# scheduler_output as arguments, we should probably refactor
|
||||
# the method to use new data structures which are independent
|
||||
# from input_batch and scheduler_output.
|
||||
# self.runner.attn_metadata_builder.reorder_batch(
|
||||
# input_batch=self.runner.input_batch,
|
||||
# scheduler_output=self.runner.scheduler_output,
|
||||
# )
|
||||
|
||||
if not self.runner.with_prefill:
|
||||
# Torchair graph mode, padding is same as the main model
|
||||
num_input_tokens = self.runner.graph_pad_size
|
||||
elif (self.runner.use_aclgraph
|
||||
and num_tokens <= self.runner.aclgraph_batch_sizes[-1]):
|
||||
# Acl graph mode, add padding to the batch size
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||
else:
|
||||
# Eager mode, no padding needed
|
||||
num_input_tokens = num_tokens
|
||||
|
||||
seq_lens = target_positions[last_token_indices] + 1
|
||||
seq_lens = seq_lens.int()
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=cu_num_tokens[:batch_size + 1],
|
||||
query_start_loc_cpu=cu_num_tokens[:batch_size + 1].cpu(),
|
||||
seq_lens_cpu=seq_lens.cpu(),
|
||||
num_reqs=batch_size,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=max_query_len,
|
||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
||||
block_table_tensor=self.runner.input_batch.block_table[0].
|
||||
get_device_tensor(),
|
||||
slot_mapping=target_slot_mapping,
|
||||
positions=target_positions,
|
||||
attn_mask=self.runner.attn_mask,
|
||||
spec_attn_mask=self.runner.spec_attn_mask,
|
||||
attn_state=self.runner.attn_state,
|
||||
graph_pad_size=self.runner.graph_pad_size,
|
||||
decode_token_per_req=self.runner.decode_token_per_req,
|
||||
num_computed_tokens_cpu=None,
|
||||
seq_lens=None)
|
||||
|
||||
attn_metadata = self.runner.attn_metadata_builder.build(
|
||||
0, common_attn_metadata, self.runner.get_model())
|
||||
|
||||
self.positions[:num_tokens] = target_positions
|
||||
self.hidden_states[:num_tokens] = target_hidden_states
|
||||
|
||||
# torchair mode can reuse self.runner.num_tokens_across_dp
|
||||
num_tokens_across_dp = self.runner.num_tokens_across_dp
|
||||
with_prefill = self.runner.with_prefill
|
||||
moe_comm_type = self.runner._select_moe_comm_method(num_input_tokens)
|
||||
|
||||
for step in range(self.num_speculative_tokens):
|
||||
with set_ascend_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
with_prefill=with_prefill,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
||||
moe_comm_type=moe_comm_type,
|
||||
in_profile_run=self.runner.in_profile_run,
|
||||
num_actual_tokens=num_tokens):
|
||||
with ProfileExecuteDuration().capture_async('mtp_forward'):
|
||||
model_kwargs = {}
|
||||
model_kwargs["attn_metadata"] = attn_metadata
|
||||
|
||||
model_kwargs["kv_caches"] = self.runner.kv_caches[-1:]
|
||||
if not self.runner.with_prefill:
|
||||
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
|
||||
num_input_tokens)
|
||||
hidden_states = torchair_compiled_model(
|
||||
input_ids=self.input_ids[:num_input_tokens],
|
||||
positions=self.positions[:num_input_tokens],
|
||||
hidden_states=self.
|
||||
hidden_states[:num_input_tokens],
|
||||
inputs_embeds=None,
|
||||
intermediate_tensors=None,
|
||||
spec_step_idx=0,
|
||||
**model_kwargs)
|
||||
else:
|
||||
hidden_states = self.model(
|
||||
input_ids=self.input_ids[:num_input_tokens],
|
||||
positions=self.positions[:num_input_tokens],
|
||||
hidden_states=self.hidden_states[:num_input_tokens]
|
||||
)
|
||||
|
||||
num_indices = last_token_indices.shape[0]
|
||||
if lmhead_tp_enable():
|
||||
if not self.runner.with_prefill:
|
||||
max_num_reqs_across_dp = num_input_tokens
|
||||
else:
|
||||
max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs
|
||||
last_token_indices = nn.functional.pad(
|
||||
last_token_indices,
|
||||
(0, max_num_reqs_across_dp - num_indices))
|
||||
|
||||
sample_hidden_states = hidden_states[last_token_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
if lmhead_tp_enable() and num_indices < logits.shape[0]:
|
||||
logits = logits[:num_indices]
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
|
||||
if self.num_speculative_tokens == 1:
|
||||
# [batch_size, 1]
|
||||
return draft_token_ids.view(-1, 1)
|
||||
|
||||
if step == 0:
|
||||
draft_token_ids_list = [draft_token_ids]
|
||||
else:
|
||||
draft_token_ids_list.append(draft_token_ids)
|
||||
|
||||
# prepare next mtp inputs
|
||||
# mtp>1: prefill skip or decode skip last loop
|
||||
if with_prefill:
|
||||
for _ in range(self.num_speculative_tokens - 1):
|
||||
draft_token_ids_list.append(draft_token_ids)
|
||||
if step == self.num_speculative_tokens - 1 or with_prefill:
|
||||
break
|
||||
|
||||
attn_metadata_i = attn_metadata
|
||||
|
||||
if step == 0:
|
||||
positions = target_positions[last_token_indices]
|
||||
hidden_states = hidden_states[last_token_indices]
|
||||
slot_mapping = attn_metadata_i.slot_mapping[last_token_indices]
|
||||
attn_metadata_i.slot_mapping.fill_(-1)
|
||||
attn_metadata_i.query_start_loc = self.arange[:batch_size + 1]
|
||||
last_token_indices = self.arange[:batch_size]
|
||||
if attn_metadata_i.num_decode_tokens != 0:
|
||||
attn_metadata_i.num_decode_tokens = batch_size
|
||||
if not self.runner.with_prefill:
|
||||
attn_metadata_i.num_actual_tokens = batch_size
|
||||
attn_metadata_i.query_lens = [1] * batch_size
|
||||
|
||||
input_ids = draft_token_ids_list[-1].int()
|
||||
positions += 1
|
||||
|
||||
# NOTE(woosuk): We should handle the case where the draft model
|
||||
# generates tokens beyond the max model length. Since it is complex
|
||||
# to remove such requests from the batch, we keep them in the batch
|
||||
# but adjust the position ids and slot mappings to avoid the
|
||||
# out-of-range access during the model execution. The draft tokens
|
||||
# generated with this adjustment should be ignored.
|
||||
exceeds_max_model_len = positions >= self.runner.model_config.max_model_len
|
||||
# Mask out the position ids that exceed the max model length.
|
||||
# Otherwise, we may get out-of-range error in RoPE.
|
||||
clamped_positions = torch.where(exceeds_max_model_len, 0,
|
||||
positions)
|
||||
# Increment the sequence lengths.
|
||||
attn_metadata_i.seq_lens[:batch_size] += 1
|
||||
# For the requests that exceed the max model length, we set the
|
||||
# sequence length to 1 to minimize their overheads in attention.
|
||||
exceeds_max_model_len_cpu = exceeds_max_model_len.to(
|
||||
attn_metadata_i.seq_lens.device, non_blocking=True)
|
||||
attn_metadata_i.seq_lens[:batch_size].masked_fill_(
|
||||
exceeds_max_model_len_cpu, 1)
|
||||
# Mask out the slot mappings that exceed the max model length.
|
||||
# Otherwise, the KV cache will be inadvertently updated with the
|
||||
# padding tokens.
|
||||
slot_mapping += 1
|
||||
slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
|
||||
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.input_ids[:batch_size] = input_ids
|
||||
self.positions[:batch_size] = clamped_positions
|
||||
self.hidden_states[:hidden_states.shape[0]] = hidden_states
|
||||
attn_metadata_i.slot_mapping[:batch_size] = slot_mapping
|
||||
|
||||
if attn_metadata_i.prefill is not None:
|
||||
attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens
|
||||
attn_metadata_i.prefill.seq_lens_list = attn_metadata_i.prefill.seq_lens.tolist(
|
||||
)
|
||||
attn_metadata_i.prefill.context_lens = attn_metadata_i.seq_lens
|
||||
attn_metadata_i.prefill.input_positions = self.positions[:
|
||||
num_input_tokens]
|
||||
attn_metadata_i.prefill.max_seq_lens += 1
|
||||
attn_metadata_i.prefill.max_seq_lens = min(
|
||||
attn_metadata_i.prefill.max_seq_lens,
|
||||
self.runner.model_config.max_model_len)
|
||||
if attn_metadata_i.decode is not None:
|
||||
attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens
|
||||
attn_metadata_i.decode.seq_lens_list = attn_metadata_i.decode.seq_lens.tolist(
|
||||
)
|
||||
attn_metadata_i.decode.input_positions = self.positions[:
|
||||
num_input_tokens]
|
||||
attn_metadata_i.decode.max_seq_lens += 1
|
||||
attn_metadata_i.decode.max_seq_lens = min(
|
||||
attn_metadata_i.decode.max_seq_lens,
|
||||
self.runner.model_config.max_model_len)
|
||||
|
||||
# mtp>1: [batch_size, k]
|
||||
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
||||
return draft_token_ids
|
||||
|
||||
def _get_torchair_lazy_compiled_model(self, batch_size: int):
|
||||
if batch_size < 0 or batch_size > self.runner.torchair_graph_batch_sizes[
|
||||
-1]:
|
||||
raise ValueError(
|
||||
f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.runner.torchair_graph_batch_sizes[-1]}"
|
||||
)
|
||||
|
||||
compiled_model = self.torchair_compiled_models.get(
|
||||
batch_size
|
||||
) if self.runner.use_cached_npu_graph else self.torchair_compiled_model
|
||||
|
||||
if compiled_model:
|
||||
return compiled_model
|
||||
|
||||
patch_for_hcom()
|
||||
config = torchair.CompilerConfig()
|
||||
config.experimental_config.frozen_parameter = True
|
||||
config.experimental_config.tiling_schedule_optimize = True
|
||||
config.experimental_config.enable_view_optimize = \
|
||||
get_ascend_config().torchair_graph_config.enable_view_optimize
|
||||
torch.npu.set_compile_mode(jit_compile=False)
|
||||
if not self.runner.use_cached_npu_graph:
|
||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||
self.torchair_compiled_model = torch.compile(
|
||||
self.model,
|
||||
dynamic=not self.use_sparse,
|
||||
fullgraph=True,
|
||||
backend=npu_backend)
|
||||
return self.torchair_compiled_model
|
||||
else:
|
||||
# Generate a new forward proxy code object to prevent the invalidation of
|
||||
# compilation cache caused by dynamo retracing
|
||||
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
|
||||
forward_fn = self.model.forward
|
||||
code = forward_fn.__code__
|
||||
# Mark code object with a new proxy name
|
||||
modified_code = code.replace(co_name=forward_proxy_name, )
|
||||
|
||||
modified_func = types.FunctionType(modified_code,
|
||||
forward_fn.__globals__,
|
||||
name=forward_proxy_name,
|
||||
argdefs=forward_fn.__defaults__)
|
||||
|
||||
self.model.__dict__[forward_proxy_name] = modified_func.__get__(
|
||||
self.model, nn.Module)
|
||||
self.torchair_compiled_models[
|
||||
batch_size] = torchair.inference.cache_compile(
|
||||
self.model.__dict__[forward_proxy_name],
|
||||
dynamic=not self.use_sparse,
|
||||
fullgraph=True,
|
||||
cache_dir=TORCHAIR_CACHE_DIR,
|
||||
config=config,
|
||||
ge_cache=False)
|
||||
return self.torchair_compiled_models[batch_size]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,63 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from vllm.logger import logger
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.torchair.torchair_model_runner import NPUTorchairModelRunner
|
||||
from vllm_ascend.torchair.utils import (check_kv_cache_bytes_cache_exist,
|
||||
delete_torchair_cache_file,
|
||||
read_kv_cache_bytes_from_file)
|
||||
from vllm_ascend.worker.worker_v1 import NPUWorker
|
||||
|
||||
|
||||
class NPUTorchairWorker(NPUWorker):
|
||||
"""Torchair worker bases on NPUWorker. Only torchair specified code should be added in this class."""
|
||||
|
||||
def determine_available_memory(self) -> int:
|
||||
"""Override determine_available_memory to use cached torchair kv_cache_bytes."""
|
||||
|
||||
available_kv_cache_memory = super().determine_available_memory()
|
||||
ascend_config = get_ascend_config()
|
||||
if ascend_config.enable_shared_expert_dp:
|
||||
return available_kv_cache_memory
|
||||
if ascend_config.torchair_graph_config.use_cached_kv_cache_bytes:
|
||||
if check_kv_cache_bytes_cache_exist():
|
||||
old_kv_cache_bytes = read_kv_cache_bytes_from_file(
|
||||
torch.distributed.get_rank())
|
||||
if 0 < old_kv_cache_bytes <= available_kv_cache_memory:
|
||||
logger.info(
|
||||
f"Use cached torchair kv_cache_bytes: {old_kv_cache_bytes}"
|
||||
)
|
||||
self.model_runner.new_kv_cache_bytes = old_kv_cache_bytes
|
||||
return old_kv_cache_bytes
|
||||
else:
|
||||
logger.info(
|
||||
"Cached torchair kv_cache_bytes is too big, invalidate old torchair_cache"
|
||||
)
|
||||
delete_torchair_cache_file()
|
||||
bytes_floating_tolerance = 1024 * 1024 * envs_ascend.VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE
|
||||
available_kv_cache_memory -= bytes_floating_tolerance
|
||||
logger.info(f"Use new kv_cache_bytes: {available_kv_cache_memory}")
|
||||
self.model_runner.new_kv_cache_bytes = available_kv_cache_memory
|
||||
return available_kv_cache_memory
|
||||
|
||||
def init_device(self):
|
||||
"""Override init_device to init torchair model runner"""
|
||||
device = self._init_device()
|
||||
# Init ModelRunner here, so that we have access to self.device.
|
||||
self.model_runner = NPUTorchairModelRunner(self.vllm_config, device)
|
||||
@@ -1,275 +0,0 @@
|
||||
import fcntl
|
||||
import os
|
||||
import shutil
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from torchair.scope import super_kernel as _super_kernel
|
||||
|
||||
try:
|
||||
# Recent release of torchair has moved these ops to `.scope`.
|
||||
from torchair.scope import npu_stream_switch as _npu_stream_switch
|
||||
from torchair.scope import npu_wait_tensor as _npu_wait_tensor
|
||||
except ImportError:
|
||||
from torchair.ops import NpuStreamSwitch as _npu_stream_switch
|
||||
from torchair.ops import npu_wait_tensor as _npu_wait_tensor
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
|
||||
|
||||
KV_CACHE_BYTES_CACHE_PATH_NAME = ".kv_cache_bytes"
|
||||
KV_CACHE_BYTES_CACHE_FILE_NAME = "kv_cache_bytes"
|
||||
TORCHAIR_CACHE_PATH_NAME = ".torchair_cache"
|
||||
TORCHAIR_CACHE_DIR = os.path.join(
|
||||
os.getenv('TORCHAIR_CACHE_HOME', os.getcwd()), TORCHAIR_CACHE_PATH_NAME)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TorchairCommonAttentionMetadata:
|
||||
"""
|
||||
Per-batch attention metadata, shared across layers and backends.
|
||||
AttentionMetadataBuilder instances use it to construct per-layer metadata.
|
||||
|
||||
For many of the tensors we keep both GPU and CPU versions.
|
||||
"""
|
||||
|
||||
num_reqs: int
|
||||
"""Number of requests"""
|
||||
|
||||
num_actual_tokens: int
|
||||
"""Total number of tokens in batch"""
|
||||
|
||||
decode_token_per_req: int
|
||||
|
||||
actual_seq_lengths_q: list[int]
|
||||
|
||||
attn_mask: torch.Tensor = None
|
||||
|
||||
spec_attn_mask: torch.Tensor = None
|
||||
|
||||
graph_pad_size: int = -1
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _file_lock(file_descriptor, lock_type):
|
||||
fcntl.flock(file_descriptor, lock_type)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
fcntl.flock(file_descriptor, fcntl.LOCK_UN)
|
||||
|
||||
|
||||
def _get_torchair_current_work_dir(file_name=None):
|
||||
if file_name is None:
|
||||
return TORCHAIR_CACHE_DIR
|
||||
return os.path.join(TORCHAIR_CACHE_DIR, file_name)
|
||||
|
||||
|
||||
def check_torchair_cache_exist():
|
||||
res = False
|
||||
torch_air_abs_path = _get_torchair_current_work_dir()
|
||||
if os.path.exists(torch_air_abs_path):
|
||||
file_list = os.listdir(torch_air_abs_path)
|
||||
if len(file_list) != 0:
|
||||
res = True
|
||||
return res
|
||||
|
||||
|
||||
def check_kv_cache_bytes_cache_exist():
|
||||
res = False
|
||||
kv_cache_bytes_cache_abs_path = _get_torchair_current_work_dir(
|
||||
KV_CACHE_BYTES_CACHE_PATH_NAME)
|
||||
if os.path.exists(kv_cache_bytes_cache_abs_path):
|
||||
file_list = os.listdir(kv_cache_bytes_cache_abs_path)
|
||||
if len(file_list) != 0:
|
||||
res = True
|
||||
return res
|
||||
|
||||
|
||||
def read_kv_cache_bytes_from_file(rank) -> int:
|
||||
kv_cache_bytes = -1
|
||||
kv_cache_bytes_cache_abs_path = _get_torchair_current_work_dir(
|
||||
KV_CACHE_BYTES_CACHE_PATH_NAME)
|
||||
kv_cache_bytes_file = os.path.join(
|
||||
kv_cache_bytes_cache_abs_path,
|
||||
f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}")
|
||||
with open(kv_cache_bytes_file, "r", encoding="utf-8") as f:
|
||||
with _file_lock(f, fcntl.LOCK_SH):
|
||||
kv_cache_bytes = int(f.readline())
|
||||
return kv_cache_bytes
|
||||
|
||||
|
||||
def write_kv_cache_bytes_to_file(rank, kv_cache_bytes):
|
||||
kv_cache_bytes_cache_abs_path = _get_torchair_current_work_dir(
|
||||
KV_CACHE_BYTES_CACHE_PATH_NAME)
|
||||
os.makedirs(kv_cache_bytes_cache_abs_path, exist_ok=True)
|
||||
kv_cache_bytes_file = os.path.join(
|
||||
kv_cache_bytes_cache_abs_path,
|
||||
f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}")
|
||||
with open(kv_cache_bytes_file, "w", encoding="utf-8") as f:
|
||||
with _file_lock(f, fcntl.LOCK_EX):
|
||||
f.write(f"{kv_cache_bytes}")
|
||||
|
||||
|
||||
def delete_torchair_cache_file():
|
||||
torch_air_abs_path = _get_torchair_current_work_dir()
|
||||
try:
|
||||
shutil.rmtree(torch_air_abs_path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
def npu_stream_switch(tag: str, priority: int, *, enabled: bool = True):
|
||||
return _npu_stream_switch(tag, priority) if enabled else nullcontext()
|
||||
|
||||
|
||||
def npu_wait_tensor(self: torch.Tensor,
|
||||
dependency: torch.Tensor,
|
||||
*,
|
||||
enabled: bool = True):
|
||||
return _npu_wait_tensor(self, dependency) if enabled else self
|
||||
|
||||
|
||||
def converting_weight_acl_format(model, format):
|
||||
# currently, there are some operations which do not support ACL_FORMAT_FRACTAL_NZ
|
||||
# in eager mode but support it in torchair graph mode. since ACL_FORMAT_FRACTAL_NZ
|
||||
# is much more preferred than ACL_FORMAT_FRACTAL_ND on 300I Duo, we add this
|
||||
# conversion when using torchair graph mode on 300I Duo platform.
|
||||
# TODO: we will remove this conversion if npu_quant_grouped_matmul_dequant
|
||||
# accepts weight format of ACL_FORMAT_FRACTAL_NZ in eager mode.
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, FusedMoE):
|
||||
if torch_npu.get_npu_format(module.w13_weight.data) == format:
|
||||
return
|
||||
if format == ACL_FORMAT_FRACTAL_NZ \
|
||||
and not is_enable_nz():
|
||||
return
|
||||
module.w13_weight.data = torch_npu.npu_format_cast(
|
||||
module.w13_weight.data, format)
|
||||
module.w2_weight.data = torch_npu.npu_format_cast(
|
||||
module.w2_weight.data, format)
|
||||
|
||||
|
||||
def register_torchair_model():
|
||||
from vllm import ModelRegistry
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepSeekMTPModel",
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_mtp:TorchairDeepSeekMTP"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV2ForCausalLM",
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v2:TorchairDeepseekV2ForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV3ForCausalLM",
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV32ForCausalLM",
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen2ForCausalLM",
|
||||
"vllm_ascend.torchair.models.qwen2:CustomQwen2ForCausalLM")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3MoeForCausalLM",
|
||||
"vllm_ascend.torchair.models.qwen3_moe:CustomQwen3MoeForCausalLM")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"PanguProMoEForCausalLM",
|
||||
"vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM"
|
||||
)
|
||||
|
||||
|
||||
def torchair_quant_method_register():
|
||||
from vllm_ascend.quantization.utils import ASCEND_QUANTIZATION_METHOD_MAP
|
||||
from vllm_ascend.torchair.quantization.torchair_w4a8_dynamic import (
|
||||
TorchairAscendW4A8DynamicFusedMoEMethod,
|
||||
TorchairAscendW4A8DynamicLinearMethod)
|
||||
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import (
|
||||
TorchairAscendW8A8DynamicFusedMoEMethod,
|
||||
TorchairAscendW8A8DynamicLinearMethod)
|
||||
|
||||
ASCEND_QUANTIZATION_METHOD_MAP["W8A8_DYNAMIC"][
|
||||
"linear"] = TorchairAscendW8A8DynamicLinearMethod
|
||||
ASCEND_QUANTIZATION_METHOD_MAP["W8A8_DYNAMIC"][
|
||||
"moe"] = TorchairAscendW8A8DynamicFusedMoEMethod
|
||||
ASCEND_QUANTIZATION_METHOD_MAP["W4A8_DYNAMIC"][
|
||||
"linear"] = TorchairAscendW4A8DynamicLinearMethod
|
||||
ASCEND_QUANTIZATION_METHOD_MAP["W4A8_DYNAMIC"][
|
||||
"moe"] = TorchairAscendW4A8DynamicFusedMoEMethod
|
||||
|
||||
|
||||
def torchair_ops_patch():
|
||||
from vllm_ascend.ops.activation import AscendSiluAndMul
|
||||
from vllm_ascend.ops.layernorm import AscendQuantRMSNorm, AscendRMSNorm
|
||||
from vllm_ascend.ops.rotary_embedding import (
|
||||
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding)
|
||||
from vllm_ascend.ops.vocab_parallel_embedding import \
|
||||
AscendVocabParallelEmbedding
|
||||
from vllm_ascend.torchair.ops import (torchair_activation,
|
||||
torchair_layernorm)
|
||||
from vllm_ascend.torchair.ops.torchair_rotary_embedding import (
|
||||
deepseek_rope_init_func, native_rope_deepseek_forward,
|
||||
qwen_rope_init_func, rope_forward)
|
||||
from vllm_ascend.torchair.ops.torchair_vocab_parallel_embedding import \
|
||||
vocab_embedding_forward
|
||||
|
||||
AscendRotaryEmbedding.__init__ = qwen_rope_init_func # type: ignore[method-assign]
|
||||
AscendRotaryEmbedding.forward_oot = rope_forward # type: ignore[method-assign]
|
||||
|
||||
AscendDeepseekScalingRotaryEmbedding.__init__ = deepseek_rope_init_func # type: ignore[method-assign]
|
||||
AscendDeepseekScalingRotaryEmbedding.forward = native_rope_deepseek_forward # type: ignore[method-assign]
|
||||
|
||||
AscendRMSNorm.__init__ = torchair_layernorm.torchair_rmsnorm_init_ # type: ignore[method-assign]
|
||||
AscendRMSNorm.forward_oot = torchair_layernorm.torchair_rmsnorm_forward_oot # type: ignore[method-assign]
|
||||
|
||||
AscendQuantRMSNorm.__init__ = torchair_layernorm.torchair_rmsnorm_init_ # type: ignore[method-assign]
|
||||
AscendQuantRMSNorm.forward_oot = torchair_layernorm.torchair_rmsnorm_forward_oot # type: ignore[method-assign]
|
||||
|
||||
AscendSiluAndMul.forward_oot = torchair_activation.torchair_silu_and_mul_forward_oot # type: ignore[method-assign]
|
||||
AscendVocabParallelEmbedding.forward = vocab_embedding_forward # type: ignore[method-assign]
|
||||
|
||||
|
||||
def super_kernel(prefix: str, option: str, enabled: bool = True):
|
||||
return _super_kernel(prefix, option) if enabled else nullcontext()
|
||||
|
||||
|
||||
# TODO(ttanzhiqiang): rm_router_logits
|
||||
# dp>1 will trigger
|
||||
# In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors.
|
||||
def get_rm_router_logits_state(ep_size: int, dp_size: int,
|
||||
is_deepseek_v3_r1: bool):
|
||||
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
|
||||
# only supports deepseek v3/r1
|
||||
if dp_size > 1:
|
||||
if (envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
|
||||
and is_deepseek_v3_r1):
|
||||
return True
|
||||
elif ep_size == 1 and is_deepseek_v3_r1:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# TODO(ttanzhiqiang): all_reduce merge
|
||||
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
|
||||
# Currently, all_reduce_merge is enabled by default in the AllGather, AllGatherEP and NaiveMulticast scenarios of the deepseek model.
|
||||
def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool):
|
||||
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
|
||||
# only supports deepseek v3/r1
|
||||
if (envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
|
||||
and is_deepseek_v3_r1):
|
||||
return True
|
||||
elif ep_size == 1 and is_deepseek_v3_r1:
|
||||
return True
|
||||
return False
|
||||
@@ -143,7 +143,6 @@ from vllm_ascend.spec_decode import get_spec_decode_method
|
||||
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
||||
from vllm_ascend.spec_decode.interface import SpecDcodeType
|
||||
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
||||
from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
AscendDeviceType, ProfileExecuteDuration,
|
||||
enable_sp, get_ascend_device_type, is_enable_nz,
|
||||
@@ -638,7 +637,6 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
# Set up speculative decoding.
|
||||
self.spec_attn_mask = None
|
||||
self.drafter: Optional[Union[NgramProposer, EagleProposer, MtpProposer,
|
||||
TorchairMtpProposer,
|
||||
SuffixDecodingProposer]] = None
|
||||
self.actual_seq_lengths_q: list[int] = []
|
||||
self.decode_token_per_req = 1
|
||||
@@ -2917,8 +2915,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def _generate_dummy_run_hidden_states(self, with_prefill,
|
||||
is_torchair_compile, input_ids,
|
||||
def _generate_dummy_run_hidden_states(self, with_prefill, input_ids,
|
||||
positions, attn_metadata, num_tokens,
|
||||
intermediate_tensors, inputs_embeds):
|
||||
hidden_states = self.model(input_ids=input_ids,
|
||||
@@ -2960,7 +2957,6 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
self,
|
||||
num_tokens: int,
|
||||
with_prefill: bool = False,
|
||||
is_torchair_compile: bool = False,
|
||||
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
|
||||
force_attention: bool = False,
|
||||
uniform_decode: bool = False,
|
||||
@@ -3136,9 +3132,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
model_instance=self.model,
|
||||
weight_prefetch_method=self.weight_prefetch_method):
|
||||
hidden_states = self._generate_dummy_run_hidden_states(
|
||||
with_prefill, is_torchair_compile, input_ids, positions,
|
||||
attn_metadata, num_tokens_padded, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
with_prefill, input_ids, positions, attn_metadata,
|
||||
num_tokens_padded, intermediate_tensors, inputs_embeds)
|
||||
dummy_compute_logits(hidden_states)
|
||||
|
||||
if self.drafter:
|
||||
@@ -4262,9 +4257,6 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
|
||||
return list(model.pooler.get_supported_tasks())
|
||||
|
||||
def _build_drafter_prepare_inputs_torchair_param(self):
|
||||
return False
|
||||
|
||||
def _update_tokens_for_pcp(self, tokens):
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
self.num_pcp_pads = self.num_pcp_pads[:num_reqs]
|
||||
|
||||
Reference in New Issue
Block a user